In [2]:
import importlib
import sys
import torch
import pickle
import os
from tqdm.notebook import tqdm

sys.path.insert(0, '..')
sys.path.insert(0, '../..')
sys.path.insert(0, '../../..')
sys.path.insert(0, '../../../..')
sys.path.insert(0, '../../../../..')

from model.dropout_uncertainty_enc_dec_LSTM.dropout_uncertainty_model import DropoutUncertaintyEncoderDecoderLSTM
from evaluation.probabilistic_evaluation import ProbabilisticEvaluation


In [5]:
# Load model
#file_path_model = '../../../training_variational_dropout/sepsis/Sepsis_full_no_grad_norm_robustness.pkl'
file_path_model = '../../../training_variational_dropout/Sepsis/Sepsis_setting_2.pkl'
output_dir = '../../../../../evaluation_results/robustness/sepsis/redo_activity/'
model = DropoutUncertaintyEncoderDecoderLSTM.load(file_path_model, dropout=0.1)

# Load datasets
file_path_original = '../../../../../encoded_data/sepsis/Sepsis_all_5_test.pkl'
file_path_perturbed = '../../../../../encoded_data/sepsis/Sepsis_all_5_test.pkl'
file_path_redo_activity = '../../../../../encoded_data/sepsis/val.pkl'
file_path_redo_activity_pert = '../../../../../encoded_data/sepsis/redo_activity.pkl'


original_dataset = torch.load(file_path_original, weights_only=False)
perturbed_dataset = torch.load(file_path_perturbed, weights_only=False)
redo_activity_dataset = torch.load(file_path_redo_activity, weights_only=False)
redo_activity_pert_dataset = torch.load(file_path_redo_activity_pert, weights_only=False)


print(f"Original dataset loaded: {len(original_dataset)} cases")
print(f"Perturbed dataset loaded: {len(perturbed_dataset)} cases")


