# Automated Robustness Comparison

This notebook automatically compares robustness metrics across multiple datasets and attacks.

**Features:**
- Supports 1-3 models flexibly (usually just U-ED-LSTM, but can compare multiple models)
- Automatically processes multiple datasets and attacks
- Generates all comparison charts
- Saves charts to organized directories in the `img` folder
- Provides progress tracking and error handling
- Configurable BASE_RESULTS_DIR for benchmarks

## Configuration

Edit this section to configure your datasets and attacks.

In [None]:
# ============================================================================
# CONFIGURATION: Define attack-to-paths mapping and output paths
# ============================================================================

# Dictionary mapping attack names to lists of result paths
# Each path should be the directory containing robustness_results.pkl
# Multiple paths per attack will be treated as multiple models to compare
ATTACK_RESULTS = {
    # "last_event_attack": [
    #     "../../evaluation_results/robustness/sepsis/last_event_attack",
    #     # Add more paths here for comparison (e.g., different model variants)
    #     # "../../evaluation_results/robustness/sepsis/last_event_attack_camargo",
    # ],
    "random_event_attack": [
        "../../evaluation_results/robustness/sepsis/random_event_attack",
    ]
}

# Output paths for each attack (relative to OUTPUT_BASE_DIR)
# Must be in same order as keys in ATTACK_RESULTS
# Format: "dataset/attack" or just "attack" if no dataset folder needed
OUTPUT_PATHS = [
    #"sepsis/last_event_attack",
    "sepsis/random_event_attack",
]

# Base output directory for saved charts (relative to project root)
OUTPUT_BASE_DIR = '../../img'

## Imports and Setup

In [2]:
import sys
import os
import numpy as np
from pathlib import Path

sys.path.insert(0, '..')
sys.path.insert(0, '../..')

from robustness_metrics import (
    load_results, prepare_robustness_results, calculate_aggregate_metrics
)
from robustness_charts import (
    generate_all_charts_for_comparison,
    generate_summary_table
)

# Create output directory
Path(OUTPUT_BASE_DIR).mkdir(parents=True, exist_ok=True)

# Color and marker assignments for models (automatically assigned)
MODEL_COLORS = ['blue', 'orange', 'green']
MODEL_MARKERS = ['o', 's', '^']

## Helper Functions

In [3]:
def load_and_prepare_model_results(result_path, model_name=None):
    """
    Load and prepare results from a direct path.
    
    Args:
        result_path: Full path to directory containing robustness_results.pkl, 
                     or full path to the .pkl file itself
        model_name: Optional model name. If None, auto-generated as "Model 1", "Model 2", etc.
    
    Returns:
        Tuple of (results_dict, aggregate_data_dict, model_name) or (None, None, None) if not found
    """
    # Construct path to .pkl file if directory path is provided
    if result_path.endswith('.pkl'):
        results_path = result_path
    else:
        results_path = os.path.join(result_path, 'robustness_results.pkl')
    
    # Auto-generate model name if not provided
    if model_name is None:
        # Extract a default name from path if possible, otherwise use generic
        model_name = "Model"  # Will be numbered by caller
    
    if not os.path.exists(results_path):
        print(f"  Warning: Results not found at {results_path}")
        return None, None, None
    
    try:
        results = load_results(results_path)
        results = prepare_robustness_results(results, save_path=results_path)
        data = calculate_aggregate_metrics(results)
        return results, data, model_name
    except Exception as e:
        print(f"  Error loading from {results_path}: {e}")
        return None, None, None

## Main Automation Loop

In [4]:
# Validate configuration
if len(ATTACK_RESULTS) != len(OUTPUT_PATHS):
    raise ValueError(f"Mismatch: ATTACK_RESULTS has {len(ATTACK_RESULTS)} entries, but OUTPUT_PATHS has {len(OUTPUT_PATHS)} entries. They must match in order.")

# Track progress
total_combinations = len(ATTACK_RESULTS)
processed = 0
successful = 0
failed = []