Data set categories:  ([('concept:name', 18, {'Admission IC': 1, 'Admission NC': 2, 'CRP': 3, 'EOS': 4, 'ER Registration': 5, 'ER Sepsis Triage': 6, 'ER Triage': 7, 'IV Antibiotics': 8, 'IV Liquid': 9, 'LacticAcid': 10, 'Leucocytes': 11, 'Release A': 12, 'Release B': 13, 'Release C': 14, 'Release D': 15, 'Release E': 16, 'Return ER': 17}), ('InfectionSuspected', 5, {'EOS': 1, 'False': 2, 'True': 3, nan: 4}), ('org:group', 27, {'?': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'EOS': 7, 'F': 8, 'G': 9, 'H': 10, 'I': 11, 'J': 12, 'K': 13, 'L': 14, 'M': 15, 'N': 16, 'O': 17, 'P': 18, 'Q': 19, 'R': 20, 'S': 21, 'T': 22, 'U': 23, 'V': 24, 'W': 25, 'Y': 26}), ('DiagnosticBlood', 5, {'EOS': 1, 'False': 2, 'True': 3, nan: 4}), ('DisfuncOrg', 5, {'EOS': 1, 'False': 2, 'True': 3, nan: 4}), ('SIRSCritTachypnea', 5, {'EOS': 1, 'False': 2, 'True': 3, nan: 4}), ('Hypotensie', 5, {'EOS': 1, 'False': 2, 'True': 3, nan: 4}), ('SIRSCritHeartRate', 5, {'EOS': 1, 'False': 2, 'True': 3, nan: 4}), ('Infusion'

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Original dataset loaded: 3237 cases
Perturbed dataset loaded: 3237 cases


In [9]:
# Create evaluation instances (NON-RANDOM ORDER)
eval_original = ProbabilisticEvaluation(
    model, original_dataset,
    concept_name='concept:name',
    num_processes=16, 
    growing_num_values=['case_elapsed_time'],
    samples_per_case=100,
    sample_argmax=False,
    use_variance_cat=True,
    use_variance_num=True,
    all_cat=['concept:name', 'org:group', 'lifecycle:transition'],
    all_num=['case_elapsed_time', 'event_elapsed_time'],
    dataset_predefined_prefixes=redo_activity_dataset
)

eval_perturbed = ProbabilisticEvaluation(
    model, perturbed_dataset,
    concept_name='concept:name', #'Activity'
    num_processes=16,
    growing_num_values=['case_elapsed_time'],
    samples_per_case=100,
    sample_argmax=False,
    use_variance_cat=True,
    use_variance_num=True,
    all_cat=['concept:name', 'org:group', 'lifecycle:transition'],
    all_num=['case_elapsed_time', 'event_elapsed_time'],
    dataset_predefined_prefixes=redo_activity_pert_dataset
)

print("ProbabilisticEvaluation instances created")


ProbabilisticEvaluation instances created


In [10]:
# Import robustness metrics module
import robustness.evaluator.robustness_metrics
#importlib.reload(robustness.robustness_metrics)
from robustness.evaluator.robustness_metrics import save_chunk

print("Robustness metrics module imported")

# Helper functions for filtering predictions and calculating remaining time
def filter_prediction_events(prediction_list, concept_name='concept:name'):
    """Filter prediction events to only include concept:name and case_elapsed_time"""
    if prediction_list is None:
        return None
    filtered = []
    for event in prediction_list:
        if not isinstance(event, dict):
            continue
        filtered_event = {}
        if concept_name in event:
            filtered_event[concept_name] = event[concept_name]
        if 'case_elapsed_time' in event:
            filtered_event['case_elapsed_time'] = event['case_elapsed_time']
        filtered.append(filtered_event)
    return filtered

def calculate_remaining_time(prefix, prediction, concept_name='concept:name'):
    """Calculate remaining time from prefix and prediction"""
    if not prefix or not prediction:
        return None
    if not isinstance(prefix[-1], dict) or 'case_elapsed_time' not in prefix[-1]:
        return None
    if not isinstance(prediction[-1], dict) or 'case_elapsed_time' not in prediction[-1]:
        return None
    current_time = prefix[-1]['case_elapsed_time']
    final_time = prediction[-1]['case_elapsed_time']
    return final_time - current_time

def calculate_sampled_remaining_times(prefix, predicted_suffixes, concept_name='concept:name'):
    """Calculate remaining time for each sample in predicted_suffixes"""
    if not prefix or not predicted_suffixes:
        return None
    if not isinstance(prefix[-1], dict) or 'case_elapsed_time' not in prefix[-1]:
        return None
    current_time = prefix[-1]['case_elapsed_time']
    
    remaining_times = []
    for sample in predicted_suffixes:
        if not sample or not isinstance(sample[-1], dict) or 'case_elapsed_time' not in sample[-1]:
            remaining_times.append(None)
        else:
            final_time = sample[-1]['case_elapsed_time']
            remaining_times.append(final_time - current_time)
    return remaining_times


Robustness metrics module imported


In [None]:
# Main evaluation loop
os.makedirs(output_dir, exist_ok=True)

save_every = 50
results = {}
concept_name = 'concept:name'  # Match the concept_name used in ProbabilisticEvaluation

for i, ((case_name_orig, prefix_len_orig, prefix_orig, predicted_suffixes_orig, suffix_orig, mean_pred_orig),
        (case_name_pert, prefix_len_pert, prefix_pert, predicted_suffixes_pert, suffix_pert, mean_pred_pert)) in enumerate(
    tqdm(zip(eval_original.evaluate_with_predifined_prefix(random_order=False), 
             eval_perturbed.evaluate_with_predifined_prefix(random_order=False)), 
         desc="Evaluating robustness")):
    
    # Ensure we're comparing the same case and prefix length
    assert case_name_orig == case_name_pert, f"Case mismatch: {case_name_orig} != {case_name_pert}"

    assert prefix_len_orig == prefix_len_pert, f"Prefix length mismatch: {prefix_len_orig} != {prefix_len_pert}"

    # Filter predictions to only include concept:name and case_elapsed_time
    mean_pred_orig_filtered = filter_prediction_events(mean_pred_orig, concept_name=concept_name)
    predicted_suffixes_orig_filtered = [filter_prediction_events(sample, concept_name=concept_name) 
                                        for sample in predicted_suffixes_orig] if predicted_suffixes_orig else None
    
    mean_pred_pert_filtered = filter_prediction_events(mean_pred_pert, concept_name=concept_name)
    predicted_suffixes_pert_filtered = [filter_prediction_events(sample, concept_name=concept_name) 
                                        for sample in predicted_suffixes_pert] if predicted_suffixes_pert else None
    
    # Calculate remaining times immediately
    mean_pred_remaining_time_orig = calculate_remaining_time(prefix_orig, mean_pred_orig, concept_name=concept_name)
    sampled_remaining_time_orig = calculate_sampled_remaining_times(prefix_orig, predicted_suffixes_orig, concept_name=concept_name)
    
    mean_pred_remaining_time_pert = calculate_remaining_time(prefix_pert, mean_pred_pert, concept_name=concept_name)
    sampled_remaining_time_pert = calculate_sampled_remaining_times(prefix_pert, predicted_suffixes_pert, concept_name=concept_name)

    # Store results with new structure
    key = (case_name_orig, prefix_len_orig)
    results[key] = {
        'original': (
            prefix_orig,  # Keep all fields
            suffix_orig,  # Keep all fields
            mean_pred_orig_filtered,  # Filtered: only concept:name and case_elapsed_time
            predicted_suffixes_orig_filtered,  # Filtered: only concept:name and case_elapsed_time
            mean_pred_remaining_time_orig,  # NEW: single float
            sampled_remaining_time_orig  # NEW: list of floats
        ),
        'perturbed': (
            prefix_pert,  # Keep all fields
            suffix_pert,  # Keep all fields
            mean_pred_pert_filtered,  # Filtered: only concept:name and case_elapsed_time
            predicted_suffixes_pert_filtered,  # Filtered: only concept:name and case_elapsed_time
            mean_pred_remaining_time_pert,  # NEW: single float
            sampled_remaining_time_pert  # NEW: list of floats
        ),
    }

    
    if (i + 1) % save_every == 0:
        save_chunk(results, i, output_dir)
        results = {}

if len(results):
    save_chunk(results, i, output_dir)

print("Robustness evaluation completed!")


Evaluating robustness: 0it [00:00, ?it/s]

  0%|          | 0/1754 [00:00<?, ?it/s]

  0%|          | 0/1754 [00:00<?, ?it/s]

Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_050.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_100.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_150.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_200.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_250.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_300.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_350.pkl
Saved 50 results to ../../../../../evaluation_results/robustness/sepsis/redo_activity/robustness_results_part_400.pkl
Saved 50 results to ../../../../../evaluation_results/ro

In [6]:
# Load all saved chunks and combine them
all_results = {}
# Get all chunk files and sort them
chunk_files = [f for f in os.listdir(output_dir) if f.startswith('robustness_results_part_')]
chunk_files.sort()  # Ensure correct order

print(f"Found {len(chunk_files)} chunk files")

for chunk_file in chunk_files:
    chunk_path = os.path.join(output_dir, chunk_file)
    print(f"Loading {chunk_file}...")
    with open(chunk_path, 'rb') as f:
        chunk_results = pickle.load(f)
        all_results.update(chunk_results)
        print(f"  Added {len(chunk_results)} results from {chunk_file}")

# Also add the final results if any (e.g. from a still-running evaluation loop)
if 'results' in locals() and len(results) > 0:
    print(f"Adding final {len(results)} results...")
    all_results.update(results)

print(f"\nTotal results loaded: {len(all_results)}")

# Save combined results into a single pickle file
combined_results_path = os.path.join(output_dir, 'robustness_results.pkl')
with open(combined_results_path, 'wb') as f:
    pickle.dump(all_results, f)

print(f"Combined results saved to {combined_results_path}")

Found 10 chunk files
Loading robustness_results_part_050.pkl...
  Added 50 results from robustness_results_part_050.pkl
Loading robustness_results_part_100.pkl...
  Added 50 results from robustness_results_part_100.pkl
Loading robustness_results_part_150.pkl...
  Added 50 results from robustness_results_part_150.pkl
Loading robustness_results_part_200.pkl...
  Added 50 results from robustness_results_part_200.pkl
Loading robustness_results_part_250.pkl...
  Added 50 results from robustness_results_part_250.pkl
Loading robustness_results_part_300.pkl...
  Added 50 results from robustness_results_part_300.pkl
Loading robustness_results_part_350.pkl...
  Added 50 results from robustness_results_part_350.pkl
Loading robustness_results_part_400.pkl...
  Added 50 results from robustness_results_part_400.pkl
Loading robustness_results_part_450.pkl...
  Added 50 results from robustness_results_part_450.pkl
Loading robustness_results_part_500.pkl...
  Added 50 results from robustness_results_pa