print("="*80)
print("AUTOMATED ROBUSTNESS COMPARISON")
print("="*80)
print(f"Attacks: {len(ATTACK_RESULTS)}")
for attack_name, result_paths in ATTACK_RESULTS.items():
    print(f"  - {attack_name}: {len(result_paths)} result path(s)")
print(f"Total combinations: {total_combinations}")
print(f"Output directory: {OUTPUT_BASE_DIR}")
print("="*80)

# Convert ATTACK_RESULTS to list to maintain order and match with OUTPUT_PATHS
attack_items = list(ATTACK_RESULTS.items())

for attack_idx, (attack_name, result_paths) in enumerate(attack_items):
    processed += 1
    
    # Get corresponding output path
    output_path = OUTPUT_PATHS[attack_idx]
    
    # Extract dataset and attack from output_path for chart function parameters
    # Format: "dataset/attack" or just "attack"
    if '/' in output_path:
        dataset, attack = output_path.split('/', 1)
    else:
        dataset = ''
        attack = output_path
    
    print(f"\n  [{processed}/{total_combinations}] Processing {attack_name}...")
    print(f"    Output path: {output_path}")
    
    # Load results for all models from this attack's result paths
    loaded_models = []
    for path_idx, result_path in enumerate(result_paths):
        # Auto-generate model name: "Model 1", "Model 2", etc.
        model_name = f"Model {path_idx + 1}"
        
        results, data, returned_model_name = load_and_prepare_model_results(result_path, model_name)
        
        if results is None or data is None:
            print(f"    Skipping {result_path}: Missing results")
            continue
        
        # Add color and marker based on position
        model_dict = {
            'name': returned_model_name,  # Use 'name' key (not 'model_name') - fixes KeyError
            'model_id': f"model_{path_idx + 1}",  # Add model_id for wasserstein chart filename
            'results': results,
            'data': data,
            'color': MODEL_COLORS[path_idx % len(MODEL_COLORS)],
            'marker': MODEL_MARKERS[path_idx % len(MODEL_MARKERS)]
        }
        loaded_models.append(model_dict)
    
    # Check if we have at least one model
    if not loaded_models:
        print(f"    Skipping: No models loaded for this attack")
        failed.append((attack_name, "No models loaded"))
        continue
    
    try:
        # Generate all charts (using extracted dataset/attack for chart function parameters)
        charts = generate_all_charts_for_comparison(
            dataset, attack, loaded_models, OUTPUT_BASE_DIR
        )
        
        # Generate summary table
        summary = generate_summary_table(dataset, attack, loaded_models, OUTPUT_BASE_DIR)
        
        print(f"    ✓ Generated {len(charts)} charts")
        print(f"    ✓ Saved to: {OUTPUT_BASE_DIR}/{output_path}/")
        successful += 1
        
    except Exception as e:
        print(f"    ✗ Error: {e}")
        failed.append((attack_name, str(e)))
        import traceback
        traceback.print_exc()

# Final summary
print("\n" + "="*80)
print("PROCESSING COMPLETE")
print("="*80)
print(f"Total combinations: {total_combinations}")
print(f"Successful: {successful}")
print(f"Failed: {len(failed)}")

if failed:
    print("\nFailed combinations:")
    for attack_name, reason in failed:
        print(f"  - {attack_name}: {reason}")

print(f"\nAll charts saved to: {OUTPUT_BASE_DIR}/")

AUTOMATED ROBUSTNESS COMPARISON
Attacks: 2
  - last_event_attack: 1 result path(s)
  - random_event_attack: 1 result path(s)
Total combinations: 2
Output directory: ../../img

  [1/2] Processing last_event_attack...
    Output path: sepsis/last_event_attack
    ✓ Generated 12 charts
    ✓ Saved to: ../../img/sepsis/last_event_attack/

  [2/2] Processing random_event_attack...
    Output path: sepsis/random_event_attack
    Skipping ../../evaluation_results/robustness/sepsis/random_event_attack: Missing results
    Skipping: No models loaded for this attack

PROCESSING COMPLETE
Total combinations: 2
Successful: 1
Failed: 1

Failed combinations:
  - random_event_attack: No models loaded

All charts saved to: ../../img/
