In [None]:
# --- Cell 0: Fix package dependencies ---
!pip install --force-reinstall numpy
!pip install --force-reinstall pandas
!pip install --force-reinstall matplotlib seaborn scikit-learn tqdm
!pip install --force-reinstall torch torchvision
!pip install --force-reinstall captum

# Restart the runtime after running this cell
# before running your import cell again

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# --- Cell 0.B: Install Captum ---
!pip install captum -q

print("Captum installation attempt complete.")
print(">>> CRITICAL: Please go to 'Runtime' -> 'Restart session' NOW one more time before running your main imports! <<<")

In [None]:
# --- Cell 1: Python Imports (Run this part AFTER pip installs and a manual Runtime Restart) ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import pandas as pd # This is where the error was occurring
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, auc
)
try:
    from captum.attr import IntegratedGradients, Saliency, NoiseTunnel
except ImportError:
    print("Captum not installed or import failed. Explanations requiring Captum will not be available.")
    IntegratedGradients = None
    Saliency = None
    NoiseTunnel = None

import os
import math
import time
import copy
import warnings
from typing import List, Tuple, Dict, Optional, Union, Callable, Any

# General settings
warnings.filterwarnings('ignore', category=UserWarning)
# Check if 'seaborn-v0_8-whitegrid' is a valid style, otherwise use a default like 'ggplot' or 'seaborn-v0_8-darkgrid'
try:
    plt.style.use('seaborn-v0_8-whitegrid')
except OSError:
    print("Style 'seaborn-v0_8-whitegrid' not found, using 'ggplot'.")
    plt.style.use('ggplot')
sns.set_palette('pastel')
EPSILON = 1e-8

print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Torch version: {torch.__version__}")
print("Cell 1: Python Imports executed successfully.")

In [None]:
# --- Cell 2: Configuration and Device Setup ---

# --- 1. General Settings ---
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# --- 2. Data Paths ---
DRIVE_MOUNT_POINT = '/content/drive' # Or None if not using Colab/Drive
BASE_DATA_PATH = os.path.join(DRIVE_MOUNT_POINT, 'MyDrive') if DRIVE_MOUNT_POINT and os.path.exists(os.path.join(DRIVE_MOUNT_POINT, 'MyDrive')) else './' # Default to current dir if path missing

if not os.path.exists(BASE_DATA_PATH) and BASE_DATA_PATH != './':
    print(f"Warning: Default BASE_DATA_PATH '{BASE_DATA_PATH}' not found. Using current directory './' for data.")
    BASE_DATA_PATH = './'
elif BASE_DATA_PATH != './':
     print(f"Using BASE_DATA_PATH: {BASE_DATA_PATH}")


DATA_PATHS = {
    'mitbih_train': os.path.join(BASE_DATA_PATH, 'mitbih_train.csv'),
    'mitbih_test': os.path.join(BASE_DATA_PATH, 'mitbih_test.csv'),
    'ptbdb_normal': os.path.join(BASE_DATA_PATH, 'ptbdb_normal.csv'),
    'ptbdb_abnormal': os.path.join(BASE_DATA_PATH, 'ptbdb_abnormal.csv'),
}

# --- 3. Model & Training Hyperparameters ---
# These can be overridden for specific dataset experiments
GENERAL_TRAINING_CONFIG = {
    'epochs': 15, # Increased epochs, with early stopping
    'batch_size': 64,
    'learning_rate': 1e-4,
    'patience': 10, # Early stopping patience
    'checkpoint_dir': './checkpoints', # Directory to save model checkpoints
}
os.makedirs(GENERAL_TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)

# For Stage 1 (Best Base Classifier), we set loss_beta to 0.0 and deferral_threshold_train to a very large value (effectively infinity)
CGD_MODEL_CONFIG = {
    'input_dim': 1,
    'latent_dim': 64,
    'deferral_threshold_train': 1e9,
    'loss_alpha': 0.0,
    'loss_beta': 0.0,
    'defer_cost_factor': 0.3,        #
    'max_seq_length': 187,
}


ENCODER_CONFIG = {
    'embed_dim': CGD_MODEL_CONFIG['latent_dim'],
    'num_heads': 4,
    'num_layers': 3,
    'dim_feedforward_factor': 4,
    'dropout': 0.15,
    'activation': 'gelu',
    'aggregation_method': 'mean', # 'mean', 'last', 'cls'
}

PREDICTOR_CONFIG = {
    # output_dim will be set per dataset
    'hidden_dims': [CGD_MODEL_CONFIG['latent_dim'] * 2, CGD_MODEL_CONFIG['latent_dim']], # Larger hidden layers
    'dropout': 0.2,
    # activation will be 'linear' for CrossEntropyLoss/BCEWithLogitsLoss
}

STRUCTURAL_REGULARIZER_CONFIG = {
    'regularization_type': 'contrastive', # 'contrastive', 'prototype', 'none'
    'temperature': 0.1,
    # For 'prototype'
    # 'num_prototypes': 10,
    # 'prototype_lambda': 0.1
}


# Experiment Option 2: ECG-Relevant Input Masking + Mean Displacement
# --- In Cell 2: Example Configuration 1 ---
PERTURBATION_CONFIG = {
    'active_types': ['input_masking'],    # Only input masking
    'gaussian_noise_level': 0.1,
    'feature_dropout_rate': 0.1,
    'temporal_swap_rate': 0.05,
    'input_masking_rate': 0.15,           # Mask 15% of the sequence
    'input_masking_chunk_size': 18,       # Chunk size ~10% of 187 steps, could mask a beat component
    'num_perturbations': 15,
}

SENSITIVITY_CONFIG = {
    'active_measures': ['mean_displacement'], # Only mean displacement
    'aggregation_method_for_multiple_sensitivities': 'mean',
}


ADAPTIVE_THRESHOLD_CONFIG = {
    'method': 'max_acc_under_budget',
    'percentile_value_for_threshold': 90, # Not used by this method
    'target_defer_rate_value': 0.10,  # Not used by this method
    'max_defer_rate_budget': 0.25,    # 0.25 means don't defer more than 25%
    'num_threshold_candidates': 200   # Usually fine
}

EXPLAINER_CONFIG = {
    'method': 'integrated_gradients', # 'saliency', 'integrated_gradients'
    'n_steps_ig': 25, # For Integrated Gradients
    'noise_tunnel_nt_type': 'smoothgrad_sq', # For NoiseTunnel with Saliency
    'noise_tunnel_stdevs': 0.1,           # For NoiseTunnel
    'noise_tunnel_nt_samples': 5,         # For NoiseTunnel
}

DEFERRAL_HEAD_CONFIG = {
    'hidden_dims_deferral_head': [64, 32], # Structure of the Deferral Head MLP
    'dropout_deferral_head': 0.2,
    'learning_rate': 1e-3,      # LR for training the Deferral Head
    'epochs': 50,               # Max epochs for training the Deferral Head
    'batch_size': 256,          # Batch size for Deferral Head training data
    'patience': 5,              # Early stopping patience for Deferral Head
    'target_type': 'error_prediction' # Currently only 'error_prediction' is fully fleshed out
}

print("\n--- Configurations ---")
print(f"DEVICE: {DEVICE}")
print(f"DATA_PATHS: {DATA_PATHS}")
print(f"GENERAL_TRAINING_CONFIG: {GENERAL_TRAINING_CONFIG}")
print(f"CGD_MODEL_CONFIG: {CGD_MODEL_CONFIG}")
print(f"ENCODER_CONFIG: {ENCODER_CONFIG}")
print(f"PREDICTOR_CONFIG: {PREDICTOR_CONFIG}")
print(f"STRUCTURAL_REGULARIZER_CONFIG: {STRUCTURAL_REGULARIZER_CONFIG}")
print(f"PERTURBATION_CONFIG: {PERTURBATION_CONFIG}")
print(f"SENSITIVITY_CONFIG: {SENSITIVITY_CONFIG}")
print(f"ADAPTIVE_THRESHOLD_CONFIG: {ADAPTIVE_THRESHOLD_CONFIG}")
print(f"EXPLAINER_CONFIG: {EXPLAINER_CONFIG}")

print("\nCell 2: Configuration and Device Setup executed successfully.")

In [None]:
# --- Revised Cell 3: Unzip Data, Load, Preprocess, and Create DataLoaders ---
import zipfile

# --- 1. Unzip Dataset ---
ZIP_DATASET_PATH = '/content/drive/MyDrive/archive.zip'
EXTRACTION_DIR = './ecg_data_extracted/' # Directory to extract files into

if os.path.exists(ZIP_DATASET_PATH):
    print(f"Found dataset zip file at: {ZIP_DATASET_PATH}")
    os.makedirs(EXTRACTION_DIR, exist_ok=True)
    try:
        with zipfile.ZipFile(ZIP_DATASET_PATH, 'r') as zip_ref:
            zip_ref.extractall(EXTRACTION_DIR)
        print(f"Successfully extracted dataset to: {EXTRACTION_DIR}")

        DATA_PATHS = {
            'mitbih_train': os.path.join(EXTRACTION_DIR, 'mitbih_train.csv'),
            'mitbih_test': os.path.join(EXTRACTION_DIR, 'mitbih_test.csv'),
            'ptbdb_normal': os.path.join(EXTRACTION_DIR, 'ptbdb_normal.csv'),
            'ptbdb_abnormal': os.path.join(EXTRACTION_DIR, 'ptbdb_abnormal.csv'),
        }
        print(f"Updated DATA_PATHS to use extracted files: {DATA_PATHS}")
    except Exception as e:
        print(f"Error unzipping {ZIP_DATASET_PATH}: {e}")
        print("Proceeding with potentially empty or previously defined DATA_PATHS from Cell 2.")
else:
    print(f"Warning: Dataset zip file not found at {ZIP_DATASET_PATH}.")
    print("Ensure the path is correct or data is already extracted and DATA_PATHS (from Cell 2) are set accordingly.")
    print("If DATA_PATHS are not set to existing CSVs, dummy data logic might be triggered later.")


# --- 2. Data Loading & Preprocessing Utilities (same as before) ---

def load_raw_ecg_data(data_paths: Dict[str, str]) -> Dict[str, Optional[pd.DataFrame]]:
    """Loads raw ECG data from specified CSV file paths."""
    raw_data = {}
    print("\nLoading raw ECG data...")
    for key, path in data_paths.items():
        try:
            if not os.path.exists(path):
                print(f"Warning: File not found for {key} at {path}. Skipping.")
                raw_data[key] = None
                continue
            raw_data[key] = pd.read_csv(path, header=None)
            print(f"Successfully loaded {key} from {path}, shape: {raw_data[key].shape if raw_data[key] is not None else 'N/A'}")
        except FileNotFoundError:
            print(f"Error: File not found for {key} at {path}. Returning None for this key.")
            raw_data[key] = None
        except Exception as e:
            print(f"An error occurred while loading {key} from {path}: {e}")
            raw_data[key] = None

    # Specifically combine PTB normal and abnormal
    # The CSVs have 188 columns. Last col (187) is the label for MIT-BIH.
    # For PTB, we'll add a new label column.
    ptb_normal_df = raw_data.get('ptbdb_normal')
    ptb_abnormal_df = raw_data.get('ptbdb_abnormal')

    if ptb_normal_df is not None and ptb_abnormal_df is not None:
        # PTB data has 188 columns, features are 0 to 186, col 187 is a heartbeat indicator (mostly 1.0)
        # We will use cols 0-186 as features.
        ptb_normal_df_features = ptb_normal_df.iloc[:, :CGD_MODEL_CONFIG['max_seq_length']]
        ptb_abnormal_df_features = ptb_abnormal_df.iloc[:, :CGD_MODEL_CONFIG['max_seq_length']]

        ptb_normal_df_processed = ptb_normal_df_features.copy()
        ptb_normal_df_processed['label'] = 0  # Add normal label

        ptb_abnormal_df_processed = ptb_abnormal_df_features.copy()
        ptb_abnormal_df_processed['label'] = 1 # Add abnormal label

        raw_data['ptbdb_combined'] = pd.concat([ptb_normal_df_processed, ptb_abnormal_df_processed], axis=0, ignore_index=True)
        print(f"Combined PTB dataset created, shape: {raw_data['ptbdb_combined'].shape}")
        # We will remove individual ptb normal/abnormal if combined successfully
        if 'ptbdb_normal' in raw_data: del raw_data['ptbdb_normal']
        if 'ptbdb_abnormal' in raw_data: del raw_data['ptbdb_abnormal']
    elif 'ptbdb_normal' in data_paths or 'ptbdb_abnormal' in data_paths : # check if keys exist in original paths
        print("Warning: Could not combine PTB datasets as one or both parts are missing/failed to load.")
        raw_data['ptbdb_combined'] = None
    return raw_data

def preprocess_ecg_features(
    df: pd.DataFrame,
    is_train: bool = True,
    scaler: Optional[StandardScaler] = None,
    expected_features: int = 187 # MIT-BIH and PTB datasets have 187 features
) -> Tuple[np.ndarray, np.ndarray, Optional[StandardScaler]]:
    """
    Scales the ECG features. Assumes features are in columns 0 to expected_features-1,
    and the last column of the input df is the label.
    """
    if df is None or df.empty:
        print("Warning: preprocess_ecg_features received an empty or None DataFrame.")
        return np.array([]), np.array([]), scaler

    X = df.iloc[:, :expected_features].values
    y = df.iloc[:, -1].values # Assumes label is the actual last column

    if X.shape[1] != expected_features:
        raise ValueError(f"Expected {expected_features} features, but got {X.shape[1]} from DataFrame slice.")

    if is_train:
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
    else:
        if scaler is None:
            raise ValueError("Scaler must be provided for non-train (test/validation) data.")
        X_scaled = scaler.transform(X)
    return X_scaled, y.astype(int), scaler


class TimeSeriesDataset(Dataset):
    """Generic Dataset for time series (adapted for ECG fixed length)."""
    def __init__(self, sequences: np.ndarray, labels: np.ndarray, num_classes_for_problem: int):
        """
        Args:
            sequences: Array of sequences [num_samples, seq_length]. Expected to be scaled.
            labels: Array of corresponding labels [num_samples].
            num_classes_for_problem: Number of unique classes for the problem (e.g., 2 for PTB, 5 for MIT-BIH).
        """
        if sequences.ndim == 2:
            self.sequences = torch.from_numpy(sequences).float().unsqueeze(-1) # [N, S, 1]
        elif sequences.ndim == 3 and sequences.shape[-1] == 1:
             self.sequences = torch.from_numpy(sequences).float()
        else:
            raise ValueError(f"Sequences should be 2D or 3D with last dim 1, got {sequences.shape}")

        if num_classes_for_problem == 2: # Binary classification
             self.labels = torch.from_numpy(labels).float() # For BCEWithLogitsLoss
        else: # Multiclass classification
             self.labels = torch.from_numpy(labels).long()  # For CrossEntropyLoss
        self.num_classes = num_classes_for_problem


    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.sequences[idx], self.labels[idx]

def collate_fn_fixed_length(
    batch: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    sequences, labels = zip(*batch)
    sequences_stacked = torch.stack(sequences)
    labels_stacked = torch.stack(labels)
    batch_size, seq_length, _ = sequences_stacked.shape
    padding_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool)
    return sequences_stacked, labels_stacked, padding_mask

def prepare_dataloaders(
    X_train: np.ndarray, y_train: np.ndarray,
    X_val: np.ndarray, y_val: np.ndarray,
    num_classes: int, # Number of classes for the specific problem
    X_test: Optional[np.ndarray] = None, y_test: Optional[np.ndarray] = None,
    batch_size: int = 64, num_workers: int = 2
) -> Dict[str, DataLoader]:
    if X_train.size == 0 or X_val.size == 0 :
        raise ValueError("Train or Validation data is empty. Cannot create DataLoaders.")

    train_dataset = TimeSeriesDataset(X_train, y_train, num_classes_for_problem=num_classes)
    val_dataset = TimeSeriesDataset(X_val, y_val, num_classes_for_problem=num_classes)

    dataloaders = {
        'train': DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_fixed_length, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False),
        'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_fixed_length, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False)
    }

    if X_test is not None and y_test is not None and X_test.size > 0:
        test_dataset = TimeSeriesDataset(X_test, y_test, num_classes_for_problem=num_classes)
        dataloaders['test'] = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_fixed_length, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False)
    elif X_test is not None or y_test is not None :
        print("Warning: Test data or labels are partially provided. Test DataLoader not created.")

    print(f"Created DataLoaders for a {num_classes}-class problem: Train batches={len(dataloaders['train'])}, Val batches={len(dataloaders['val'])}" +
          (f", Test batches={len(dataloaders['test'])}" if 'test' in dataloaders else ""))
    return dataloaders

# --- 3. Load and Process Data (using the actual data if unzipped) ---
# This block will now attempt to load and process the actual data.
# If unzipping failed or DATA_PATHS are incorrect, it might try to load non-existent files or previously created dummies.

print("\n--- Processing Datasets ---")
raw_dfs = load_raw_ecg_data(DATA_PATHS) # DATA_PATHS should be updated if unzipping was successful

# Process MIT-BIH
mitbih_train_df = raw_dfs.get('mitbih_train')
mitbih_test_df = raw_dfs.get('mitbih_test')
mitbih_loaders = None
expected_mitbih_features = CGD_MODEL_CONFIG['max_seq_length'] # Should be 187

if mitbih_train_df is not None and mitbih_test_df is not None:
    print(f"\nProcessing MIT-BIH (expecting {expected_mitbih_features} features)...")
    # For MIT-BIH, the last column (index 187) is the label. Features are columns 0-186.
    X_mtrain, y_mtrain, scaler_m = preprocess_ecg_features(mitbih_train_df, is_train=True, expected_features=expected_mitbih_features)
    # Use a portion of the original test set as validation, and the rest as test
    # Or, if you want to use the full mitbih_test_df as validation like in the original script:
    X_mval_full, y_mval_full, _ = preprocess_ecg_features(mitbih_test_df, is_train=False, scaler=scaler_m, expected_features=expected_mitbih_features)

    if X_mval_full.shape[0] > 10: # Ensure there's enough data to split
        val_size_mit = int(0.5 * X_mval_full.shape[0]) # Splitting the original test set
        X_mval, X_mtest = X_mval_full[:val_size_mit], X_mval_full[val_size_mit:]
        y_mval, y_mtest = y_mval_full[:val_size_mit], y_mval_full[val_size_mit:]
        print(f"MIT-BIH original test set split into: Val ({X_mval.shape[0]} samples), Test ({X_mtest.shape[0]} samples)")
    else: # Not enough data, use all for validation, none for test
        X_mval, y_mval = X_mval_full, y_mval_full
        X_mtest, y_mtest = None, None
        print("MIT-BIH original test set used entirely for validation due to small size.")


    if X_mtrain.size > 0 and X_mval.size > 0:
        print(f"MIT-BIH scaled shapes: X_train={X_mtrain.shape}, y_train={y_mtrain.shape}, X_val={X_mval.shape}, y_val={y_mval.shape}" + (f", X_test={X_mtest.shape}, y_test={y_mtest.shape}" if X_mtest is not None else ""))
        mitbih_loaders = prepare_dataloaders(
            X_mtrain, y_mtrain, X_mval, y_mval, X_test=X_mtest, y_test=y_mtest,
            num_classes=5, batch_size=GENERAL_TRAINING_CONFIG['batch_size']
        )
        X_b, y_b, m_b = next(iter(mitbih_loaders['train']))
        print(f"MIT-BIH sample train batch: X_shape={X_b.shape}, y_shape={y_b.shape}, mask_shape={m_b.shape}, y_dtype={y_b.dtype}")
    else:
        print("Skipping MIT-BIH DataLoader creation due to empty preprocessed data.")
else:
    print("Skipping MIT-BIH processing (raw data missing or failed to load).")

# Process PTB
ptb_combined_df = raw_dfs.get('ptbdb_combined')
ptb_loaders = None
expected_ptb_features = CGD_MODEL_CONFIG['max_seq_length'] # Should be 187

if ptb_combined_df is not None:
    print(f"\nProcessing PTB (expecting {expected_ptb_features} features)...")
    # Stratified split of the combined PTB data first
    ptb_train_val_df, ptb_test_df = train_test_split(ptb_combined_df, test_size=0.2, random_state=SEED, stratify=ptb_combined_df.iloc[:, -1])
    ptb_train_df, ptb_val_df = train_test_split(ptb_train_val_df, test_size=0.2, random_state=SEED, stratify=ptb_train_val_df.iloc[:, -1]) # 0.2 of (0.8) = 0.16

    X_ptrain, y_ptrain, scaler_p = preprocess_ecg_features(ptb_train_df, is_train=True, expected_features=expected_ptb_features)
    X_pval, y_pval, _ = preprocess_ecg_features(ptb_val_df, is_train=False, scaler=scaler_p, expected_features=expected_ptb_features)
    X_ptest, y_ptest, _ = preprocess_ecg_features(ptb_test_df, is_train=False, scaler=scaler_p, expected_features=expected_ptb_features)

    if X_ptrain.size > 0 and X_pval.size > 0:
        print(f"PTB scaled shapes: X_train={X_ptrain.shape}, y_train={y_ptrain.shape}, X_val={X_pval.shape}, y_val={y_pval.shape}, X_test={X_ptest.shape}, y_test={y_ptest.shape}")
        ptb_loaders = prepare_dataloaders(
            X_ptrain, y_ptrain, X_pval, y_pval, X_test=X_ptest, y_test=y_ptest,
            num_classes=2, batch_size=GENERAL_TRAINING_CONFIG['batch_size']
        )
        X_b_ptb, y_b_ptb, m_b_ptb = next(iter(ptb_loaders['train']))
        print(f"PTB sample train batch: X_shape={X_b_ptb.shape}, y_shape={y_b_ptb.shape}, mask_shape={m_b_ptb.shape}, y_dtype={y_b_ptb.dtype}")
    else:
        print("Skipping PTB DataLoader creation due to empty preprocessed data.")
else:
    print("Skipping PTB processing (raw combined data missing or failed to load).")


print("\nCell 3 (Revised): Data Unzipping, Loading & Preprocessing Utilities executed successfully.")

In [None]:
# --- Corrected Cell 4: Core Model Components - Encoder ---

class PositionalEncoding(nn.Module):
    """Standard Sinusoidal Positional Encoding."""
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500): # Increased default max_len
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1) # [max_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # [d_model/2]

        pe = torch.zeros(max_len, 1, d_model) # Shape [max_len, 1, d_model]
        pe.requires_grad = False

        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim] (if batch_first=True for Transformer)
        Returns:
            Tensor with positional encoding added.
        """
        # x is expected to be [batch_size, seq_len, embedding_dim]
        # self.pe is [max_len, 1, embedding_dim]
        # We need to add pe[ :seq_len, 0, : ] to x
        # so pe_slice should be [seq_len, 1, embedding_dim] -> broadcasted or expanded for addition

        seq_len = x.size(1)
        if seq_len > self.pe.size(0):
            raise ValueError(f"Input sequence length ({seq_len}) exceeds PositionalEncoding max_len ({self.pe.size(0)})")

        # self.pe[:seq_len] gives shape [seq_len, 1, d_model]
        # x is [batch_size, seq_len, d_model]
        # We want to add positional encoding to each sample in the batch.
        # Permute x to [seq_len, batch_size, d_model] for easier broadcasting with pe[:seq_len]
        x_permuted = x.permute(1, 0, 2) # -> [seq_len, batch_size, d_model]
        x_permuted = x_permuted + self.pe[:seq_len, :, :] # Add pe slice (broadcasting the '1' dim)
        x_final = x_permuted.permute(1, 0, 2) # -> [batch_size, seq_len, d_model]

        return self.dropout(x_final)

class TimeSeriesTransformerEncoder(nn.Module):
    """
    Transformer-based encoder for time-series data.
    Uses configurations from ENCODER_CONFIG.
    """
    def __init__(self, input_dim: int, config: Dict[str, Any], max_seq_len_data: int): # Renamed max_seq_len to max_seq_len_data
        super().__init__()
        self.config = config
        self.embed_dim = config.get('embed_dim', 64)
        self.aggregation_method = config.get('aggregation_method', 'mean')

        # 1. Input Embedding/Projection
        self.feature_embedding = nn.Linear(input_dim, self.embed_dim)

        # 2. Determine max_len for positional encoding based on potential CLS token
        pe_max_len = max_seq_len_data
        if self.aggregation_method == 'cls':
            pe_max_len += 1 # Account for the CLS token

        self.pos_encoder = PositionalEncoding(
            d_model=self.embed_dim,
            dropout=config.get('dropout', 0.1),
            max_len=pe_max_len
        )

        # 3. Optional CLS Token
        if self.aggregation_method == 'cls':
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
            nn.init.normal_(self.cls_token, std=0.02)

        # 4. Transformer Encoder Layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=config.get('num_heads', 4),
            dim_feedforward=self.embed_dim * config.get('dim_feedforward_factor', 4),
            dropout=config.get('dropout', 0.1),
            activation=config.get('activation', 'relu'),
            batch_first=True,
            norm_first=config.get('norm_first', False)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=config.get('num_layers', 2),
            norm=nn.LayerNorm(self.embed_dim) if config.get('norm_first', False) else None
        )

    def forward(self, x: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_length_orig, _ = x.shape
        x_embed = self.feature_embedding(x)

        current_seq_length = seq_length_orig
        if self.aggregation_method == 'cls':
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)
            x_embed = torch.cat((cls_tokens, x_embed), dim=1)
            current_seq_length += 1
            if src_key_padding_mask is not None:
                cls_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=x.device)
                src_key_padding_mask = torch.cat((cls_mask, src_key_padding_mask), dim=1)

        # Add positional encoding - PositionalEncoding.forward expects [B, S, E]
        x_pe = self.pos_encoder(x_embed)

        z_seq = self.transformer_encoder(x_pe, src_key_padding_mask=src_key_padding_mask)

        if self.aggregation_method == 'mean':
            if src_key_padding_mask is not None:
                active_elements_mask = (~src_key_padding_mask).float().unsqueeze(-1)
                masked_sum = (z_seq * active_elements_mask).sum(dim=1)
                valid_seq_lengths = active_elements_mask.sum(dim=1)
                valid_seq_lengths = torch.clamp(valid_seq_lengths, min=EPSILON)
                z = masked_sum / valid_seq_lengths
            else:
                z = z_seq.mean(dim=1)
        elif self.aggregation_method == 'last':
            if src_key_padding_mask is not None:
                sequence_lengths = (~src_key_padding_mask).sum(dim=1) - 1
                sequence_lengths = torch.clamp(sequence_lengths, min=0)
                z = z_seq[torch.arange(batch_size, device=x.device), sequence_lengths]
            else:
                # If CLS token was added and this is 'last', it might pick CLS if it's truly last.
                # Usually 'last' implies no CLS token, or CLS is handled by 'cls' method.
                # Assuming effective sequence length if no padding.
                z = z_seq[:, current_seq_length -1, :] # takes the actual last token output
        elif self.aggregation_method == 'cls':
            z = z_seq[:, 0, :] # CLS token is at the beginning
        else:
            raise ValueError(f"Unsupported aggregation_method: {self.aggregation_method}")

        return z

# --- Example Usage (can be commented out after testing) ---
if __name__ == '__main__':
    print("\n--- Testing Encoder Components (Corrected) ---")
    # Use configs from Cell 2
    test_batch_size = 4
    test_seq_length_data = CGD_MODEL_CONFIG['max_seq_length'] # 187 (original data seq length)
    test_input_dim = CGD_MODEL_CONFIG['input_dim']     # 1

    dummy_x = torch.randn(test_batch_size, test_seq_length_data, test_input_dim).to(DEVICE)
    dummy_padding_mask = torch.zeros(test_batch_size, test_seq_length_data, dtype=torch.bool).to(DEVICE)
    if test_batch_size > 1 and test_seq_length_data > 10:
        dummy_padding_mask[1, -10:] = True

    print(f"Dummy input shape: {dummy_x.shape}")
    print(f"Dummy padding mask shape: {dummy_padding_mask.shape}")

    # Test TimeSeriesTransformerEncoder with 'mean'
    try:
        mean_encoder_config = ENCODER_CONFIG.copy()
        mean_encoder_config['aggregation_method'] = 'mean'
        encoder_mean = TimeSeriesTransformerEncoder(
            input_dim=test_input_dim,
            config=mean_encoder_config,
            max_seq_len_data=test_seq_length_data # Pass original data max length
        ).to(DEVICE)
        encoder_mean.eval()
        with torch.no_grad():
            output_z_mean = encoder_mean(dummy_x, src_key_padding_mask=dummy_padding_mask)
        print(f"Encoder output shape (aggregation='mean'): {output_z_mean.shape}")
        assert output_z_mean.shape == (test_batch_size, ENCODER_CONFIG.get('embed_dim'))
        print("Encoder (mean aggregation) tested successfully.")

        # Test CLS token aggregation specifically if configured
        cls_encoder_config = ENCODER_CONFIG.copy()
        cls_encoder_config['aggregation_method'] = 'cls'
        encoder_cls = TimeSeriesTransformerEncoder(
            input_dim=test_input_dim,
            config=cls_encoder_config,
            max_seq_len_data=test_seq_length_data # Pass original data max length
        ).to(DEVICE)
        encoder_cls.eval()
        with torch.no_grad():
            output_z_cls = encoder_cls(dummy_x, src_key_padding_mask=dummy_padding_mask)
        print(f"Encoder output shape (aggregation='cls'): {output_z_cls.shape}")
        assert output_z_cls.shape == (test_batch_size, cls_encoder_config.get('embed_dim'))
        print("Encoder (cls aggregation) tested successfully.")

    except Exception as e:
        print(f"Error during encoder component testing: {e}")
        import traceback
        traceback.print_exc()

print("\nCell 4 (Corrected): Core Model Components - Encoder executed successfully.")

# Cell 5: Core Model Components - Predictor

In [None]:
# --- Cell 5: Core Model Components - Predictor ---

class CGDPredictor(nn.Module):
    """
    Prediction module (MLP) mapping latent space to output predictions.
    Uses configurations from PREDICTOR_CONFIG.
    """
    def __init__(self, latent_dim: int, output_dim: int, config: Dict[str, Any]):
        super().__init__()
        self.config = config
        hidden_dims = config.get('hidden_dims', [max(latent_dim // 2, 16), max(latent_dim // 4, 8)])
        dropout_rate = config.get('dropout', 0.1)
        # Final activation is 'linear' if using CrossEntropyLoss or BCEWithLogitsLoss
        # as these losses prefer raw logits.
        final_activation_name = config.get('activation', 'linear').lower()

        layers = []
        current_dim = latent_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, h_dim))
            layers.append(nn.BatchNorm1d(h_dim)) # Batch norm often helps stabilize MLP training
            layers.append(nn.ReLU()) # Common activation for hidden layers
            layers.append(nn.Dropout(dropout_rate))
            current_dim = h_dim

        # Output layer
        layers.append(nn.Linear(current_dim, output_dim))

        # Add final activation if specified (and not 'linear'/'none')
        if final_activation_name == 'sigmoid' and output_dim == 1:
            layers.append(nn.Sigmoid())
        elif final_activation_name == 'softmax' and output_dim > 1:
            layers.append(nn.Softmax(dim=-1))
        elif final_activation_name not in ['linear', 'none']:
            print(f"Warning: Unsupported final_activation '{final_activation_name}' for predictor, defaulting to linear.")

        self.predictor_mlp = nn.Sequential(*layers)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z: Latent representation [batch_size, latent_dim]
        Returns:
            y_pred: Prediction logits or probabilities [batch_size, output_dim]
        """
        return self.predictor_mlp(z)

# --- Example Usage (can be commented out after testing) ---
if __name__ == '__main__':
    print("\n--- Testing Predictor Component ---")
    # Use configs from Cell 2
    test_batch_size = 4
    test_latent_dim = CGD_MODEL_CONFIG['latent_dim'] # 64

    # Test for MIT-BIH like (multiclass)
    test_output_dim_mitbih = 5
    mitbih_predictor_config = PREDICTOR_CONFIG.copy()
    mitbih_predictor_config['activation'] = 'linear' # For CrossEntropyLoss

    dummy_z_mitbih = torch.randn(test_batch_size, test_latent_dim).to(DEVICE)

    try:
        predictor_mitbih = CGDPredictor(
            latent_dim=test_latent_dim,
            output_dim=test_output_dim_mitbih,
            config=mitbih_predictor_config
        ).to(DEVICE)
        predictor_mitbih.eval()
        with torch.no_grad():
            output_y_mitbih = predictor_mitbih(dummy_z_mitbih)
        print(f"Predictor output shape (MIT-BIH, 5 classes, linear activation): {output_y_mitbih.shape}")
        assert output_y_mitbih.shape == (test_batch_size, test_output_dim_mitbih)

        # Test for PTB like (binary)
        test_output_dim_ptb = 1
        ptb_predictor_config = PREDICTOR_CONFIG.copy()
        ptb_predictor_config['activation'] = 'linear' # For BCEWithLogitsLoss

        dummy_z_ptb = torch.randn(test_batch_size, test_latent_dim).to(DEVICE)
        predictor_ptb = CGDPredictor(
            latent_dim=test_latent_dim,
            output_dim=test_output_dim_ptb,
            config=ptb_predictor_config
        ).to(DEVICE)
        predictor_ptb.eval()
        with torch.no_grad():
            output_y_ptb = predictor_ptb(dummy_z_ptb)
        print(f"Predictor output shape (PTB, 1 class, linear activation): {output_y_ptb.shape}")
        assert output_y_ptb.shape == (test_batch_size, test_output_dim_ptb)

        print("Predictor component tested successfully.")

    except Exception as e:
        print(f"Error during predictor component testing: {e}")
        import traceback
        traceback.print_exc()

print("\nCell 5: Core Model Components - Predictor executed successfully.")

In [None]:
# --- Cell 6: Perturbation Engine ---

class PerturbationGenerator:
    """
    Generates various types of perturbations on input time-series data.
    Uses configurations from PERTURBATION_CONFIG.
    """
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.num_perturbations = config.get('num_perturbations', 10) # M

    def _get_active_perturbations(self) -> List[str]:
        """Returns the list of perturbation types to apply."""
        active = self.config.get('active_types', ['gaussian_noise'])
        if not isinstance(active, list):
            active = [active]
        return active

    def _apply_gaussian_noise(self, X_expanded: torch.Tensor) -> torch.Tensor:
        """Adds Gaussian noise to each perturbed copy."""
        # X_expanded shape: [batch_size, num_perturbations, seq_length, input_dim]
        noise_level = self.config.get('gaussian_noise_level', 0.1)
        noise = torch.randn_like(X_expanded) * noise_level
        return X_expanded + noise

    def _apply_feature_dropout(self, X_expanded: torch.Tensor) -> torch.Tensor:
        """Applies dropout to features at each time step for each perturbed copy."""
        # X_expanded shape: [batch_size, num_perturbations, seq_length, input_dim]
        dropout_rate = self.config.get('feature_dropout_rate', 0.1)
        # Create a dropout mask for features.
        # We want different masks for each of the M perturbations and each batch sample.
        # Mask shape should be [B, M, S, D]
        dropout_mask = (torch.rand_like(X_expanded) > dropout_rate).float()
        # Apply dropout and scale (inverted dropout)
        return (X_expanded * dropout_mask) / (1.0 - dropout_rate + EPSILON)

    def _apply_temporal_swap(self, X_expanded: torch.Tensor) -> torch.Tensor:
        """Randomly swaps adjacent time steps in each perturbed copy."""
        # X_expanded shape: [batch_size, num_perturbations, seq_length, input_dim]
        swap_rate = self.config.get('temporal_swap_rate', 0.05)
        X_perturbed = X_expanded.clone()
        batch_size, M, seq_length, _ = X_perturbed.shape

        for b in range(batch_size):
            for m in range(M):
                for t in range(seq_length - 1):
                    if torch.rand(1).item() < swap_rate:
                        # Swap time steps t and t+1
                        temp = X_perturbed[b, m, t, :].clone()
                        X_perturbed[b, m, t, :] = X_perturbed[b, m, t + 1, :]
                        X_perturbed[b, m, t + 1, :] = temp
        return X_perturbed

    def _apply_input_masking(self, X_expanded: torch.Tensor) -> torch.Tensor:
        """Zeros out random contiguous segments of the input time series."""
        # X_expanded shape: [batch_size, num_perturbations, seq_length, input_dim]
        masking_rate = self.config.get('input_masking_rate', 0.1) # Proportion of total sequence to mask
        chunk_size = self.config.get('input_masking_chunk_size', 10)

        X_perturbed = X_expanded.clone()
        batch_size, M, seq_length, _ = X_perturbed.shape

        num_chunks_to_mask = math.ceil((seq_length * masking_rate) / chunk_size)

        for b in range(batch_size):
            for m in range(M):
                for _ in range(num_chunks_to_mask):
                    if seq_length <= chunk_size : # Handle cases where seq_length is small
                        start_idx = 0
                        actual_chunk_size = seq_length
                    else:
                        start_idx = torch.randint(0, seq_length - chunk_size + 1, (1,)).item()
                        actual_chunk_size = chunk_size

                    X_perturbed[b, m, start_idx : start_idx + actual_chunk_size, :] = 0.0
        return X_perturbed

    def generate(self, X: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Generates perturbed versions of input X.
        Args:
            X: Input tensor [batch_size, seq_length, input_dim]
            padding_mask: Boolean tensor [batch_size, seq_length] (True for padded elements).
                          Perturbations should ideally not affect padded regions.
        Returns:
            X_perturbed_all: Tensor [batch_size, num_perturbations, seq_length, input_dim]
        """
        if self.num_perturbations == 0:
            return X.unsqueeze(1) # Return original if no perturbations requested

        # Expand X for M perturbations: [B, S, D] -> [B, M, S, D]
        X_expanded = X.unsqueeze(1).expand(-1, self.num_perturbations, -1, -1)

        # Create a mask for non-padded elements if padding_mask is provided
        # Shape: [B, 1, S, 1] for broadcasting, True for valid data
        non_padding_mask_expanded = None
        if padding_mask is not None:
            non_padding_mask_expanded = (~padding_mask).float().unsqueeze(1).unsqueeze(-1)
            # Ensure it can broadcast with X_expanded
            # non_padding_mask_expanded = non_padding_mask_expanded.expand(-1, self.num_perturbations, -1, X.size(-1))

        X_current_perturbations = X_expanded.clone()

        active_perturb_types = self._get_active_perturbations()

        for pert_type in active_perturb_types:
            original_values = X_current_perturbations.clone() # Keep original for masked application
            if pert_type == 'gaussian_noise':
                X_current_perturbations = self._apply_gaussian_noise(X_current_perturbations)
            elif pert_type == 'feature_dropout':
                X_current_perturbations = self._apply_feature_dropout(X_current_perturbations)
            elif pert_type == 'temporal_swap':
                X_current_perturbations = self._apply_temporal_swap(X_current_perturbations)
            elif pert_type == 'input_masking':
                X_current_perturbations = self._apply_input_masking(X_current_perturbations)
            # Add more elif for other perturbation types here
            else:
                print(f"Warning: Unknown perturbation type '{pert_type}' in active_types. Skipping.")
                continue

            # Apply padding mask: restore original values in padded regions
            if non_padding_mask_expanded is not None:
                # Apply perturbations only to non-padded regions
                X_current_perturbations = X_current_perturbations * non_padding_mask_expanded + \
                                          original_values * (1 - non_padding_mask_expanded)

        return X_current_perturbations

# --- Example Usage (can be commented out after testing) ---
if __name__ == '__main__':
    print("\n--- Testing Perturbation Engine ---")
    # Use configs from Cell 2
    test_batch_size = 2
    test_seq_length = CGD_MODEL_CONFIG['max_seq_length'] # 187
    test_input_dim = CGD_MODEL_CONFIG['input_dim']     # 1

    dummy_X_batch = torch.ones(test_batch_size, test_seq_length, test_input_dim).to(DEVICE) # Use ones for easy visual check of masking
    dummy_X_batch[0, :10, :] = 5 # Make some parts different

    # Create a dummy padding mask: last 5 elements of the first sample are padded
    dummy_padding_mask_pert = torch.zeros(test_batch_size, test_seq_length, dtype=torch.bool).to(DEVICE)
    if test_seq_length > 5:
        dummy_padding_mask_pert[0, -5:] = True

    print(f"Original X_batch[0, -10:] before perturbation:\n{dummy_X_batch[0, -10:].squeeze().cpu().numpy()}")
    if dummy_padding_mask_pert is not None:
        print(f"Padding mask for X_batch[0, -10:]: {dummy_padding_mask_pert[0, -10:].cpu().numpy()}")


    perturb_config_test = {
        'active_types': ['gaussian_noise', 'input_masking', 'temporal_swap'],
        'gaussian_noise_level': 0.1,
        'input_masking_rate': 0.05, # Mask ~5%
        'input_masking_chunk_size': 5,
        'temporal_swap_rate': 0.1,
        'num_perturbations': 3
    }
    generator = PerturbationGenerator(config=perturb_config_test)

    try:
        X_perturbed_batch = generator.generate(dummy_X_batch, padding_mask=dummy_padding_mask_pert)
        print(f"\nPerturbed X_batch shape: {X_perturbed_batch.shape}")
        assert X_perturbed_batch.shape == (test_batch_size, perturb_config_test['num_perturbations'], test_seq_length, test_input_dim)

        # Check if padded regions remained unchanged (for perturbations that are not masking)
        # For Gaussian noise, if applied before masking, the padded part might be affected then restored.
        # The current logic restores padded regions AFTER each perturbation type.
        # Let's check the first sample, M=0, last 10 elements. Padded ones should be original (1.0)
        print(f"Perturbed X_batch[0, 0, -10:] after all perturbations:\n{X_perturbed_batch[0, 0, -10:].squeeze().cpu().numpy()}")

        if dummy_padding_mask_pert is not None and test_seq_length > 5:
            original_padded_values = dummy_X_batch[0, -5:].squeeze()
            perturbed_padded_values = X_perturbed_batch[0, 0, -5:].squeeze() # Check first perturbation instance

            # Due to the sequential nature and masking being last, this check needs care.
            # The core idea is that the *final* output should respect padding for non-masking perturbations.
            # For masking itself, it can zero out padded regions if the chunk falls there.
            # The current restoration logic `original_values * (1 - non_padding_mask_expanded)`
            # means that for a perturbation like gaussian noise, if it altered a padded region,
            # that alteration would be reverted. If input_masking zeros a padded region, it stays zero.

            print(f"Original padded values (sample 0, last 5): {original_padded_values.cpu().numpy()}")
            print(f"Perturbed padded values (sample 0, pert 0, last 5): {perturbed_padded_values.cpu().numpy()}")
            # This assertion is tricky because input_masking *can* affect padded regions if it randomly targets them.
            # A better check would be for gaussian_noise or feature_dropout if applied *last*.
            # For now, visual inspection is more practical for the combined effect.

        print("Perturbation Engine tested (check output shapes and example values).")

    except Exception as e:
        print(f"Error during Perturbation Engine testing: {e}")
        import traceback
        traceback.print_exc()

print("\nCell 6: Perturbation Engine executed successfully.")

# Cell 7: Geometric Sensitivity Calculation Module.

In [None]:
# --- Cell 7: Geometric Sensitivity Calculation Module ---

class GeometricSensitivityCalculator:
    """
    Calculates geometric sensitivity of latent representations based on perturbations.
    Uses configurations from SENSITIVITY_CONFIG.
    """
    def __init__(self, config: Dict[str, Any]):
        self.config = config

    def _get_active_measures(self) -> List[str]:
        """Returns the list of sensitivity measures to calculate."""
        active = self.config.get('active_measures', ['mean_displacement'])
        if not isinstance(active, list):
            active = [active]
        return active

    def _compute_displacements(self, z: torch.Tensor, z_perturbed: torch.Tensor) -> torch.Tensor:
        """
        Computes Euclidean distances between original z and each perturbed z_m.
        Args:
            z: Original latent embeddings [batch_size, latent_dim]
            z_perturbed: Perturbed latent embeddings [batch_size, num_perturbations, latent_dim]
        Returns:
            distances: Euclidean distances [batch_size, num_perturbations]
        """
        # Expand z to match z_perturbed shape for broadcasting: [B, L] -> [B, 1, L]
        z_expanded = z.unsqueeze(1) # [batch_size, 1, latent_dim]

        # Calculate squared Euclidean distances
        # (z_perturbed - z_expanded)**2 has shape [B, M, L]
        # Sum over latent_dim (dim=2)
        distances_sq = torch.sum((z_perturbed - z_expanded)**2, dim=2) # [batch_size, num_perturbations]

        # Clamp for numerical stability before sqrt
        distances = torch.sqrt(torch.clamp(distances_sq, min=EPSILON))
        return distances

    def _calculate_max_displacement(self, distances: torch.Tensor) -> torch.Tensor:
        """Max Euclidean displacement over M perturbations."""
        # distances shape: [batch_size, num_perturbations]
        max_dist, _ = torch.max(distances, dim=1) # [batch_size]
        return max_dist

    def _calculate_mean_displacement(self, distances: torch.Tensor) -> torch.Tensor:
        """Mean Euclidean displacement over M perturbations."""
        # distances shape: [batch_size, num_perturbations]
        mean_dist = torch.mean(distances, dim=1) # [batch_size]
        return mean_dist

    def _calculate_variance_of_displacements(self, distances: torch.Tensor) -> torch.Tensor:
        """Variance of Euclidean displacements over M perturbations."""
        # distances shape: [batch_size, num_perturbations]
        if distances.size(1) < 2: # Variance requires at least 2 samples
            return torch.zeros_like(distances[:, 0]) # Return zero variance for M < 2
        var_dist = torch.var(distances, dim=1, unbiased=True) # [batch_size]
        return var_dist

    def _calculate_log_covariance_volume(self, z: torch.Tensor, z_perturbed: torch.Tensor) -> torch.Tensor:
        """
        Compute the log determinant of the covariance matrix of perturbed points,
        centered around the original z.
        Args:
            z: Original latent embeddings [batch_size, latent_dim]
            z_perturbed: Perturbed latent embeddings [batch_size, num_perturbations, latent_dim]
        Returns:
            log_det: Log determinant of covariance [batch_size]
        """
        batch_size, M, latent_dim = z_perturbed.shape

        if M < 2:
            #print("Warning: Covariance volume requires at least M>=2 perturbations. Returning large negative value for sensitivity.")
            return torch.full((batch_size,), -100.0, device=z.device, dtype=z.dtype)
        if M <= latent_dim:
            # This warning is very common if latent_dim is high and M is relatively small.
            # print(f"Warning: Covariance matrix might be rank-deficient (M={M} <= latent_dim={latent_dim}). Log-determinant might be -inf or unstable.")
            pass

        # Center perturbed points around the original z
        centered_perturbations = z_perturbed - z.unsqueeze(1) # Shape: [B, M, L]

        # Calculate covariance matrix for each sample in the batch
        # Cov = (1/(M-1)) * X_centered^T * X_centered
        # X_centered^T: [B, L, M]
        # X_centered:   [B, M, L]
        # Cov:          [B, L, L]
        # Need to ensure M-1 is not zero.
        cov_factor = M - 1 if M > 1 else 1 # Avoid division by zero if M=1 (though M<2 handled above)

        # For batch matrix multiplication: (B, n, m) @ (B, m, p) -> (B, n, p)
        cov_matrices = torch.bmm(centered_perturbations.transpose(1, 2), centered_perturbations) / cov_factor

        # Add small identity matrix for numerical stability (regularization) before logdet
        identity = torch.eye(latent_dim, device=z.device, dtype=z.dtype).unsqueeze(0).expand(batch_size, -1, -1)
        cov_stable = cov_matrices + identity * EPSILON # EPSILON is small (e.g., 1e-8)

        # Calculate signed log determinant
        _sign, log_determinant = torch.linalg.slogdet(cov_stable)

        # Handle potential -inf or NaN values from logdet if matrix is singular despite regularization
        log_determinant = torch.nan_to_num(log_determinant, nan=-100.0, posinf=0.0, neginf=-100.0)

        return log_determinant

    def calculate_sensitivities(self, z: torch.Tensor, z_perturbed: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Calculates all active sensitivity measures.
        Args:
            z: Original latent embeddings [batch_size, latent_dim]
            z_perturbed: Perturbed latent embeddings [batch_size, num_perturbations, latent_dim]
        Returns:
            A dictionary of sensitivity scores, e.g., {'max_displacement': tensor, 'mean_displacement': tensor}
        """
        results = {}
        active_measures = self._get_active_measures()

        # Pre-compute displacements if any displacement-based measure is active
        distances = None
        if any(m in ['max_displacement', 'mean_displacement', 'variance_of_displacements'] for m in active_measures):
            distances = self._compute_displacements(z, z_perturbed)

        for measure_type in active_measures:
            if measure_type == 'max_displacement':
                results[measure_type] = self._calculate_max_displacement(distances)
            elif measure_type == 'mean_displacement':
                results[measure_type] = self._calculate_mean_displacement(distances)
            elif measure_type == 'variance_of_displacements':
                results[measure_type] = self._calculate_variance_of_displacements(distances)
            elif measure_type == 'log_covariance_volume':
                results[measure_type] = self._calculate_log_covariance_volume(z, z_perturbed)
            else:
                print(f"Warning: Unknown sensitivity measure type '{measure_type}'. Skipping.")
        return results

    def aggregate_sensitivities(self, sensitivity_scores_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Aggregates multiple sensitivity scores into a single score per sample.
        """
        if not sensitivity_scores_dict:
            raise ValueError("sensitivity_scores_dict is empty. Cannot aggregate.")

        scores_to_aggregate = list(sensitivity_scores_dict.values())

        # Stack scores: if each is [B], stacking makes it [Num_Measures, B]
        if not scores_to_aggregate: # Should not happen if dict is not empty, but as a safeguard
             return torch.tensor(0.0, device=DEVICE) # Or handle appropriately

        stacked_scores = torch.stack(scores_to_aggregate, dim=0) # [Num_Measures, Batch_Size]

        agg_method = self.config.get('aggregation_method_for_multiple_sensitivities', 'mean')

        if agg_method == 'mean':
            final_score = torch.mean(stacked_scores, dim=0)
        elif agg_method == 'max':
            final_score, _ = torch.max(stacked_scores, dim=0)
        # elif agg_method == 'weighted_sum':
        #     weights = self.config.get('sensitivity_aggregation_weights')
        #     if weights is None or len(weights) != len(scores_to_aggregate):
        else:
            print(f"Warning: Unknown sensitivity aggregation method '{agg_method}'. Using mean.")
            final_score = torch.mean(stacked_scores, dim=0)

        return final_score # [Batch_Size]

# --- Example Usage (can be commented out after testing) ---
if __name__ == '__main__':
    print("\n--- Testing Geometric Sensitivity Calculator ---")
    # Use configs from Cell 2
    test_batch_size = 4
    test_latent_dim = CGD_MODEL_CONFIG['latent_dim'] # 64
    test_num_perturbations = PERTURBATION_CONFIG['num_perturbations'] # 10

    dummy_z = torch.randn(test_batch_size, test_latent_dim).to(DEVICE)
    dummy_z_perturbed = torch.randn(test_batch_size, test_num_perturbations, test_latent_dim).to(DEVICE)
    # Add some structure: make one perturbation far for one sample
    dummy_z_perturbed[0, 0, :] += 5.0
    # Make another sample have very little spread for its perturbations
    dummy_z_perturbed[1, :, :] = dummy_z[1].unsqueeze(0) + torch.randn(test_num_perturbations, test_latent_dim).to(DEVICE) * 0.01


    sensitivity_config_test = {
        'active_measures': ['max_displacement', 'mean_displacement', 'variance_of_displacements', 'log_covariance_volume'],
        'aggregation_method_for_multiple_sensitivities': 'mean',
    }
    calculator = GeometricSensitivityCalculator(config=sensitivity_config_test)

    try:
        all_scores_dict = calculator.calculate_sensitivities(dummy_z, dummy_z_perturbed)
        print("\nCalculated raw sensitivity scores (dict):")
        for k, v in all_scores_dict.items():
            print(f"  {k}: shape={v.shape}, example_values={v.detach().cpu().numpy()[:2]}")
            assert v.shape == (test_batch_size,)

        final_sensitivity_score = calculator.aggregate_sensitivities(all_scores_dict)
        print(f"\nFinal aggregated sensitivity score (method='{sensitivity_config_test['aggregation_method_for_multiple_sensitivities']}'):")
        print(f"  Shape: {final_sensitivity_score.shape}, example_values={final_sensitivity_score.detach().cpu().numpy()[:2]}")
        assert final_sensitivity_score.shape == (test_batch_size,)

        # Test with fewer perturbations for log_covariance_volume warnings
        print("\nTesting log_covariance_volume with M < L_dim:")
        few_pert_config = {'num_perturbations': min(5, test_latent_dim -1) if test_latent_dim > 1 else 1}
        if few_pert_config['num_perturbations'] > 0:
             dummy_z_perturbed_few = torch.randn(test_batch_size, few_pert_config['num_perturbations'], test_latent_dim).to(DEVICE)
             log_cov_vol_few = calculator._calculate_log_covariance_volume(dummy_z, dummy_z_perturbed_few)
             print(f"  log_cov_vol with M={few_pert_config['num_perturbations']}: {log_cov_vol_few.detach().cpu().numpy()[:2]}")
        else:
             print("  Skipping M < L_dim test as latent_dim is too small or num_perturbations is 0.")


        print("\nGeometric Sensitivity Calculator tested successfully.")

    except Exception as e:
        print(f"Error during Geometric Sensitivity Calculator testing: {e}")
        import traceback
        traceback.print_exc()

print("\nCell 7: Geometric Sensitivity Calculation Module executed successfully.")

# Cell 8: Structural Regularizer.

In [None]:
# --- Cell 8: Structural Regularizer ---

class StructuralRegularizer(nn.Module):
    """
    Module for applying structural regularization in the latent space.
    Currently supports contrastive loss.
    Uses configurations from STRUCTURAL_REGULARIZER_CONFIG.
    """
    def __init__(self, latent_dim: int, config: Dict[str, Any]):
        super().__init__()
        self.latent_dim = latent_dim
        self.config = config
        self.regularization_type = config.get('regularization_type', 'contrastive').lower()
        self.temperature = config.get('temperature', 0.1)

        if self.regularization_type == 'contrastive':
            # Projection head for contrastive loss, as is common practice (e.g., SimCLR)
            # It maps representations to the space where contrastive loss is applied.
            # Typically, a 2-layer MLP. Output dim can be smaller than latent_dim.
            projection_hidden_dim = config.get('projection_hidden_dim', latent_dim) # Can be same as latent_dim
            projection_output_dim = config.get('projection_output_dim', max(latent_dim // 2, 16)) # Example: half of latent_dim

            self.projection_head = nn.Sequential(
                nn.Linear(latent_dim, projection_hidden_dim),
                nn.ReLU(),
                nn.Linear(projection_hidden_dim, projection_output_dim)
            )
        elif self.regularization_type == 'prototype':
            # Placeholder for prototype loss components if you decide to implement it
            self.num_prototypes = config.get('num_prototypes', 10)
            self.prototype_lambda = config.get('prototype_lambda', 0.1) # For prototype diversity
            self.prototypes = nn.Parameter(torch.randn(self.num_prototypes, latent_dim))
            nn.init.xavier_uniform_(self.prototypes)
            print("Prototype regularizer initialized (implementation is basic).")
        elif self.regularization_type != 'none':
            raise ValueError(f"Unsupported regularization_type: {self.regularization_type}")

    def _compute_cosine_similarity(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
        """Computes cosine similarity matrix between two sets of embeddings."""
        # z1: [N, D], z2: [M, D] -> Output: [N, M]
        z1_norm = F.normalize(z1, p=2, dim=1)
        z2_norm = F.normalize(z2, p=2, dim=1)
        return torch.mm(z1_norm, z2_norm.t())

    def _info_nce_loss(self, z_anchor: torch.Tensor, z_positive: torch.Tensor, temperature: float) -> torch.Tensor:
        """
        Computes InfoNCE loss.
        z_anchor: [batch_size, proj_dim] - e.g., projections of original samples
        z_positive: [batch_size, proj_dim] - e.g., projections of augmented/perturbed samples
        """
        batch_size = z_anchor.size(0)
        if batch_size == 0:
            return torch.tensor(0.0, device=z_anchor.device)

        # Concatenate all representations: [2*batch_size, proj_dim]
        representations = torch.cat([z_anchor, z_positive], dim=0)

        # Calculate similarity matrix: [2*batch_size, 2*batch_size]
        similarity_matrix = self._compute_cosine_similarity(representations, representations) / temperature

        # Create labels for positives:
        # For each z_anchor[i], its positive is z_positive[i] (which is at index i + batch_size in 'representations')
        # For each z_positive[i], its positive is z_anchor[i] (which is at index i in 'representations')
        labels = torch.cat([
            torch.arange(batch_size) + batch_size, # Positives for z_anchor
            torch.arange(batch_size)               # Positives for z_positive
        ]).to(z_anchor.device)

        # Mask out self-similarity (diagonal elements) from logits before CrossEntropy
        # Create an identity matrix, set diagonal to a very small number (-inf effectively after division by T)
        # This ensures that a sample is not contrasted with itself.
        mask_diag = ~torch.eye(2 * batch_size, dtype=torch.bool, device=similarity_matrix.device)
        similarity_matrix_masked = similarity_matrix.masked_fill(~mask_diag, float('-inf'))

        loss = F.cross_entropy(similarity_matrix_masked, labels)
        return loss

    def contrastive_loss(self, z: torch.Tensor, z_perturbed_single: torch.Tensor) -> torch.Tensor:
        """
        Calculates contrastive loss (SimCLR-style).
        Args:
            z: Original latent embeddings [batch_size, latent_dim].
            z_perturbed_single: A single representative perturbed embedding per original sample.
                                e.g., mean of M perturbations [batch_size, latent_dim].
        """
        if not hasattr(self, 'projection_head'):
            # This should not happen if regularization_type is 'contrastive' due to __init__
            warnings.warn("Projection head not found for contrastive loss. Using raw embeddings.", UserWarning)
            z_proj = z
            z_pert_proj = z_perturbed_single
        else:
            z_proj = self.projection_head(z)
            z_pert_proj = self.projection_head(z_perturbed_single)

        return self._info_nce_loss(z_proj, z_pert_proj, self.temperature)

    def forward(self, z: torch.Tensor, z_perturbed: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute the structural regularization loss.
        Args:
            z: Original latent embeddings [batch_size, latent_dim]
            z_perturbed: All M perturbed latent embeddings [batch_size, num_perturbations, latent_dim]
                         OR a single representative perturbed embedding [batch_size, latent_dim].
        """
        if self.regularization_type == 'contrastive':
            if z_perturbed is None:
                #print("Warning: Contrastive loss requires perturbed representations (z_perturbed). Returning 0 loss.")
                return torch.tensor(0.0, device=z.device)

            if z_perturbed.ndim == 3: # [B, M, L] -> Take mean over M perturbations
                # It's common to use the mean of perturbations as the "positive" augmentation view
                z_perturbed_representative = torch.mean(z_perturbed, dim=1) # -> [B, L]
            elif z_perturbed.ndim == 2: # Already [B, L]
                z_perturbed_representative = z_perturbed
            else:
                raise ValueError(f"z_perturbed has unexpected ndim: {z_perturbed.ndim}")

            return self.contrastive_loss(z, z_perturbed_representative)

        elif self.regularization_type == 'prototype':
            # Basic prototype loss: encourage z to be close to one prototype,
            # and prototypes to be diverse. This is a simplified version.
            # This part is not fully developed/tested like contrastive.
            batch_size = z.shape[0]
            z_expanded = z.unsqueeze(1) # [B, 1, L]
            prototypes_expanded = self.prototypes.unsqueeze(0) # [1, P, L]

            distances_sq = torch.sum((z_expanded - prototypes_expanded)**2, dim=2) # [B, P]
            min_distances_sq, _ = torch.min(distances_sq, dim=1) # [B], distance to closest prototype
            clustering_loss = min_distances_sq.mean()

            # Prototype diversity (e.g., encourage orthogonality or distance)
            proto_norm = F.normalize(self.prototypes, dim=1)
            proto_sim_matrix = torch.mm(proto_norm, proto_norm.t()) # [P, P]
            # Penalize high similarity between different prototypes
            diversity_loss = (proto_sim_matrix[~torch.eye(self.num_prototypes, dtype=bool, device=z.device)]**2).mean()

            return clustering_loss + self.prototype_lambda * diversity_loss

        elif self.regularization_type == 'none':
            return torch.tensor(0.0, device=z.device)
        else:
            raise ValueError(f"Unsupported regularization_type: {self.regularization_type}")

# --- Example Usage (can be commented out after testing) ---
if __name__ == '__main__':
    print("\n--- Testing Structural Regularizer ---")
    # Use configs from Cell 2
    test_batch_size = 16 # Contrastive loss works better with larger batches
    test_latent_dim = CGD_MODEL_CONFIG['latent_dim'] # 64
    test_num_perturbations = PERTURBATION_CONFIG['num_perturbations'] # 10

    dummy_z_orig = torch.randn(test_batch_size, test_latent_dim).to(DEVICE)
    # For contrastive, we usually pass the mean of perturbations, or one selected view
    dummy_z_pert_all = torch.randn(test_batch_size, test_num_perturbations, test_latent_dim).to(DEVICE)
    dummy_z_pert_mean = torch.mean(dummy_z_pert_all, dim=1)

    reg_config_test_contrastive = STRUCTURAL_REGULARIZER_CONFIG.copy()
    reg_config_test_contrastive['regularization_type'] = 'contrastive'

    reg_config_test_proto = STRUCTURAL_REGULARIZER_CONFIG.copy()
    reg_config_test_proto['regularization_type'] = 'prototype'
    reg_config_test_proto['num_prototypes'] = 5
    reg_config_test_proto['prototype_lambda'] = 0.05


    try:
        print("\nTesting Contrastive Regularizer:")
        regularizer_contrastive = StructuralRegularizer(
            latent_dim=test_latent_dim,
            config=reg_config_test_contrastive
        ).to(DEVICE)
        regularizer_contrastive.train() # Projection head has dropout/batchnorm if any

        # Test with mean of perturbations
        loss_contrastive_mean = regularizer_contrastive(dummy_z_orig, dummy_z_pert_mean)
        print(f"  Contrastive loss (with mean z_perturbed): {loss_contrastive_mean.item():.4f}")
        assert loss_contrastive_mean >= 0

        # Test with all M perturbations (internally takes mean)
        loss_contrastive_all_m = regularizer_contrastive(dummy_z_orig, dummy_z_pert_all)
        print(f"  Contrastive loss (with all M z_perturbed): {loss_contrastive_all_m.item():.4f}")
        assert loss_contrastive_all_m >= 0

        # Test 'none' type
        reg_config_none = {'regularization_type': 'none'}
        regularizer_none = StructuralRegularizer(latent_dim=test_latent_dim, config=reg_config_none).to(DEVICE)
        loss_none = regularizer_none(dummy_z_orig, dummy_z_pert_mean)
        print(f"  'None' regularizer loss: {loss_none.item():.4f}")
        assert loss_none.item() == 0.0

        print("\nTesting Prototype Regularizer (basic implementation):")
        regularizer_proto = StructuralRegularizer(
            latent_dim=test_latent_dim,
            config=reg_config_test_proto
        ).to(DEVICE)
        regularizer_proto.train()
        loss_proto = regularizer_proto(dummy_z_orig) # Doesn't use z_perturbed in this simple version
        print(f"  Prototype loss: {loss_proto.item():.4f}")
        assert loss_proto >= 0


        print("\nStructural Regularizer tested successfully.")

    except Exception as e:
        print(f"Error during Structural Regularizer testing: {e}")
        import traceback
        traceback.print_exc()

print("\nCell 8: Structural Regularizer executed successfully.")

In [None]:
# --- Cell 9: Main CGD Model Architecture (UniversalCGDModel) ---

class UniversalCGDModel(nn.Module):
    """
    Universal Causal Geometric Deferral (CGD) framework.
    Integrates encoder, predictor, perturbation, sensitivity, and regularization.
    Deferral during training is based on a fixed threshold for loss calculation.
    Evaluation can use an adaptive threshold.
    """
    def __init__(
        self,
        model_config: Dict[str, Any],
        encoder_config: Dict[str, Any],
        predictor_config: Dict[str, Any],
        perturb_config: Dict[str, Any],
        sensitivity_config: Dict[str, Any],
        regularizer_config: Dict[str, Any],
        # output_dim must be passed explicitly based on dataset
        output_dim: int,
    ):
        super().__init__()
        self.model_config = model_config
        self.output_dim_val = output_dim # Store for predictor and loss logic

        # Initialize Components
        self.encoder = TimeSeriesTransformerEncoder(
            input_dim=model_config['input_dim'],
            config=encoder_config,
            max_seq_len_data=model_config['max_seq_length']
        )

        # The encoder outputs `embed_dim` which might be different from the desired `latent_dim`
        # for the rest of the CGD components. Add a projection if necessary.
        encoder_actual_output_dim = encoder_config.get('embed_dim', model_config['latent_dim'])
        self.latent_dim = model_config['latent_dim']

        if encoder_actual_output_dim != self.latent_dim:
            self.latent_projection = nn.Linear(encoder_actual_output_dim, self.latent_dim)
            print(f"Added latent projection from encoder output {encoder_actual_output_dim} to latent_dim {self.latent_dim}")
        else:
            self.latent_projection = nn.Identity()

        self.predictor = CGDPredictor(
            latent_dim=self.latent_dim,
            output_dim=output_dim, # Passed based on dataset
            config=predictor_config
        )
        self.perturbation_generator = PerturbationGenerator(config=perturb_config)
        self.sensitivity_calculator = GeometricSensitivityCalculator(config=sensitivity_config)
        self.structural_regularizer = StructuralRegularizer(
            latent_dim=self.latent_dim, # Regularizer works on the final latent_dim
            config=regularizer_config
        )

        self.deferral_threshold_train = model_config.get('deferral_threshold_train', 0.5)

    def forward(
        self,
        X_batch: torch.Tensor, # [B, S, D_in]
        padding_mask_batch: Optional[torch.Tensor] = None # [B, S]
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass through the UniversalCGDModel.
        """
        # 1. Encode original input
        z_encoded_original = self.encoder(X_batch, src_key_padding_mask=padding_mask_batch)
        z_original = self.latent_projection(z_encoded_original) # -> [B, L]

        # 2. Generate & encode perturbations
        # X_perturbed_all: [B, M, S, D_in]
        X_perturbed_all = self.perturbation_generator.generate(X_batch, padding_mask=padding_mask_batch)

        batch_size, M, seq_length, input_dim = X_perturbed_all.shape
        X_perturbed_flat = X_perturbed_all.reshape(batch_size * M, seq_length, input_dim) # [B*M, S, D_in]

        pert_padding_mask_flat = None
        if padding_mask_batch is not None:
            pert_padding_mask_flat = padding_mask_batch.repeat_interleave(M, dim=0) # [B*M, S]

        z_encoded_perturbed_flat = self.encoder(X_perturbed_flat, src_key_padding_mask=pert_padding_mask_flat)
        z_perturbed_flat = self.latent_projection(z_encoded_perturbed_flat) # -> [B*M, L]
        z_perturbed_all = z_perturbed_flat.reshape(batch_size, M, self.latent_dim) # -> [B, M, L]

        # 3. Calculate one or more raw sensitivity scores
        # raw_sensitivity_scores_dict: {'measure1': [B], 'measure2': [B], ...}
        raw_sensitivity_scores_dict = self.sensitivity_calculator.calculate_sensitivities(z_original, z_perturbed_all)

        # 4. Aggregate raw sensitivity scores if multiple are active
        if not raw_sensitivity_scores_dict: # Should not happen if config is valid
            final_sensitivity_score = torch.zeros(batch_size, device=X_batch.device)
        elif len(raw_sensitivity_scores_dict) == 1:
            final_sensitivity_score = next(iter(raw_sensitivity_scores_dict.values()))
        else:
            final_sensitivity_score = self.sensitivity_calculator.aggregate_sensitivities(raw_sensitivity_scores_dict)

        # 5. Make predictions
        y_pred_logits = self.predictor(z_original) # [B, Output_Dim]

        # 6. Deferral decision *during training* (for loss calculation purposes)
        # This uses the fixed training threshold. Evaluation will use an adaptive one.
        defer_train_time = final_sensitivity_score > self.deferral_threshold_train # [B] boolean

        return {
            'z_original': z_original,                       # [B, L]
            'z_perturbed_all': z_perturbed_all,             # [B, M, L]
            'raw_sensitivity_scores_dict': raw_sensitivity_scores_dict, # Dict of [B]
            'final_sensitivity_score': final_sensitivity_score,       # [B]
            'y_pred_logits': y_pred_logits,                 # [B, Output_Dim]
            'defer_train_time': defer_train_time            # [B] boolean
        }

    def compute_loss(
        self,
        model_output: Dict[str, torch.Tensor],
        y_true: torch.Tensor, # [B] (long for CE, float for BCE)
    ) -> Dict[str, torch.Tensor]:
        """Computes the multi-objective loss for the CGD model."""

        z_original = model_output['z_original']
        z_perturbed_all = model_output['z_perturbed_all']
        final_sensitivity_score = model_output['final_sensitivity_score']
        y_pred_logits = model_output['y_pred_logits']
        defer_train_time = model_output['defer_train_time'] # Based on fixed training threshold

        loss_alpha = self.model_config.get('loss_alpha', 0.1)
        loss_beta = self.model_config.get('loss_beta', 0.1)
        defer_cost_factor = self.model_config.get('defer_cost_factor', 0.3)

        batch_size = y_true.size(0)
        if batch_size == 0: # Should not happen with valid dataloader
            return {k: torch.tensor(0.0, device=DEVICE) for k in ['total_loss', 'pred_loss', 'struct_loss', 'defer_cost_loss', 'sens_reg_loss', 'defer_rate_train_time']}


        # --- 1. Prediction Loss (L_pred) ---
        # Applied only to non-deferred samples (based on training-time deferral)
        non_deferred_mask_train = ~defer_train_time

        pred_loss = torch.tensor(0.0, device=y_true.device)
        num_non_deferred = non_deferred_mask_train.sum().item()

        if num_non_deferred > 0:
            if self.output_dim_val == 1: # Binary classification (e.g., PTB)
                # BCEWithLogitsLoss expects float targets of shape [B] or [B,1]
                pred_loss_fn = nn.BCEWithLogitsLoss(reduction='none')
                current_pred_loss = pred_loss_fn(
                    y_pred_logits[non_deferred_mask_train].squeeze(-1), # Logits [N_nd]
                    y_true[non_deferred_mask_train].float()            # Targets [N_nd]
                )
            else: # Multiclass classification (e.g., MIT-BIH)
                # CrossEntropyLoss expects long targets of shape [B]
                pred_loss_fn = nn.CrossEntropyLoss(reduction='none')
                current_pred_loss = pred_loss_fn(
                    y_pred_logits[non_deferred_mask_train], # Logits [N_nd, C]
                    y_true[non_deferred_mask_train].long()  # Targets [N_nd]
                )
            pred_loss = current_pred_loss.mean()

        # --- 2. Structural Regularization Loss (L_struct) ---
        struct_loss = self.structural_regularizer(z_original, z_perturbed_all)

        # --- 3. Deferral Cost Loss (L_defer_cost) ---
        # Cost for samples deferred during training
        defer_cost_loss = defer_train_time.float().mean() * defer_cost_factor

        # --- 4. Sensitivity Regularization Loss (L_sens_reg) ---
        with torch.no_grad(): # Detach correct_prediction from graph for this specific loss term
            if self.output_dim_val == 1: # Binary
                # Ensure y_pred_logits for binary is [B,1], then squeeze for comparison
                predicted_classes = (y_pred_logits.squeeze(-1).sigmoid() > 0.5).float()
                correct_prediction = (predicted_classes == y_true.float()).float()
            else: # Multiclass
                predicted_classes = torch.argmax(y_pred_logits, dim=1)
                correct_prediction = (predicted_classes == y_true.long()).float()

        # Target for sensitivity: +1 for incorrect, -1 for correct
        # We want to minimize (sensitivity * target_for_sensitivity)
        # So, sensitivity should be high when target_for_sensitivity is -1 (incorrect prediction)
        # And sensitivity should be low when target_for_sensitivity is +1 (correct prediction)
        # The original formula: mean(sensitivity * (2 * correct_pred.float() - 1))
        # if correct_pred = 1 -> sens * 1 (minimize sens)
        # if correct_pred = 0 -> sens * -1 (minimize -sens -> maximize sens)
        # This aligns sensitivity with being high for errors.
        target_for_sensitivity_shaping = (2 * correct_prediction) - 1
        sens_reg_loss = (final_sensitivity_score * target_for_sensitivity_shaping).mean()

        # --- 5. Total Loss ---
        total_loss = pred_loss + \
                     loss_alpha * struct_loss + \
                     loss_beta * (defer_cost_loss + sens_reg_loss)

        return {
            'total_loss': total_loss,
            'pred_loss': pred_loss,
            'struct_loss': struct_loss,
            'defer_cost_loss': defer_cost_loss,
            'sens_reg_loss': sens_reg_loss,
            'defer_rate_train_time': defer_train_time.float().mean() # For monitoring
        }

# --- Example Usage (can be commented out after testing) ---
if __name__ == '__main__':
    print("\n--- Testing UniversalCGDModel ---")
    # Use configs from Cell 2
    # Test for MIT-BIH like (multiclass)
    mitbih_output_dim = 5

    # Create dummy data batch from MIT-BIH loader if available, else create simple dummy
    if 'mitbih_loaders' in locals() and mitbih_loaders is not None and 'train' in mitbih_loaders:
        try:
            dummy_X_mit, dummy_y_mit, dummy_mask_mit = next(iter(mitbih_loaders['train']))
            dummy_X_mit, dummy_y_mit, dummy_mask_mit = dummy_X_mit.to(DEVICE), dummy_y_mit.to(DEVICE), dummy_mask_mit.to(DEVICE)
            # Reduce batch size for faster test if loaded batch is large
            if dummy_X_mit.size(0) > 4:
                 dummy_X_mit = dummy_X_mit[:4]
                 dummy_y_mit = dummy_y_mit[:4]
                 dummy_mask_mit = dummy_mask_mit[:4]
            print(f"Using a batch from MIT-BIH DataLoader for testing: X_shape={dummy_X_mit.shape}")
        except StopIteration:
            print("MIT-BIH DataLoader is empty, creating simple dummy data.")
            dummy_X_mit = torch.randn(4, CGD_MODEL_CONFIG['max_seq_length'], CGD_MODEL_CONFIG['input_dim']).to(DEVICE)
            dummy_y_mit = torch.randint(0, mitbih_output_dim, (4,)).to(DEVICE)
            dummy_mask_mit = torch.zeros(4, CGD_MODEL_CONFIG['max_seq_length'], dtype=torch.bool).to(DEVICE)
    else:
        print("MIT-BIH loader not found, creating simple dummy data for model test.")
        dummy_X_mit = torch.randn(4, CGD_MODEL_CONFIG['max_seq_length'], CGD_MODEL_CONFIG['input_dim']).to(DEVICE)
        dummy_y_mit = torch.randint(0, mitbih_output_dim, (4,)).to(DEVICE)
        dummy_mask_mit = torch.zeros(4, CGD_MODEL_CONFIG['max_seq_length'], dtype=torch.bool).to(DEVICE)

    try:
        cgd_model_mitbih = UniversalCGDModel(
            model_config=CGD_MODEL_CONFIG,
            encoder_config=ENCODER_CONFIG,
            predictor_config=PREDICTOR_CONFIG,
            perturb_config=PERTURBATION_CONFIG,
            sensitivity_config=SENSITIVITY_CONFIG,
            regularizer_config=STRUCTURAL_REGULARIZER_CONFIG,
            output_dim=mitbih_output_dim
        ).to(DEVICE)
        cgd_model_mitbih.train() # Set to train for dropout, batchnorm etc.

        model_output_mitbih = cgd_model_mitbih(dummy_X_mit, padding_mask_batch=dummy_mask_mit)
        print("\nMIT-BIH Model Output Keys:", list(model_output_mitbih.keys()))
        print(f"  y_pred_logits shape: {model_output_mitbih['y_pred_logits'].shape}")
        assert model_output_mitbih['y_pred_logits'].shape == (dummy_X_mit.size(0), mitbih_output_dim)
        print(f"  final_sensitivity_score shape: {model_output_mitbih['final_sensitivity_score'].shape}")
        assert model_output_mitbih['final_sensitivity_score'].shape == (dummy_X_mit.size(0),)
        print(f"  defer_train_time shape: {model_output_mitbih['defer_train_time'].shape}, example: {model_output_mitbih['defer_train_time'].cpu().numpy()}")
        assert model_output_mitbih['defer_train_time'].shape == (dummy_X_mit.size(0),)


        loss_dict_mitbih = cgd_model_mitbih.compute_loss(model_output_mitbih, dummy_y_mit)
        print("\nMIT-BIH Model Loss Dict Keys:", list(loss_dict_mitbih.keys()))
        for k, v in loss_dict_mitbih.items():
            print(f"  {k}: {v.item():.4f}")
        assert 'total_loss' in loss_dict_mitbih

        print("\nUniversalCGDModel tested successfully for MIT-BIH case.")

    except Exception as e:
        print(f"Error during UniversalCGDModel testing: {e}")
        import traceback
        traceback.print_exc()

print("\nCell 9: Main CGD Model Architecture (UniversalCGDModel) executed successfully.")

# Cell 10: Training Loop

In [None]:
# --- Cell 10: Training Loop ---

def train_epoch(
    model: UniversalCGDModel,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch_num: int, # For logging
    total_epochs: int # For logging
) -> Dict[str, float]:
    """Trains the model for one epoch."""
    model.train()
    total_samples = 0
    running_losses = {
        'total_loss': 0.0, 'pred_loss': 0.0, 'struct_loss': 0.0,
        'defer_cost_loss': 0.0, 'sens_reg_loss': 0.0, 'defer_rate_train_time': 0.0
    }

    pbar = tqdm(dataloader, desc=f"Epoch {epoch_num+1}/{total_epochs} [Train]")
    for X_batch, y_batch, padding_mask_batch in pbar:
        X_batch, y_batch, padding_mask_batch = X_batch.to(device), y_batch.to(device), padding_mask_batch.to(device)

        optimizer.zero_grad()
        model_output = model(X_batch, padding_mask_batch=padding_mask_batch)
        loss_dict = model.compute_loss(model_output, y_batch)

        total_loss = loss_dict['total_loss']

        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print(f"Warning: NaN or Inf loss detected in training epoch {epoch_num+1}. Skipping batch.")
            continue

        total_loss.backward()
        # Optional: Gradient clipping can help with stability for complex models/losses
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        batch_size = X_batch.size(0)
        total_samples += batch_size
        for key in running_losses.keys():
            running_losses[key] += loss_dict[key].item() * batch_size

        pbar.set_postfix({
            'loss': f"{running_losses['total_loss']/total_samples:.4f}",
            'defer%': f"{running_losses['defer_rate_train_time']/total_samples*100:.1f}%"
        })

    epoch_avg_losses = {key: val / total_samples if total_samples > 0 else 0.0 for key, val in running_losses.items()}
    return epoch_avg_losses

def validate_epoch(
    model: UniversalCGDModel,
    dataloader: DataLoader,
    device: torch.device,
    epoch_num: int, # For logging
    total_epochs: int # For logging
) -> Dict[str, float]:
    """Validates the model for one epoch."""
    model.eval()
    total_samples = 0
    running_losses = {
        'total_loss': 0.0, 'pred_loss': 0.0, 'struct_loss': 0.0,
        'defer_cost_loss': 0.0, 'sens_reg_loss': 0.0, 'defer_rate_train_time': 0.0
    }
    all_y_true = []
    all_y_pred_logits = []
    all_defer_train_time = []

    pbar = tqdm(dataloader, desc=f"Epoch {epoch_num+1}/{total_epochs} [Val]")
    with torch.no_grad():
        for X_batch, y_batch, padding_mask_batch in pbar:
            X_batch, y_batch, padding_mask_batch = X_batch.to(device), y_batch.to(device), padding_mask_batch.to(device)

            model_output = model(X_batch, padding_mask_batch=padding_mask_batch)
            loss_dict = model.compute_loss(model_output, y_batch)

            batch_size = X_batch.size(0)
            total_samples += batch_size
            for key in running_losses.keys():
                 if key in loss_dict and loss_dict[key] is not None : # Ensure key exists
                    running_losses[key] += loss_dict[key].item() * batch_size

            all_y_true.append(y_batch.cpu())
            all_y_pred_logits.append(model_output['y_pred_logits'].cpu())
            all_defer_train_time.append(model_output['defer_train_time'].cpu())

            pbar.set_postfix({
                'val_loss': f"{running_losses['total_loss']/total_samples:.4f}",
                'val_defer%': f"{running_losses['defer_rate_train_time']/total_samples*100:.1f}%"
            })

    epoch_avg_losses = {key: val / total_samples if total_samples > 0 else 0.0 for key, val in running_losses.items()}

    # Calculate accuracy on non-deferred samples (using training-time deferral for this val metric)
    y_true_cat = torch.cat(all_y_true)
    y_pred_logits_cat = torch.cat(all_y_pred_logits)
    defer_train_time_cat = torch.cat(all_defer_train_time)

    non_deferred_mask = ~defer_train_time_cat
    acc_nd_val = 0.0
    if non_deferred_mask.sum().item() > 0:
        y_true_nd = y_true_cat[non_deferred_mask]
        y_pred_logits_nd = y_pred_logits_cat[non_deferred_mask]

        if model.output_dim_val == 1: # Binary
            preds_nd = (y_pred_logits_nd.sigmoid() > 0.5).float().squeeze()
            acc_nd_val = accuracy_score(y_true_nd.float().numpy(), preds_nd.numpy())
        else: # Multiclass
            preds_nd = torch.argmax(y_pred_logits_nd, dim=1)
            acc_nd_val = accuracy_score(y_true_nd.long().numpy(), preds_nd.numpy())

    epoch_avg_losses['accuracy_nd_val_train_thresh'] = acc_nd_val # Accuracy on non-deferred (using training threshold)

    return epoch_avg_losses


def train_universal_cgd_model(
    model: UniversalCGDModel,
    train_loader: DataLoader,
    val_loader: DataLoader,
    training_config: Dict[str, Any], # From GENERAL_TRAINING_CONFIG
    model_specific_checkpoint_name: str = "universal_cgd_best.pt"
) -> Tuple[UniversalCGDModel, Dict[str, List[float]]]:
    """
    Main training loop for the UniversalCGDModel.
    Includes early stopping and saving the best model.
    """
    device = next(model.parameters()).device # Get device from model

    optimizer = optim.Adam(model.parameters(), lr=training_config['learning_rate'])
    # Optional: Learning rate scheduler
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=training_config['patience'] // 2, factor=0.5, verbose=True)

    epochs = training_config['epochs']
    patience = training_config['patience']
    checkpoint_dir = training_config['checkpoint_dir']
    best_model_path = os.path.join(checkpoint_dir, model_specific_checkpoint_name)

    history_keys = [
        'train_total_loss', 'train_pred_loss', 'train_struct_loss',
        'train_defer_cost_loss', 'train_sens_reg_loss', 'train_defer_rate_train_time',
        'val_total_loss', 'val_pred_loss', 'val_struct_loss',
        'val_defer_cost_loss', 'val_sens_reg_loss', 'val_defer_rate_train_time',
        'val_accuracy_nd_val_train_thresh' # Accuracy on non-deferred samples in validation (using training defer threshold)
    ]
    history = {k: [] for k in history_keys}

    best_val_total_loss = float('inf')
    epochs_no_improve = 0

    print(f"Starting training for {epochs} epochs on {device}...")
    print(f"Best model will be saved to: {best_model_path}")

    for epoch in range(epochs):
        train_metrics = train_epoch(model, train_loader, optimizer, device, epoch, epochs)
        val_metrics = validate_epoch(model, val_loader, device, epoch, epochs)

        # Log metrics to history
        for key, value in train_metrics.items(): history[f'train_{key}'].append(value)
        for key, value in val_metrics.items(): history[f'val_{key}'].append(value)

        current_val_total_loss = val_metrics['total_loss']

        print(f"Epoch {epoch+1}/{epochs} Summary:")
        print(f"  Train: Loss={train_metrics['total_loss']:.4f}, Defer%={train_metrics['defer_rate_train_time']*100:.1f}%")
        print(f"  Val:   Loss={val_metrics['total_loss']:.4f}, Defer%={val_metrics['defer_rate_train_time']*100:.1f}%, Acc_ND_Val(TrainThresh)={val_metrics['accuracy_nd_val_train_thresh']:.4f}")

        # if scheduler: scheduler.step(current_val_total_loss)

        if current_val_total_loss < best_val_total_loss:
            best_val_total_loss = current_val_total_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"  -> New best validation loss: {best_val_total_loss:.4f}. Checkpoint saved.")
        else:
            epochs_no_improve += 1
            print(f"  -> Validation loss did not improve for {epochs_no_improve} epoch(s). Best: {best_val_total_loss:.4f}")

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    print("Training finished.")
    if os.path.exists(best_model_path):
        print(f"Loading best model state from {best_model_path}")
        model.load_state_dict(torch.load(best_model_path, map_location=device))
    else:
        print("Warning: No best model checkpoint found. Using model from last epoch.")

    return model, history

# Cell 11: Adaptive Deferral Threshold Utilities

In [None]:
# --- Cell 11: Adaptive Deferral Threshold Utilities ---

def compute_adaptive_deferral_threshold(
    model: UniversalCGDModel,
    val_loader: DataLoader,
    adaptive_thresh_config: Dict[str, Any], # From ADAPTIVE_THRESHOLD_CONFIG
    device: torch.device
) -> float:
    """
    Computes an adaptive deferral threshold based on validation set sensitivities.
    """
    model.eval()
    all_sensitivities = []
    all_y_true_val = []
    all_y_pred_logits_val = []

    print("Computing adaptive deferral threshold using validation set...")
    with torch.no_grad():
        for X_batch, y_batch, padding_mask_batch in tqdm(val_loader, desc="Adaptive Threshold Calc"):
            X_batch, y_batch, padding_mask_batch = X_batch.to(device), y_batch.to(device), padding_mask_batch.to(device)
            model_output = model(X_batch, padding_mask_batch=padding_mask_batch)
            all_sensitivities.append(model_output['final_sensitivity_score'].cpu())
            all_y_true_val.append(y_batch.cpu())
            all_y_pred_logits_val.append(model_output['y_pred_logits'].cpu())

    if not all_sensitivities:
        print("Warning: No sensitivities collected from validation set. Returning default threshold 0.5.")
        return 0.5

    sensitivities_np = torch.cat(all_sensitivities).numpy()
    y_true_np = torch.cat(all_y_true_val).numpy()
    y_pred_logits_np = torch.cat(all_y_pred_logits_val).numpy()

    if sensitivities_np.size == 0:
        print("Warning: Sensitivities array is empty. Returning default threshold 0.5.")
        return 0.5

    method = adaptive_thresh_config.get('method', 'percentile')
    num_candidates = adaptive_thresh_config.get('num_threshold_candidates', 200)

    # Determine potential threshold candidates from unique sensitivity scores or linspace
    # Using unique sorted scores is often better than linspace if distribution is skewed
    candidate_thresholds = np.unique(sensitivities_np)
    if len(candidate_thresholds) > num_candidates: # If too many unique values, sample them
        candidate_thresholds = np.percentile(sensitivities_np, np.linspace(0, 100, num_candidates))
    elif len(candidate_thresholds) == 0: # Should not happen if sensitivities_np is not empty
         print("Warning: No unique sensitivity scores found. Returning default 0.5")
         return 0.5


    optimal_threshold = 0.5 # Default

    if method == 'percentile':
        percentile_val = adaptive_thresh_config.get('percentile_value_for_threshold', 90)
        # Threshold is the sensitivity score at this percentile.
        # Samples with sensitivity *above* this threshold are deferred.
        optimal_threshold = np.percentile(sensitivities_np, percentile_val)
        print(f"Adaptive threshold (percentile {percentile_val}%): {optimal_threshold:.4f}")

    elif method == 'target_defer_rate':
        target_dr = adaptive_thresh_config.get('target_defer_rate_value', 0.10)
        best_thresh_for_target_dr = candidate_thresholds[0] # Start with lowest sens as threshold (max defer)
        min_dr_diff = float('inf')

        for thresh in candidate_thresholds:
            deferred_mask = sensitivities_np > thresh
            current_dr = np.mean(deferred_mask)
            dr_diff = abs(current_dr - target_dr)
            if dr_diff < min_dr_diff:
                min_dr_diff = dr_diff
                best_thresh_for_target_dr = thresh
            elif dr_diff == min_dr_diff and current_dr < target_dr : # Prefer lower DR if diff is same
                best_thresh_for_target_dr = thresh


        optimal_threshold = best_thresh_for_target_dr
        final_dr = np.mean(sensitivities_np > optimal_threshold)
        print(f"Adaptive threshold (target DR ~{target_dr*100:.1f}%): {optimal_threshold:.4f} (results in actual DR: {final_dr*100:.1f}%)")

    elif method == 'max_acc_under_budget':
        max_budget_dr = adaptive_thresh_config.get('max_defer_rate_budget', 0.20)
        best_thresh_for_acc = candidate_thresholds[-1] # Start with highest sens as threshold (min defer)
        max_acc_nd = -1.0

        # Determine predictions for accuracy calculation
        if model.output_dim_val == 1: # Binary
            preds_np = (1 / (1 + np.exp(-y_pred_logits_np))).squeeze() > 0.5 # Sigmoid then threshold
            true_labels_np = y_true_np.astype(float)
        else: # Multiclass
            preds_np = np.argmax(y_pred_logits_np, axis=1)
            true_labels_np = y_true_np.astype(int)

        for thresh in candidate_thresholds:
            deferred_mask = sensitivities_np > thresh
            current_dr = np.mean(deferred_mask)

            if current_dr <= max_budget_dr: # Only consider if within budget
                non_deferred_mask = ~deferred_mask
                if np.sum(non_deferred_mask) > 0:
                    acc_nd = accuracy_score(true_labels_np[non_deferred_mask], preds_np[non_deferred_mask])
                    if acc_nd > max_acc_nd:
                        max_acc_nd = acc_nd
                        best_thresh_for_acc = thresh
                    # If acc is same, prefer lower deferral rate (higher threshold)
                    elif acc_nd == max_acc_nd and thresh > best_thresh_for_acc :
                        best_thresh_for_acc = thresh

        optimal_threshold = best_thresh_for_acc
        final_dr = np.mean(sensitivities_np > optimal_threshold)
        final_acc_nd = -1.0
        if np.sum(~(sensitivities_np > optimal_threshold)) > 0:
           final_acc_nd = accuracy_score(true_labels_np[~(sensitivities_np > optimal_threshold)], preds_np[~(sensitivities_np > optimal_threshold)])
        print(f"Adaptive threshold (max Acc_ND under {max_budget_dr*100:.1f}% DR): {optimal_threshold:.4f}")
        print(f"  Results in: Actual DR={final_dr*100:.1f}%, Acc_ND={final_acc_nd*100:.2f}%")

    else:
        print(f"Warning: Unknown adaptive threshold method '{method}'. Using default 0.5.")
        optimal_threshold = 0.5

    return float(optimal_threshold)


# Cell 12: Evaluation Function

In [None]:
# --- Cell 12: Evaluation Function ---

def evaluate_model_with_adaptive_deferral(
    model: UniversalCGDModel,
    test_loader: DataLoader,
    adaptive_threshold: float,
    device: torch.device,
    dataset_name: str = "Test Set" # For printing purposes
) -> Dict[str, Any]:
    """
    Evaluates the trained CGD model on a test set using an adaptive deferral threshold.
    """
    model.eval()
    all_y_true_test = []
    all_y_pred_logits_test = []
    all_final_sensitivity_test = []

    print(f"\nEvaluating model on {dataset_name} with adaptive threshold: {adaptive_threshold:.4f}")
    with torch.no_grad():
        for X_batch, y_batch, padding_mask_batch in tqdm(test_loader, desc=f"Evaluating {dataset_name}"):
            X_batch, y_batch, padding_mask_batch = X_batch.to(device), y_batch.to(device), padding_mask_batch.to(device)
            model_output = model(X_batch, padding_mask_batch=padding_mask_batch)

            all_y_true_test.append(y_batch.cpu())
            all_y_pred_logits_test.append(model_output['y_pred_logits'].cpu())
            all_final_sensitivity_test.append(model_output['final_sensitivity_score'].cpu())

    if not all_y_true_test:
        print(f"Warning: No data processed during evaluation for {dataset_name}. Returning empty metrics.")
        return {}

    y_true_np = torch.cat(all_y_true_test).numpy()
    y_pred_logits_np = torch.cat(all_y_pred_logits_test).numpy()
    sensitivities_np = torch.cat(all_final_sensitivity_test).numpy()

    # --- Determine Predictions and Errors (Overall) ---
    is_binary_classification = model.output_dim_val == 1
    if is_binary_classification:
        y_pred_probs_overall = 1 / (1 + np.exp(-y_pred_logits_np.squeeze())) # Sigmoid
        y_pred_classes_overall = (y_pred_probs_overall > 0.5).astype(int)
        y_true_labels_overall = y_true_np.astype(float) # For consistency if originally float
    else: # Multiclass
        y_pred_probs_overall = F.softmax(torch.from_numpy(y_pred_logits_np), dim=1).numpy()
        y_pred_classes_overall = np.argmax(y_pred_logits_np, axis=1)
        y_true_labels_overall = y_true_np.astype(int)

    errors_overall = (y_pred_classes_overall != y_true_labels_overall).astype(int)
    accuracy_overall = accuracy_score(y_true_labels_overall, y_pred_classes_overall)

    # --- Apply Adaptive Deferral ---
    defer_eval_time_mask = sensitivities_np > adaptive_threshold # True for deferred samples
    defer_rate_eval = np.mean(defer_eval_time_mask)

    non_deferred_mask_eval = ~defer_eval_time_mask

    metrics = {
        'dataset_name': dataset_name,
        'adaptive_deferral_threshold': adaptive_threshold,
        'accuracy_overall': accuracy_overall, # Accuracy if no deferral happened
        'defer_rate_eval_time': defer_rate_eval,
        'total_samples': len(y_true_np),
        'num_deferred_eval_time': np.sum(defer_eval_time_mask),
        'num_non_deferred_eval_time': np.sum(non_deferred_mask_eval),
    }

    # --- Metrics for Non-Deferred Samples ---
    if np.sum(non_deferred_mask_eval) > 0:
        y_true_nd = y_true_labels_overall[non_deferred_mask_eval]
        y_pred_classes_nd = y_pred_classes_overall[non_deferred_mask_eval]
        y_pred_probs_nd = y_pred_probs_overall[non_deferred_mask_eval] # For AUC

        metrics['accuracy_non_deferred'] = accuracy_score(y_true_nd, y_pred_classes_nd)
        if is_binary_classification:
            metrics['precision_non_deferred'] = precision_score(y_true_nd, y_pred_classes_nd, zero_division=0)
            metrics['recall_non_deferred'] = recall_score(y_true_nd, y_pred_classes_nd, zero_division=0)
            metrics['f1_score_non_deferred'] = f1_score(y_true_nd, y_pred_classes_nd, zero_division=0)
            if len(np.unique(y_true_nd)) > 1: # AUC requires at least two classes in true labels
                 metrics['auc_non_deferred'] = roc_auc_score(y_true_nd, y_pred_probs_nd) # Use probabilities for AUC
            else:
                 metrics['auc_non_deferred'] = float('nan')
        else: # Multiclass
            metrics['precision_non_deferred'] = precision_score(y_true_nd, y_pred_classes_nd, average='weighted', zero_division=0)
            metrics['recall_non_deferred'] = recall_score(y_true_nd, y_pred_classes_nd, average='weighted', zero_division=0)
            metrics['f1_score_non_deferred'] = f1_score(y_true_nd, y_pred_classes_nd, average='weighted', zero_division=0)
            # For multiclass AUC, use one-vs-rest and then average (macro)
            if len(np.unique(y_true_nd)) >= model.output_dim_val and model.output_dim_val > 1 : # Check if all classes present for OvR
                try:
                    metrics['auc_non_deferred'] = roc_auc_score(y_true_nd, y_pred_probs_nd, multi_class='ovr', average='macro')
                except ValueError as e_auc: # Handle cases like only one class present in y_true_nd after filtering
                    print(f"Warning: Could not compute multiclass AUC for non-deferred: {e_auc}")
                    metrics['auc_non_deferred'] = float('nan')
            else:
                metrics['auc_non_deferred'] = float('nan')
    else:
        metrics.update({
            'accuracy_non_deferred': float('nan'), 'precision_non_deferred': float('nan'),
            'recall_non_deferred': float('nan'), 'f1_score_non_deferred': float('nan'),
            'auc_non_deferred': float('nan')
        })

    # --- Metrics for Deferred Samples ---
    if np.sum(defer_eval_time_mask) > 0:
        y_true_d = y_true_labels_overall[defer_eval_time_mask]
        y_pred_classes_d = y_pred_classes_overall[defer_eval_time_mask]
        metrics['accuracy_deferred'] = accuracy_score(y_true_d, y_pred_classes_d)
    else:
        metrics['accuracy_deferred'] = float('nan') # Or 0.0 if preferred when no samples deferred

    # --- Correlation between Sensitivity and Error ---
    if len(sensitivities_np) > 1 and len(errors_overall) > 1 and np.std(sensitivities_np) > 0 and np.std(errors_overall) > 0:
        metrics['sensitivity_error_correlation'] = np.corrcoef(sensitivities_np, errors_overall)[0, 1]
    else:
        metrics['sensitivity_error_correlation'] = float('nan')


    print(f"\n--- Evaluation Metrics for {dataset_name} ---")
    for key, value in metrics.items():
        if isinstance(value, float):
            print(f"  {key:<35}: {value:.4f}")
        else:
            print(f"  {key:<35}: {value}")
    print("--------------------------------------")

    return metrics


# Cell 13: Explanation & Visualization Utilities - Saliency Maps

In [None]:
# --- Cell 13: Explanation & Visualization Utilities - Saliency Maps ---

class CGDExplainer:
    """
    Explainer for the UniversalCGDModel using Captum.
    Provides methods to explain predictions and sensitivity scores.
    """
    def __init__(self, model: UniversalCGDModel, explainer_config: Dict[str, Any]):
        self.model = model # UniversalCGDModel instance
        self.explainer_config = explainer_config
        self.device = next(model.parameters()).device

        if Saliency is None or IntegratedGradients is None:
            print("Warning: Captum library not fully available. Explanation capabilities will be limited.")


    def _get_prediction_output_for_target_class(self, X_batch: torch.Tensor, padding_mask: Optional[torch.Tensor], target_class_idx: int) -> torch.Tensor:
        """
        Wrapper for Captum: Returns the logit of the target_class_idx.
        Input X_batch is already on the correct device.
        """
        model_output = self.model(X_batch, padding_mask_batch=padding_mask)
        # y_pred_logits shape: [batch_size, num_classes]
        return model_output['y_pred_logits'][:, target_class_idx] # Return score for the specific target class

    def _get_sensitivity_score_output(self, X_batch: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
        """
        Wrapper for Captum: Returns the final_sensitivity_score.
        Input X_batch is already on the correct device.
        """
        model_output = self.model(X_batch, padding_mask_batch=padding_mask)
        return model_output['final_sensitivity_score'] # Shape [batch_size]

    def attribute_input(
        self,
        X_batch: torch.Tensor, # Expected on self.device
        padding_mask: Optional[torch.Tensor], # Expected on self.device
        explanation_target: str = 'prediction', # 'prediction' or 'sensitivity'
        target_class_idx: Optional[int] = None, # Required if explanation_target is 'prediction'
        abs_attribution: bool = True
    ) -> Optional[np.ndarray]:
        """
        Computes feature attributions for the input batch.
        Args:
            X_batch: Input tensor [batch_size, seq_len, input_dim] on self.device
            padding_mask: Optional padding mask [batch_size, seq_len] on self.device
            explanation_target: 'prediction' or 'sensitivity'.
            target_class_idx: The index of the class to explain for 'prediction'.
                               If None for multiclass prediction, will use predicted class.
            abs_attribution: Whether to take the absolute value of attributions.
        Returns:
            Attributions as a NumPy array [batch_size, seq_len, input_dim] or None if error.
        """
        if Saliency is None and self.explainer_config.get('method') == 'saliency':
            print("Cannot compute Saliency: Captum not fully available or Saliency method failed to import.")
            return None
        if IntegratedGradients is None and self.explainer_config.get('method') == 'integrated_gradients':
            print("Cannot compute IntegratedGradients: Captum not fully available or IG method failed to import.")
            return None

        self.model.eval() # Ensure model is in eval mode for explanations
        original_input_requires_grad = X_batch.requires_grad
        X_batch.requires_grad_(True) # Input must require grad for Captum

        attributions = None

        forward_callable: Callable
        method_name = self.explainer_config.get('method', 'saliency')

        if explanation_target == 'prediction':
            if target_class_idx is None: # If no target class, explain the predicted class
                with torch.no_grad(): # Get predictions without affecting gradients for explanation
                    model_output_temp = self.model(X_batch, padding_mask_batch=padding_mask)
                    if self.model.output_dim_val == 1: # Binary
                         # For binary, target_class_idx is implicitly 0 (the single output neuron)
                        target_class_idx_eff = 0
                    else: # Multiclass
                        target_class_idx_eff = torch.argmax(model_output_temp['y_pred_logits'], dim=1)
                if isinstance(target_class_idx_eff, torch.Tensor):
                    print(f"Warning: Explaining predicted class for a batch. Using predicted class of first sample: {target_class_idx_eff[0].item()}")
                    target_class_idx_eff = target_class_idx_eff[0].item()

            else:
                target_class_idx_eff = target_class_idx

            # Create a partial function that fixes the target_class_idx argument
            current_forward_callable = lambda x_b, p_mask: self._get_prediction_output_for_target_class(x_b, p_mask, target_class_idx_eff)

        elif explanation_target == 'sensitivity':
            current_forward_callable = self._get_sensitivity_score_output
        else:
            print(f"Error: Unknown explanation_target '{explanation_target}'.")
            X_batch.requires_grad_(original_input_requires_grad) # Restore original grad status
            return None

        inputs_for_captum = (X_batch, padding_mask)

        try:
            if method_name == 'saliency' and Saliency is not None:
                explainer_algo = Saliency(current_forward_callable)
                # NoiseTunnel for SmoothGrad/SmoothGrad-Sq
                use_noise_tunnel = self.explainer_config.get('noise_tunnel_nt_samples', 0) > 0
                if use_noise_tunnel:
                    nt_type = self.explainer_config.get('noise_tunnel_nt_type', 'smoothgrad')
                    stdevs = self.explainer_config.get('noise_tunnel_stdevs', 0.1)
                    nt_samples = self.explainer_config.get('noise_tunnel_nt_samples', 5)
                    noise_tunnel = NoiseTunnel(explainer_algo)
                    attributions = noise_tunnel.attribute(
                        inputs_for_captum[0], # Only pass X_batch for attribution, not mask
                        nt_type=nt_type, stdevs=stdevs, nt_samples=nt_samples,
                        additional_forward_args=inputs_for_captum[1:], # Pass padding_mask here
                        abs=False # Get raw grads, apply abs later if needed
                    )
                else:
                    attributions = explainer_algo.attribute(
                        inputs_for_captum[0],
                        additional_forward_args=inputs_for_captum[1:],
                        abs=False
                    )

            elif method_name == 'integrated_gradients' and IntegratedGradients is not None:
                explainer_algo = IntegratedGradients(current_forward_callable)
                baselines = torch.zeros_like(X_batch) # Common baseline for IG
                n_steps = self.explainer_config.get('n_steps_ig', 25)
                attributions = explainer_algo.attribute(
                    inputs_for_captum[0],
                    baselines=baselines,
                    additional_forward_args=inputs_for_captum[1:],
                    n_steps=n_steps,
                    internal_batch_size=X_batch.size(0) # Process batch at once if possible
                )
            else:
                print(f"Warning: Explanation method '{method_name}' not supported or Captum not available.")
                X_batch.requires_grad_(original_input_requires_grad)
                return None

            if attributions is not None:
                if abs_attribution:
                    attributions = torch.abs(attributions)
                # Sum over the input_dim if it's > 1, for ECG it's 1 so squeeze works.
                # Attributions shape: [batch_size, seq_len, input_dim]
                attributions_np = attributions.detach().cpu().numpy()
                if padding_mask is not None: # Zero out attributions for padded regions
                    pad_mask_np = padding_mask.cpu().numpy() # B, S
                    attributions_np = attributions_np * (~pad_mask_np[:, :, np.newaxis])

                X_batch.requires_grad_(original_input_requires_grad) # Restore original grad status
                return attributions_np

        except Exception as e:
            print(f"Error during Captum attribution for method '{method_name}': {e}")
            import traceback
            traceback.print_exc()
            X_batch.requires_grad_(original_input_requires_grad) # Restore original grad status
            return None

        X_batch.requires_grad_(original_input_requires_grad)
        return None # Should be unreachable if logic is correct

print("\nCell 13: Explanation Utilities executed successfully.")

# Cell 14: Visualization Functions


In [None]:
# --- Corrected Cell 14: Visualization Functions ---

def plot_training_history(
    history: Dict[str, List[float]],
    metrics_to_plot: List[Dict[str, str]], # List of dicts specifying keys and plot titles
    save_path: Optional[str] = "training_history.png"
):
    """
    Plots specified metrics from the training history.
    Each dict in metrics_to_plot should be:
    {'train_key': 'train_loss_key', 'val_key': 'val_loss_key',
     'title': 'Plot Title', 'ylabel': 'Y-axis Label'}
    """
    num_plots = len(metrics_to_plot)
    if num_plots == 0:
        print("No metrics specified for plotting training history.")
        return

    cols = 2 if num_plots > 1 else 1
    rows = (num_plots + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(7 * cols, 5 * rows), squeeze=False)
    axes = axes.flatten()

    for i, metric_info in enumerate(metrics_to_plot):
        ax = axes[i]
        train_key = metric_info.get('train_key')
        val_key = metric_info.get('val_key')

        if train_key and train_key in history and len(history[train_key]) > 0:
            ax.plot(history[train_key], label=f"Train {metric_info.get('ylabel', 'Value')}", marker='.')
        if val_key and val_key in history and len(history[val_key]) > 0:
            ax.plot(history[val_key], label=f"Validation {metric_info.get('ylabel', 'Value')}", marker='.')

        ax.set_title(metric_info.get('title', 'Training Metric'))
        ax.set_xlabel("Epoch")
        ax.set_ylabel(metric_info.get('ylabel', 'Value'))
        if (train_key and train_key in history and len(history[train_key]) > 0) or \
           (val_key and val_key in history and len(history[val_key]) > 0) : # Only show legend if there's something to plot
            ax.legend()
        ax.grid(True)

    for j in range(num_plots, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"Training history plot saved to {save_path}")
    plt.show()

def plot_confusion_matrix_custom(
    y_true: np.ndarray,
    y_pred_classes: np.ndarray,
    class_names: List[str],
    title: str = "Confusion Matrix",
    save_path: Optional[str] = "confusion_matrix.png"
):
    """Plots a normalized confusion matrix."""
    if len(y_true) == 0 or len(y_pred_classes) == 0:
        print(f"Cannot plot confusion matrix for '{title}': No data provided.")
        return

    cm = confusion_matrix(y_true, y_pred_classes, labels=np.arange(len(class_names)))
    cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + EPSILON)

    plt.figure(figsize=(max(8, len(class_names)), max(6, len(class_names) * 0.8)))
    sns.heatmap(cm_normalized, annot=True, fmt=".2%", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names,
                annot_kws={"size": 10 if len(class_names) < 10 else 7}) # Adjust font size
    plt.title(title)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"Confusion matrix plot saved to {save_path}")
    plt.show()


def plot_sensitivity_distributions(
    sensitivities: np.ndarray,
    errors: np.ndarray, # Boolean array (True if error)
    defer_mask_eval: Optional[np.ndarray], # Boolean array (True if deferred by adaptive threshold)
    adaptive_threshold: Optional[float],
    title: str = "Sensitivity Score Distribution",
    save_path: Optional[str] = "sensitivity_distribution.png"
):
    """Plots distributions of sensitivity scores."""
    if sensitivities.size == 0:
        print(f"Cannot plot sensitivity distribution for '{title}': No sensitivity data.")
        return

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    sns.histplot(sensitivities[~errors], label="Correct Predictions", color="green", kde=True, stat="density", element="step")
    if np.sum(errors) > 0 :
        sns.histplot(sensitivities[errors], label="Incorrect Predictions", color="red", kde=True, stat="density", element="step")
    if adaptive_threshold is not None:
        plt.axvline(adaptive_threshold, color="purple", linestyle="--", label=f"Adaptive Thresh ({adaptive_threshold:.3f})")
    plt.title("Sensitivity by Prediction Correctness")
    plt.xlabel("Final Sensitivity Score")
    plt.ylabel("Density")
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    if defer_mask_eval is not None:
        if np.sum(~defer_mask_eval) > 0:
            sns.histplot(sensitivities[~defer_mask_eval], label="Non-Deferred by Eval Thresh", color="blue", kde=True, stat="density", element="step")
        if np.sum(defer_mask_eval) > 0:
            sns.histplot(sensitivities[defer_mask_eval], label="Deferred by Eval Thresh", color="orange", kde=True, stat="density", element="step")
        if adaptive_threshold is not None:
            plt.axvline(adaptive_threshold, color="purple", linestyle="--", label=f"Adaptive Thresh ({adaptive_threshold:.3f})")
        plt.title("Sensitivity by Eval Deferral Decision")
    else:
        sns.histplot(sensitivities, label="All Samples", color="gray", kde=True, stat="density")
        plt.title("Overall Sensitivity Distribution")

    plt.xlabel("Final Sensitivity Score")
    plt.ylabel("Density")
    plt.legend()
    plt.grid(True)

    fig_title = title if title != "Sensitivity Score Distribution" else "Sensitivity Score Distributions"
    plt.suptitle(fig_title, fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    if save_path:
        plt.savefig(save_path)
        print(f"Sensitivity distribution plot saved to {save_path}")
    plt.show()

def plot_roc_auc_curves(
    y_true_list: List[np.ndarray],
    y_pred_probs_list: List[np.ndarray],
    label_names_list: List[str],
    output_dim: int,
    class_names: Optional[List[str]] = None,
    title: str = "ROC Curves",
    save_path: Optional[str] = "roc_curves.png"
):
    """Plots ROC curves for binary or multiclass (OvR) classifications."""
    if not y_true_list or not y_pred_probs_list or len(y_true_list) != len(y_pred_probs_list):
        print(f"Cannot plot ROC for '{title}': Invalid input lists.")
        return

    plt.figure(figsize=(8, 6))

    for i, (y_true, y_pred_probs, curve_label_name) in enumerate(zip(y_true_list, y_pred_probs_list, label_names_list)):
        if y_true.size == 0 or y_pred_probs.size == 0:
            print(f"Skipping ROC curve for '{curve_label_name}': empty data.")
            continue

        if output_dim == 1:
            if len(np.unique(y_true)) < 2:
                print(f"Skipping ROC for '{curve_label_name}' (binary): only one class present in y_true.")
                continue
            fpr, tpr, _ = roc_curve(y_true, y_pred_probs)
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f'{curve_label_name} (AUC = {roc_auc:.3f})')
        else:
            fpr = dict()
            tpr = dict()
            roc_auc = dict()
            valid_classes_for_auc = 0
            auc_sum = 0

            for class_idx in range(output_dim):
                y_true_class = (y_true == class_idx).astype(int)
                if len(np.unique(y_true_class)) < 2:
                    continue
                y_pred_probs_class = y_pred_probs[:, class_idx]
                fpr[class_idx], tpr[class_idx], _ = roc_curve(y_true_class, y_pred_probs_class)
                roc_auc[class_idx] = auc(fpr[class_idx], tpr[class_idx])
                auc_sum += roc_auc[class_idx]
                valid_classes_for_auc +=1

            if valid_classes_for_auc > 0:
                macro_auc = auc_sum / valid_classes_for_auc
                first_valid_class = next((k for k in fpr if fpr[k] is not None), None)
                if first_valid_class is not None:
                     plt.plot(fpr[first_valid_class], tpr[first_valid_class], lw=2, label=f'{curve_label_name} (Macro AUC = {macro_auc:.3f})')
                else:
                     plt.plot([0,1],[0,1], linestyle='--', lw=1, label=f'{curve_label_name} (Macro AUC = {macro_auc:.3f} - No individual plottable)')
            else:
                 print(f"No valid classes for ROC AUC calculation in '{curve_label_name}' (multiclass).")

    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"ROC curve plot saved to {save_path}")
    plt.show()

def plot_ecg_with_saliency(
    ecg_signal: np.ndarray,
    attributions: Optional[np.ndarray],
    title: str = "ECG Signal with Saliency",
    save_path: Optional[str] = None,
    true_label_name: Optional[str] = None,
    pred_label_name: Optional[str] = None,
    is_deferred: Optional[bool] = None,
    sensitivity_score: Optional[float] = None,
    sampling_rate: int = 125
):
    """Plots ECG signal with saliency attributions overlaid or shown below."""
    ecg_signal = ecg_signal.squeeze()
    if attributions is not None:
        attributions = attributions.squeeze()
        if attributions.shape != ecg_signal.shape:
            print(f"Warning: Attributions shape {attributions.shape} differs from ECG signal shape {ecg_signal.shape}. Cannot overlay directly.")
            attributions = None

    time_axis = np.arange(len(ecg_signal)) / sampling_rate
    fig, ax1 = plt.subplots(figsize=(15, 5))

    color_ecg = 'tab:blue'
    ax1.set_xlabel(f"Time (s) - Sample Rate: {sampling_rate}Hz")
    ax1.set_ylabel("ECG Amplitude", color=color_ecg)
    ax1.plot(time_axis, ecg_signal, color=color_ecg, linewidth=1.5, label="ECG Signal")
    ax1.tick_params(axis='y', labelcolor=color_ecg)
    ax1.grid(True, axis='x', linestyle=':')
    ax1.grid(True, axis='y', linestyle=':', color=color_ecg, alpha=0.5)

    if attributions is not None:
        ax2 = ax1.twinx()
        color_saliency = 'tab:red'
        ax2.set_ylabel("Attribution Score", color=color_saliency)
        ax2.bar(time_axis, attributions, color=color_saliency, alpha=0.6, width=1/sampling_rate, label="Saliency")
        ax2.tick_params(axis='y', labelcolor=color_saliency)
        ax2.grid(True, axis='y', linestyle=':', color=color_saliency, alpha=0.5)

    full_title = title
    details = []
    if true_label_name: details.append(f"True: {true_label_name}")
    if pred_label_name: details.append(f"Pred: {pred_label_name}")
    if is_deferred is not None: details.append(f"Deferred: {is_deferred}")
    if sensitivity_score is not None: details.append(f"Sens: {sensitivity_score:.3f}")
    if details: full_title += "\n(" + " | ".join(details) + ")"
    plt.title(full_title, fontsize=14)

    lines, labels = ax1.get_legend_handles_labels()
    if attributions is not None and 'ax2' in locals(): # Check if ax2 was created
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax2.legend(lines + lines2, labels + labels2, loc='upper right')
    else:
        ax1.legend(loc='upper right')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"ECG with saliency plot saved to {save_path}")
    plt.show()

def plot_deferral_performance_vs_threshold(
    sensitivities_val: np.ndarray,
    y_true_val: np.ndarray,
    y_pred_logits_val: np.ndarray,
    model_output_dim: int,
    title: str = "Deferral Performance vs. Sensitivity Threshold",
    save_path: Optional[str] = "deferral_performance.png",
    num_threshold_points: int = 50
):
    """
    Plots Accuracy_Non_Deferred vs. Deferral_Rate across a range of sensitivity thresholds.
    """
    if sensitivities_val.size == 0:
        print(f"Cannot plot deferral performance for '{title}': No sensitivity data.")
        return

    threshold_candidates = np.percentile(sensitivities_val, np.linspace(0, 100, num_threshold_points))
    threshold_candidates = np.unique(threshold_candidates)

    deferral_rates = []
    accuracies_nd = []

    if model_output_dim == 1:
        preds_overall = (1 / (1 + np.exp(-y_pred_logits_val.squeeze())) > 0.5).astype(int)
        true_labels_overall = y_true_val.astype(float)
    else:
        preds_overall = np.argmax(y_pred_logits_val, axis=1)
        true_labels_overall = y_true_val.astype(int)

    for thresh in tqdm(threshold_candidates, desc="Plotting Deferral Curve"):
        defer_mask = sensitivities_val > thresh
        current_dr = np.mean(defer_mask)
        deferral_rates.append(current_dr)

        non_deferred_mask = ~defer_mask
        if np.sum(non_deferred_mask) > 0:
            acc_nd = accuracy_score(true_labels_overall[non_deferred_mask], preds_overall[non_deferred_mask])
            accuracies_nd.append(acc_nd)
        else:
            accuracies_nd.append(0.0)

    fig, ax1 = plt.subplots(figsize=(10, 6))
    color1 = 'tab:blue'
    ax1.set_xlabel("Deferral Rate")
    ax1.set_ylabel("Accuracy on Non-Deferred Samples", color=color1)
    ax1.plot(deferral_rates, accuracies_nd, color=color1, marker='o', linestyle='-')
    ax1.tick_params(axis='y', labelcolor=color1)
    ax1.grid(True)

    ax2 = ax1.twiny()
    ax2.set_xlabel("Sensitivity Threshold Value")
    # For twinx/twiny, to align ticks, the data for the second axis plot doesn't matter as much as the limits
    # We plot the thresholds against the accuracies_nd to get the scale, but then make line invisible
    ax2.plot(threshold_candidates, accuracies_nd, alpha=0)
    ax2.set_xlim(threshold_candidates.min(), threshold_candidates.max())


    plt.title(title)
    plt.tight_layout() # May need adjustment if labels overlap
    if save_path:
        plt.savefig(save_path)
        print(f"Deferral performance plot saved to {save_path}")
    plt.show()

print("\nCell 14 (Corrected): Visualization Functions executed successfully.")

# Cell 15: Setup and Model Configuration for MIT-BIH Dataset

In [None]:
# --- Cell 15: Setup and Model Configuration for MIT-BIH Dataset ---

if 'mitbih_loaders' not in globals() or mitbih_loaders is None:
    print("ERROR: `mitbih_loaders` not found. Please ensure Cell 3 (Revised) was run successfully and created these DataLoaders.")
    # mitbih_loaders = {'train': None, 'val': None, 'test': None} # Placeholder to avoid immediate crash
else:
    print("MIT-BIH DataLoaders found.")
    if not all(k in mitbih_loaders for k in ['train', 'val', 'test']):
        print("Warning: MIT-BIH loaders might be incomplete. Expected 'train', 'val', 'test'.")


# 2. Define MIT-BIH Class Names
MITBIH_CLASS_NAMES = ['Normal (N)', 'Supraventricular (S)', 'Ventricular (V)', 'Fusion (F)', 'Unknown (Q)']
MITBIH_NUM_CLASSES = len(MITBIH_CLASS_NAMES) # Should be 5

# 3. Instantiate UniversalCGDModel for MIT-BIH
# We'll use the global configurations from Cell 2, specifying the output_dim.
print(f"\nConfiguring UniversalCGDModel for MIT-BIH ({MITBIH_NUM_CLASSES} classes)...")

# Ensure all config dictionaries from Cell 2 are available
config_vars = ['CGD_MODEL_CONFIG', 'ENCODER_CONFIG', 'PREDICTOR_CONFIG',
               'PERTURBATION_CONFIG', 'SENSITIVITY_CONFIG', 'STRUCTURAL_REGULARIZER_CONFIG']
missing_configs = [cv for cv in config_vars if cv not in globals()]
if missing_configs:
    raise NameError(f"Missing one or more configuration dictionaries: {missing_configs}. Please re-run Cell 2.")

# Create the model instance
try:
    cgd_model_mitbih = UniversalCGDModel(
        model_config=CGD_MODEL_CONFIG,
        encoder_config=ENCODER_CONFIG,
        predictor_config=PREDICTOR_CONFIG, # output_dim is passed separately
        perturb_config=PERTURBATION_CONFIG,
        sensitivity_config=SENSITIVITY_CONFIG,
        regularizer_config=STRUCTURAL_REGULARIZER_CONFIG,
        output_dim=MITBIH_NUM_CLASSES # Crucial: set number of output classes
    ).to(DEVICE)

    print("\nUniversalCGDModel for MIT-BIH instantiated successfully:")
    # Quick check on a component
    print(f"  Encoder aggregation: {cgd_model_mitbih.encoder.aggregation_method}")
    print(f"  Predictor output features: {cgd_model_mitbih.predictor.predictor_mlp[-1].out_features}") # Last layer of MLP
    print(f"  Perturbation types: {cgd_model_mitbih.perturbation_generator._get_active_perturbations()}")
    print(f"  Sensitivity measures: {cgd_model_mitbih.sensitivity_calculator._get_active_measures()}")
    print(f"  Regularizer type: {cgd_model_mitbih.structural_regularizer.regularization_type}")
    print(f"  Model is on device: {next(cgd_model_mitbih.parameters()).device}")

except Exception as e:
    print(f"Error instantiating UniversalCGDModel for MIT-BIH: {e}")
    import traceback
    traceback.print_exc()
    cgd_model_mitbih = None # Ensure it's None if instantiation fails

print("\nCell 15: Setup and Model Configuration for MIT-BIH Dataset executed successfully.")

# Cell 16: Training & Evaluation - MIT-BIH Dataset

In [None]:
# --- Cell 16: Stage 1 MIT-BIH Training (Train Best Base Classifier) - REWRITTEN ---

# Ensure necessary variables from previous cells are available
if 'cgd_model_mitbih' not in globals() or cgd_model_mitbih is None:
    raise NameError("`cgd_model_mitbih` is not defined. Please run Cell 15 with Stage 1 configurations first.")
if 'mitbih_loaders' not in globals() or mitbih_loaders is None or \
   not all(k in mitbih_loaders for k in ['train', 'val', 'test']):
    raise NameError("`mitbih_loaders` are not properly defined. Please run Cell 3 (Revised) first.")
if 'GENERAL_TRAINING_CONFIG' not in globals() or 'CGD_MODEL_CONFIG' not in globals() or \
   'ADAPTIVE_THRESHOLD_CONFIG' not in globals() or 'EXPLAINER_CONFIG' not in globals(): # Ensure all needed configs are present
    raise NameError("Core GNERAL_TRAINING_CONFIG or other specific configurations are missing. Please re-run Cell 2.")
if 'MITBIH_CLASS_NAMES' not in globals() or 'MITBIH_NUM_CLASSES' not in globals():
    raise NameError("`MITBIH_CLASS_NAMES` or `MITBIH_NUM_CLASSES` not defined. Please run Cell 15 first.")

print(f"Cell 16 (Stage 1 MIT-BIH) will use DEVICE: {DEVICE}")
# Verify Stage 1 specific configurations
print(f"Confirming CGD_MODEL_CONFIG for Stage 1: loss_beta={CGD_MODEL_CONFIG.get('loss_beta')}, deferral_threshold_train={CGD_MODEL_CONFIG.get('deferral_threshold_train')}")
if CGD_MODEL_CONFIG.get('loss_beta') != 0.0 or CGD_MODEL_CONFIG.get('deferral_threshold_train') < 1e8 :
    print("WARNING: CGD_MODEL_CONFIG does not appear to be correctly set for Stage 1 training.")
    print("For Stage 1 (Best Base Classifier), 'loss_beta' should be 0.0 and 'deferral_threshold_train' should be float('inf') or very large.")
    print("Please ensure configurations in Cell 2 are set for Stage 1 and re-run Cell 15 before this cell.")


# 1. Train the Base Model for MIT-BIH (Focus on classification accuracy)
print("--- Starting MIT-BIH Model Training (Stage 1 - Base Classifier) ---")
mitbih_stage1_checkpoint_filename = "mitbih_base_classifier_best.pt" # Specific name for Stage 1 model

cgd_model_mitbih.to(DEVICE) # Ensure model instance from Cell 15 is on the correct device

stage1_mitbih_train_config = GENERAL_TRAINING_CONFIG.copy()
# stage1_mitbih_train_config['epochs'] = 50 # Example: more epochs for base model training

# This single call trains the model and saves the best version based on validation loss
trained_base_mitbih_model, mitbih_stage1_history = train_universal_cgd_model(
    model=cgd_model_mitbih,
    train_loader=mitbih_loaders['train'],
    val_loader=mitbih_loaders['val'],
    training_config=stage1_mitbih_train_config, # Use the potentially adjusted config
    model_specific_checkpoint_name=mitbih_stage1_checkpoint_filename
)
print("--- MIT-BIH Model Training (Stage 1) Finished ---")

# 2. Plot Training History for Stage 1
print("\n--- Plotting MIT-BIH Stage 1 Training History ---")
if mitbih_stage1_history: # Check if history is not empty
    mitbih_stage1_metrics_to_plot = [
        {'train_key': 'train_total_loss', 'val_key': 'val_total_loss', 'title': 'MIT-BIH Stage 1 Total Loss', 'ylabel': 'Loss'},
        {'train_key': 'train_pred_loss', 'val_key': 'val_pred_loss', 'title': 'MIT-BIH Stage 1 Prediction Loss', 'ylabel': 'Pred Loss'},
        # Defer rate should be ~0% if deferral_threshold_train is float('inf') in CGD_MODEL_CONFIG
        {'train_key': 'train_defer_rate_train_time', 'val_key': 'val_defer_rate_train_time', 'title': 'MIT-BIH Stage 1 Defer Rate (During Train)', 'ylabel': 'Defer Rate'},
        # val_accuracy_nd_val_train_thresh will be the overall validation accuracy if defer rate is 0
        {'val_key': 'val_accuracy_nd_val_train_thresh', 'title': 'MIT-BIH Stage 1 Val Acc (Effectively Overall)', 'ylabel': 'Accuracy'}
    ]
    plot_training_history(
        mitbih_stage1_history,
        mitbih_stage1_metrics_to_plot,
        save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_stage1_training_history.png")
    )
else:
    print("MIT-BIH Stage 1 training history is empty, skipping plot.")

# 3. Evaluate the Stage 1 Base Model on the Test Set
# For Stage 1, the primary interest is overall performance without deferral.
# We call evaluate_model_with_adaptive_deferral with a very high threshold to get 'accuracy_overall'.
print("\n--- Evaluating MIT-BIH Stage 1 Base Model on Test Set (Overall Performance) ---")
mitbih_stage1_eval_metrics = {} # Initialize
if trained_base_mitbih_model is not None and 'test' in mitbih_loaders and mitbih_loaders['test'] is not None:
    # Ensure the loaded best model is used for evaluation
    trained_base_mitbih_model.to(DEVICE)
    trained_base_mitbih_model.eval()

    mitbih_stage1_eval_metrics = evaluate_model_with_adaptive_deferral(
        model=trained_base_mitbih_model,
        test_loader=mitbih_loaders['test'],
        adaptive_threshold=float('inf'), # This ensures defer_rate_eval_time is 0.0
        device=DEVICE,
        dataset_name="MIT-BIH Test Set (Stage 1 Base Model - Overall)"
    )
    print(f"MIT-BIH Stage 1 Base Model - Test Accuracy Overall: {mitbih_stage1_eval_metrics.get('accuracy_overall', 'N/A'):.4f}")
    # Other metrics like AUC_non_deferred will represent overall AUC here
    print(f"MIT-BIH Stage 1 Base Model - Test AUC Overall: {mitbih_stage1_eval_metrics.get('auc_non_deferred', 'N/A'):.4f}")
else:
    print("Skipping Stage 1 MIT-BIH model evaluation (trained model or test loader not available).")

# 4. Visualizations for Stage 1 Overall Performance on Test Set
print("\n--- Visualizing MIT-BIH Stage 1 Test Set Overall Performance ---")
if trained_base_mitbih_model is not None and 'test' in mitbih_loaders and mitbih_loaders['test'] is not None and mitbih_stage1_eval_metrics:
    # Re-fetch data for plots to ensure consistency and have probabilities for ROC
    all_y_true_s1_viz = []
    all_y_pred_logits_s1_viz = []
    trained_base_mitbih_model.eval()
    with torch.no_grad():
        for X_b, y_b, p_mask_b in tqdm(mitbih_loaders['test'], desc="Fetching Test Data for Stage 1 Plots"):
            X_b, y_b, p_mask_b = X_b.to(DEVICE), y_b.to(DEVICE), p_mask_b.to(DEVICE)
            output = trained_base_mitbih_model(X_b, p_mask_b)
            all_y_true_s1_viz.append(y_b.cpu())
            all_y_pred_logits_s1_viz.append(output['y_pred_logits'].cpu())

    if all_y_true_s1_viz:
        y_true_test_np_s1 = torch.cat(all_y_true_s1_viz).numpy().astype(int)
        y_pred_logits_test_np_s1 = torch.cat(all_y_pred_logits_s1_viz).numpy()

        y_pred_probs_test_s1 = F.softmax(torch.from_numpy(y_pred_logits_test_np_s1), dim=1).numpy()
        y_pred_classes_test_s1 = np.argmax(y_pred_logits_test_np_s1, axis=1)

        plot_confusion_matrix_custom(
            y_true=y_true_test_np_s1,
            y_pred_classes=y_pred_classes_test_s1,
            class_names=MITBIH_CLASS_NAMES,
            title="MIT-BIH Stage 1 Base Model CM (Overall Test)",
            save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_stage1_cm_overall_test.png")
        )
        plot_roc_auc_curves(
            y_true_list=[y_true_test_np_s1],
            y_pred_probs_list=[y_pred_probs_test_s1],
            label_names_list=["MIT-BIH Stage 1 Base Model (Overall)"],
            output_dim=MITBIH_NUM_CLASSES,
            class_names=MITBIH_CLASS_NAMES,
            title="MIT-BIH Stage 1 Base Model ROC Curve (Overall Test)",
            save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_stage1_roc_overall_test.png")
        )
    else:
        print("Skipping MIT-BIH Stage 1 test visualizations as test data could not be re-fetched or was empty.")
else:
    print("Skipping MIT-BIH Stage 1 visualizations (trained base model or test loader not available, or prior evaluation failed).")

# 5. Explanations for the Stage 1 Base Model
# This helps understand what the base model learned, independently of deferral.
captum_ready_s1 = 'Saliency' in globals() and Saliency is not None and 'IntegratedGradients' in globals() and IntegratedGradients is not None
explainer_class_ready_s1 = 'CGDExplainer' in globals() and callable(CGDExplainer)

if captum_ready_s1 and explainer_class_ready_s1 and trained_base_mitbih_model is not None and \
   'mitbih_loaders' in globals() and mitbih_loaders['test'] is not None:
    print("\n--- (Optional) Visualizing Explanations for MIT-BIH Stage 1 Base Model ---")
    mitbih_base_explainer = CGDExplainer(trained_base_mitbih_model, EXPLAINER_CONFIG) # Using global EXPLAINER_CONFIG
    num_samples_to_explain_s1 = 2
    explained_count_s1 = 0
    try:
        for X_ex_b, y_ex_b, p_mask_ex_b in mitbih_loaders['test']:
            if explained_count_s1 >= num_samples_to_explain_s1: break
            samples_to_take = min(X_ex_b.size(0), num_samples_to_explain_s1 - explained_count_s1)

            for i_s1 in range(samples_to_take):
                if explained_count_s1 >= num_samples_to_explain_s1: break
                X_s1, y_s1_true, p_mask_s1 = X_ex_b[i_s1:i_s1+1].to(DEVICE), y_ex_b[i_s1].to(DEVICE), p_mask_ex_b[i_s1:i_s1+1].to(DEVICE)

                with torch.no_grad():
                    output_s1 = trained_base_mitbih_model(X_s1, p_mask_s1)
                pred_logits_s1 = output_s1['y_pred_logits']
                pred_class_s1 = torch.argmax(pred_logits_s1, dim=1).item()

                true_lbl_name_s1 = MITBIH_CLASS_NAMES[y_s1_true.item()]
                pred_lbl_name_s1 = MITBIH_CLASS_NAMES[pred_class_s1]

                print(f"Explaining MIT-BIH Stage 1 Sample {explained_count_s1+1}: True='{true_lbl_name_s1}', Pred='{pred_lbl_name_s1}'")
                attrs_pred_s1 = mitbih_base_explainer.attribute_input(
                    X_s1, p_mask_s1, explanation_target='prediction', target_class_idx=pred_class_s1, abs_attribution=True
                )
                if attrs_pred_s1 is not None:
                    plot_ecg_with_saliency(
                        ecg_signal=X_s1.squeeze().cpu().numpy(), attributions=attrs_pred_s1.squeeze(),
                        title=f"MIT-BIH Stage 1 Sample {explained_count_s1+1} - Prediction Explanation",
                        save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], f"mitbih_stage1_explain_pred_sample{explained_count_s1}.png"),
                        true_label_name=true_lbl_name_s1, pred_label_name=pred_lbl_name_s1
                    )
                explained_count_s1 += 1
            if explained_count_s1 >= num_samples_to_explain_s1: break
    except Exception as e_s1_explain:
        print(f"Error during MIT-BIH Stage 1 explanation: {e_s1_explain}")
else:
    print("\nSkipping MIT-BIH Stage 1 explanation visualization (components missing).")


# Final check and message for Stage 1
final_mitbih_base_model_path = os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], mitbih_stage1_checkpoint_name)
if trained_base_mitbih_model is not None and os.path.exists(final_mitbih_base_model_path):
    print(f"\n--- MIT-BIH Stage 1 Analysis Complete. Best base model successfully saved to {final_mitbih_base_model_path} ---")
    print(f"This model achieved an overall test accuracy of: {mitbih_stage1_eval_metrics.get('accuracy_overall', 'N/A'):.4f}")
else:
    print(f"\n--- MIT-BIH Stage 1 Analysis Incomplete. ---")
    if trained_base_mitbih_model is None:
         print("   Reason: Base model training did not complete or failed.")
    elif not os.path.exists(final_mitbih_base_model_path):
         print(f"   Reason: Best base model checkpoint NOT FOUND at {final_mitbih_base_model_path}")
         print("   Please check training logs from 'train_universal_cgd_model' to ensure the model improved and saved correctly.")

In [None]:
# --- Cell 17: Setup and Model Configuration for PTB Dataset ---

# 1. Confirm DataLoaders are available
if 'ptb_loaders' not in globals() or ptb_loaders is None:
    print("ERROR: `ptb_loaders` not found. Please ensure Cell 3 (Revised) was run successfully and created these DataLoaders.")
    # Fallback for standalone run (not recommended for full workflow)
    # ptb_loaders = {'train': None, 'val': None, 'test': None}
else:
    print("PTB DataLoaders found.")
    if not all(k in ptb_loaders for k in ['train', 'val', 'test']):
        print("Warning: PTB loaders might be incomplete. Expected 'train', 'val', 'test'.")

# 2. Define PTB Class Names
PTB_CLASS_NAMES = ['Normal', 'Myocardial Infarction (Abnormal)']
PTB_NUM_CLASSES = 1 # For binary classification with BCEWithLogitsLoss, output_dim is 1
# If using CrossEntropyLoss for binary, PTB_NUM_CLASSES would be 2.
# Our UniversalCGDModel's compute_loss handles output_dim_val == 1 for BCE.

# 3. Instantiate UniversalCGDModel for PTB
print(f"\nConfiguring UniversalCGDModel for PTB ({PTB_NUM_CLASSES} output neuron for BCEWithLogitsLoss)...")

config_vars = ['CGD_MODEL_CONFIG', 'ENCODER_CONFIG', 'PREDICTOR_CONFIG',
               'PERTURBATION_CONFIG', 'SENSITIVITY_CONFIG', 'STRUCTURAL_REGULARIZER_CONFIG']
missing_configs = [cv for cv in config_vars if cv not in globals()]
if missing_configs:
    raise NameError(f"Missing one or more configuration dictionaries: {missing_configs}. Please re-run Cell 2.")


try:
    cgd_model_ptb = UniversalCGDModel(
        model_config=CGD_MODEL_CONFIG,
        encoder_config=ENCODER_CONFIG,
        predictor_config=PREDICTOR_CONFIG, # output_dim is passed separately
        perturb_config=PERTURBATION_CONFIG, # Or ptb_perturb_config if defined
        sensitivity_config=SENSITIVITY_CONFIG, # Or ptb_sensitivity_config if defined
        regularizer_config=STRUCTURAL_REGULARIZER_CONFIG,
        output_dim=PTB_NUM_CLASSES # Crucial: set to 1 for binary tasks using BCEWithLogitsLoss
    ).to(DEVICE)

    print("\nUniversalCGDModel for PTB instantiated successfully:")
    print(f"  Encoder aggregation: {cgd_model_ptb.encoder.aggregation_method}")
    print(f"  Predictor output features: {cgd_model_ptb.predictor.predictor_mlp[-1].out_features}")
    print(f"  Perturbation types: {cgd_model_ptb.perturbation_generator._get_active_perturbations()}")
    print(f"  Sensitivity measures: {cgd_model_ptb.sensitivity_calculator._get_active_measures()}")
    print(f"  Regularizer type: {cgd_model_ptb.structural_regularizer.regularization_type}")
    print(f"  Model is on device: {next(cgd_model_ptb.parameters()).device}")

except Exception as e:
    print(f"Error instantiating UniversalCGDModel for PTB: {e}")
    import traceback
    traceback.print_exc()
    cgd_model_ptb = None # Ensure it's None if instantiation fails

print("\nCell 17: Setup and Model Configuration for PTB Dataset executed successfully.")

In [None]:
# --- Cell 18: Stage 1 PTB Training (Train Best Base Classifier) - Corrected ---

# Ensure necessary variables from previous cells are available
if 'cgd_model_ptb' not in globals() or cgd_model_ptb is None:
    raise NameError("`cgd_model_ptb` is not defined. Please run Cell 17 with Stage 1 configurations first.")
if 'ptb_loaders' not in globals() or ptb_loaders is None or \
   not all(k in ptb_loaders for k in ['train', 'val', 'test']):
    raise NameError("`ptb_loaders` are not properly defined. Please run Cell 3 (Revised) first.")
if 'GENERAL_TRAINING_CONFIG' not in globals() or 'CGD_MODEL_CONFIG' not in globals() or \
   'ADAPTIVE_THRESHOLD_CONFIG' not in globals() or 'EXPLAINER_CONFIG' not in globals():
    raise NameError("Core GENERAL_TRAINING_CONFIG or other specific configurations are missing. Please re-run Cell 2.")
if 'PTB_CLASS_NAMES' not in globals() or 'PTB_NUM_CLASSES' not in globals():
    raise NameError("`PTB_CLASS_NAMES` or `PTB_NUM_CLASSES` not defined. Please run Cell 17 first.")

print(f"Cell 18 (Stage 1 PTB) will use DEVICE: {DEVICE}")
print(f"Confirming CGD_MODEL_CONFIG for Stage 1: loss_beta={CGD_MODEL_CONFIG.get('loss_beta')}, deferral_threshold_train={CGD_MODEL_CONFIG.get('deferral_threshold_train')}")
if CGD_MODEL_CONFIG.get('loss_beta') != 0.0 or CGD_MODEL_CONFIG.get('deferral_threshold_train') < 1e8 :
    print("WARNING: CGD_MODEL_CONFIG does not appear to be correctly set for Stage 1 training.")
    print("For Stage 1 (Best Base Classifier), 'loss_beta' should be 0.0 and 'deferral_threshold_train' should be float('inf') or very large.")
    print("Please ensure configurations in Cell 2 are set for Stage 1 and re-run Cell 17 before this cell.")


# 1. Train the Base Model for PTB (Focus on classification accuracy)
print("--- Starting PTB Model Training (Stage 1 - Base Classifier) ---")
ptb_base_checkpoint_name = "ptb_base_classifier_best.pt"

cgd_model_ptb.to(DEVICE)

# Define training configuration for Stage 1 (can adjust epochs here)
stage1_ptb_training_config = GENERAL_TRAINING_CONFIG.copy()
# stage1_ptb_training_config['epochs'] = 50 # Example: Train base model for more epochs if needed

# This single call trains the model and saves the best version based on validation loss
# The trained model is returned as 'trained_ptb_base_model'
trained_ptb_base_model, ptb_stage1_history = train_universal_cgd_model(
    model=cgd_model_ptb, # This is the instance of UniversalCGDModel for PTB
    train_loader=ptb_loaders['train'],
    val_loader=ptb_loaders['val'],
    training_config=stage1_ptb_training_config,
    model_specific_checkpoint_name=ptb_base_checkpoint_name # Pass the specific Stage 1 checkpoint name
)
print("--- PTB Model Training (Stage 1) Finished ---")

# 2. Plot Training History for Stage 1
print("\n--- Plotting PTB Stage 1 Training History ---")
if ptb_stage1_history: # Check if history is not empty
    ptb_stage1_metrics_to_plot = [
        {'train_key': 'train_total_loss', 'val_key': 'val_total_loss', 'title': 'PTB Stage 1 Total Loss', 'ylabel': 'Loss'},
        {'train_key': 'train_pred_loss', 'val_key': 'val_pred_loss', 'title': 'PTB Stage 1 Prediction Loss', 'ylabel': 'Pred Loss'},
        {'train_key': 'train_defer_rate_train_time', 'val_key': 'val_defer_rate_train_time', 'title': 'PTB Stage 1 Defer Rate (During Train)', 'ylabel': 'Defer Rate'},
        {'val_key': 'val_accuracy_nd_val_train_thresh', 'title': 'PTB Stage 1 Val Acc (Effectively Overall)', 'ylabel': 'Accuracy'}
    ]
    plot_training_history(
        ptb_stage1_history,
        ptb_stage1_metrics_to_plot,
        save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_stage1_training_history.png")
    )
else:
    print("PTB Stage 1 training history is empty, skipping plot.")

# 3. Evaluate the Stage 1 Base Model on the Test Set
print("\n--- Evaluating PTB Stage 1 Base Model on Test Set (Overall Performance) ---")
ptb_stage1_eval_metrics = {} # Initialize
# Ensure trained_ptb_base_model is used here (it's the output of the training function)
if trained_ptb_base_model is not None and 'test' in ptb_loaders and ptb_loaders['test'] is not None:
    trained_ptb_base_model.to(DEVICE)
    trained_ptb_base_model.eval()

    ptb_stage1_eval_metrics = evaluate_model_with_adaptive_deferral(
        model=trained_ptb_base_model, # Use the model returned from training
        test_loader=ptb_loaders['test'],
        adaptive_threshold=float('inf'), # Ensures no deferral for overall accuracy assessment
        device=DEVICE,
        dataset_name="PTB Test Set (Stage 1 Base Model - Overall)"
    )
    print(f"PTB Stage 1 Base Model - Test Accuracy Overall: {ptb_stage1_eval_metrics.get('accuracy_overall', 'N/A'):.4f}")
    print(f"PTB Stage 1 Base Model - Test AUC Overall: {ptb_stage1_eval_metrics.get('auc_non_deferred', 'N/A'):.4f}")
else:
    print("Skipping Stage 1 PTB model evaluation (trained model or test loader not available).")

# 4. Visualizations for Stage 1 Overall Performance on Test Set
print("\n--- Visualizing PTB Stage 1 Test Set Overall Performance ---")
# Ensure trained_ptb_base_model is used here
if trained_ptb_base_model is not None and 'test' in ptb_loaders and ptb_loaders['test'] is not None and ptb_stage1_eval_metrics:
    all_y_true_s1_viz_ptb = []
    all_y_pred_logits_s1_viz_ptb = []
    trained_ptb_base_model.eval() # Model should already be in eval if loaded, but good practice
    with torch.no_grad():
        for X_b, y_b, p_mask_b in tqdm(ptb_loaders['test'], desc="Fetching PTB Test Data for Stage 1 Plots"):
            X_b, y_b, p_mask_b = X_b.to(DEVICE), y_b.to(DEVICE), p_mask_b.to(DEVICE)
            output = trained_ptb_base_model(X_b, p_mask_b)
            all_y_true_s1_viz_ptb.append(y_b.cpu())
            all_y_pred_logits_s1_viz_ptb.append(output['y_pred_logits'].cpu())

    if all_y_true_s1_viz_ptb: # Check if list is not empty
        y_true_test_np_s1_ptb = torch.cat(all_y_true_s1_viz_ptb).numpy().astype(float)
        y_pred_logits_test_np_s1_ptb = torch.cat(all_y_pred_logits_s1_viz_ptb).numpy()

        y_pred_probs_test_s1_ptb = 1 / (1 + np.exp(-y_pred_logits_test_np_s1_ptb.squeeze()))
        y_pred_classes_test_s1_ptb = (y_pred_probs_test_s1_ptb > 0.5).astype(int)

        plot_confusion_matrix_custom(
            y_true=y_true_test_np_s1_ptb.astype(int),
            y_pred_classes=y_pred_classes_test_s1_ptb,
            class_names=PTB_CLASS_NAMES,
            title="PTB Stage 1 Base Model CM (Overall Test)",
            save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_stage1_cm_overall_test.png")
        )
        plot_roc_auc_curves(
            y_true_list=[y_true_test_np_s1_ptb],
            y_pred_probs_list=[y_pred_probs_test_s1_ptb],
            label_names_list=["PTB Stage 1 Base Model (Overall)"],
            output_dim=PTB_NUM_CLASSES,
            class_names=PTB_CLASS_NAMES,
            title="PTB Stage 1 Base Model ROC Curve (Overall Test)",
            save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_stage1_roc_overall_test.png")
        )
    else:
        print("Skipping PTB Stage 1 test visualizations as test data could not be re-fetched or was empty.")
else:
    print("Skipping PTB Stage 1 visualizations (trained base model or test loader not available, or prior evaluation failed).")

# 5.  Explanations for the Stage 1 Base Model
captum_ready_s1_ptb = 'Saliency' in globals() and Saliency is not None and 'IntegratedGradients' in globals() and IntegratedGradients is not None
explainer_class_ready_s1_ptb = 'CGDExplainer' in globals() and callable(CGDExplainer)

if captum_ready_s1_ptb and explainer_class_ready_s1_ptb and trained_ptb_base_model is not None and \
   'ptb_loaders' in globals() and ptb_loaders['test'] is not None:
    print("\n--- (Optional) Visualizing Explanations for PTB Stage 1 Base Model ---")
    ptb_base_explainer = CGDExplainer(trained_ptb_base_model, EXPLAINER_CONFIG)
    num_samples_to_explain_s1 = 2
    explained_count_s1 = 0
    try:
        for X_ex_b, y_ex_b, p_mask_ex_b in ptb_loaders['test']:
            if explained_count_s1 >= num_samples_to_explain_s1: break
            samples_to_take = min(X_ex_b.size(0), num_samples_to_explain_s1 - explained_count_s1)

            for i_s1 in range(samples_to_take):
                if explained_count_s1 >= num_samples_to_explain_s1: break
                X_s1, y_s1_true, p_mask_s1 = X_ex_b[i_s1:i_s1+1].to(DEVICE), y_ex_b[i_s1].to(DEVICE), p_mask_ex_b[i_s1:i_s1+1].to(DEVICE)

                with torch.no_grad():
                    output_s1 = trained_ptb_base_model(X_s1, p_mask_s1)
                pred_logit_s1 = output_s1['y_pred_logits'].item()
                pred_prob_s1 = torch.sigmoid(torch.tensor(pred_logit_s1)).item()
                pred_class_s1 = 1 if pred_prob_s1 > 0.5 else 0

                true_lbl_name_s1 = PTB_CLASS_NAMES[int(y_s1_true.item())]
                pred_lbl_name_s1 = PTB_CLASS_NAMES[pred_class_s1]

                print(f"Explaining PTB Stage 1 Sample {explained_count_s1+1}: True='{true_lbl_name_s1}', Pred='{pred_lbl_name_s1}' (Prob={pred_prob_s1:.2f})")
                attrs_pred_s1 = ptb_base_explainer.attribute_input(
                    X_s1, p_mask_s1, explanation_target='prediction', target_class_idx=0, abs_attribution=True
                )
                if attrs_pred_s1 is not None:
                    plot_ecg_with_saliency(
                        ecg_signal=X_s1.squeeze().cpu().numpy(), attributions=attrs_pred_s1.squeeze(),
                        title=f"PTB Stage 1 Sample {explained_count_s1+1} - Prediction Explanation",
                        save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], f"ptb_stage1_explain_pred_sample{explained_count_s1}.png"),
                        true_label_name=true_lbl_name_s1, pred_label_name=pred_lbl_name_s1
                    )
                explained_count_s1 += 1
            if explained_count_s1 >= num_samples_to_explain_s1: break
    except Exception as e_s1_explain:
        print(f"Error during PTB Stage 1 explanation: {e_s1_explain}")
else:
    print("\nSkipping PTB Stage 1 explanation visualization (components missing).")

# --- Final Check and Message for Stage 1 PTB Model ---
# ptb_base_checkpoint_name was defined in section 1 of this cell
final_ptb_base_model_path = os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], ptb_base_checkpoint_name)
if trained_ptb_base_model is not None and os.path.exists(final_ptb_base_model_path):
    print(f"\n--- PTB Stage 1 Analysis Complete. Best base model successfully saved to {final_ptb_base_model_path} ---")
    print(f"This PTB base model achieved an overall test accuracy of: {ptb_stage1_eval_metrics.get('accuracy_overall', 'N/A'):.4f}")
else:
    print(f"\n--- PTB Stage 1 Analysis Incomplete. ---")
    if trained_ptb_base_model is None: # Check if the model variable itself is None (e.g. if training call failed)
         print("   Reason: Base model training might not have completed or failed (trained_ptb_base_model is None).")
    elif not os.path.exists(final_ptb_base_model_path): # Check if the file wasn't saved
         print(f"   Reason: Best base model checkpoint NOT FOUND at {final_ptb_base_model_path}")
         print("   Please check training logs from 'train_universal_cgd_model' to ensure the model improved enough to be saved by early stopping.")



In [None]:
# --- Cell 18.A: Define Deferral Predictor Head ---

class DeferralPredictorHead(nn.Module):
    """
    A separate MLP to predict a deferral score.
    Takes features (e.g., latent representation z_original from a base model) as input.
    """
    def __init__(self, input_feature_dim: int, config: Dict[str, Any]):
        super().__init__()
        self.config = config
        # If DEFERRAL_HEAD_CONFIG doesn't exist, use some defaults or PREDICTOR_CONFIG as a template
        if 'DEFERRAL_HEAD_CONFIG' in globals():
            dh_config = DEFERRAL_HEAD_CONFIG
        else: # Fallback to using PREDICTOR_CONFIG structure if DEFERRAL_HEAD_CONFIG not set
            print("Warning: DEFERRAL_HEAD_CONFIG not found in globals. Using PREDICTOR_CONFIG for DeferralPredictorHead structure.")
            dh_config = PREDICTOR_CONFIG

        hidden_dims = dh_config.get('hidden_dims_deferral_head',
                                    dh_config.get('hidden_dims', # Fallback to general hidden_dims
                                                  [max(input_feature_dim // 2, 32), max(input_feature_dim // 4, 16)]))
        dropout_rate = dh_config.get('dropout_deferral_head', dh_config.get('dropout', 0.1))

        layers = []
        current_dim = input_feature_dim
        for h_dim in hidden_dims:
            if h_dim <= 0: h_dim = max(16, current_dim //2) # Ensure h_dim is positive and reasonable
            layers.append(nn.Linear(current_dim, h_dim))
            layers.append(nn.BatchNorm1d(h_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            current_dim = h_dim

        layers.append(nn.Linear(current_dim, 1)) # Single output logit for deferral
        self.mlp = nn.Sequential(*layers)
        print(f"DeferralPredictorHead initialized with input_dim={input_feature_dim}, hidden_dims={hidden_dims}")

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            features: Input features for deferral prediction [batch_size, input_feature_dim]
                      (e.g., z_original from the base model's encoder)
        Returns:
            deferral_logits: Logits for deferral [batch_size, 1]
        """
        return self.mlp(features)

# Example: Add DEFERRAL_HEAD_CONFIG to Cell 2 if you haven't
# DEFERRAL_HEAD_CONFIG = {
print("\nCell 18.A: DeferralPredictorHead class defined.")

In [None]:
# --- Cell 18.B: Prepare Data for Deferral Head Training ---

def create_deferral_head_dataset_and_loader(
    base_model: UniversalCGDModel,      # Trained and frozen base model from Stage 1
    original_dataloader: DataLoader,    # e.g., mitbih_loaders['train'] or mitbih_loaders['val']
    device: torch.device,
    target_type: str = 'error_prediction', # 'error_prediction' or 'sensitivity_regression'
    # For 'sensitivity_regression', you might need these:
    # strong_perturb_config: Optional[Dict] = None,
    # strong_sensitivity_config: Optional[Dict] = None
    batch_size: int = 128, # Use batch_size from DEFERRAL_HEAD_CONFIG if available
    num_workers: int = 2
) -> Optional[DataLoader]:
    """
    Generates a new dataset and DataLoader for training the DeferralPredictorHead.
    The dataset consists of (features_for_head, deferral_targets).
    Features are typically z_original from the base_model.
    Targets are '1' if base_model made an error, '0' otherwise (for error_prediction).
    """
    base_model.eval() # Ensure base model is in eval mode
    base_model.to(device)

    all_features_for_head = []
    all_deferral_targets = []

    print(f"Creating dataset for DeferralPredictorHead using target_type='{target_type}'...")
    with torch.no_grad():
        for X_batch, y_true_batch, padding_mask_batch in tqdm(original_dataloader, desc="Generating Deferral Head Data"):
            X_batch = X_batch.to(device)
            y_true_batch = y_true_batch.to(device)
            padding_mask_batch = padding_mask_batch.to(device)

            # Get z_original and base model predictions
            base_model_output = base_model(X_batch, padding_mask_batch=padding_mask_batch)
            z_original = base_model_output['z_original'].detach().cpu() # [B, LatentDim]
            y_pred_logits_base = base_model_output['y_pred_logits'].detach().cpu() # [B, NumClasses]

            all_features_for_head.append(z_original)

            if target_type == 'error_prediction':
                if base_model.output_dim_val == 1: # Binary
                    preds_base_classes = (y_pred_logits_base.sigmoid() > 0.5).float().squeeze()
                    errors = (preds_base_classes != y_true_batch.cpu().float()).float()
                else: # Multiclass
                    preds_base_classes = torch.argmax(y_pred_logits_base, dim=1)
                    errors = (preds_base_classes != y_true_batch.cpu().long()).float()
                all_deferral_targets.append(errors)
            else:
                raise ValueError(f"Unsupported target_type for DeferralPredictorHead: {target_type}")

    if not all_features_for_head:
        print("No data generated for deferral head.")
        return None

    cat_features = torch.cat(all_features_for_head, dim=0)
    cat_targets = torch.cat(all_deferral_targets, dim=0)

    # Ensure targets are correctly shaped for BCEWithLogitsLoss if it's error prediction
    if target_type == 'error_prediction':
        cat_targets = cat_targets.unsqueeze(1) # [N, 1]

    deferral_head_dataset = TensorDataset(cat_features, cat_targets) # Simpler Dataset for this
    deferral_head_loader = DataLoader(deferral_head_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    print(f"Deferral Head Dataset created: {len(deferral_head_dataset)} samples. Target '1' proportion (errors/high_sens): {cat_targets.mean().item():.3f}")
    return deferral_head_loader


# Define TensorDataset if not already available from torch.utils.data
from torch.utils.data import TensorDataset

print("\nCell 18.B: Utilities for Deferral Head Data Preparation defined.")

In [None]:
# --- Cell 18.C: Training Loop for the Deferral Head ---

def train_deferral_head_epoch_fn( # Renamed to avoid conflict with any previous train_epoch
    deferral_head: DeferralPredictorHead,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    epoch_num: int,
    total_epochs: int
) -> float:
    """Trains the DeferralPredictorHead for one epoch."""
    deferral_head.train()
    running_loss = 0.0
    total_samples = 0

    pbar = tqdm(dataloader, desc=f"Deferral Head Train Epoch {epoch_num+1}/{total_epochs}")
    for features, targets in pbar:
        features, targets = features.to(device), targets.to(device)

        optimizer.zero_grad()
        logits = deferral_head(features) # deferral_logits [B, 1]
        loss = criterion(logits, targets) # targets should be [B, 1] and float for BCEWithLogitsLoss

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: NaN or Inf loss detected in deferral head training epoch {epoch_num+1}. Skipping batch.")
            continue

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * features.size(0)
        total_samples += features.size(0)
        pbar.set_postfix({'loss': f"{running_loss/total_samples:.4f}"})

    return running_loss / total_samples if total_samples > 0 else 0.0

def validate_deferral_head_epoch_fn( # Renamed
    deferral_head: DeferralPredictorHead,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    epoch_num: int,
    total_epochs: int
) -> Tuple[float, float]:
    """Validates the DeferralPredictorHead for one epoch."""
    deferral_head.eval()
    running_loss = 0.0
    total_samples = 0
    all_preds = []
    all_targets = []

    pbar = tqdm(dataloader, desc=f"Deferral Head Val Epoch {epoch_num+1}/{total_epochs}")
    with torch.no_grad():
        for features, targets in pbar:
            features, targets = features.to(device), targets.to(device)
            logits = deferral_head(features)
            loss = criterion(logits, targets)

            if not (torch.isnan(loss) or torch.isinf(loss)):
                running_loss += loss.item() * features.size(0)
                total_samples += features.size(0)

                # Store predictions and targets for accuracy calculation
                # Sigmoid + threshold for binary classification (error vs not error)
                preds = (torch.sigmoid(logits) > 0.5).float()
                all_preds.append(preds.cpu())
                all_targets.append(targets.cpu())

            pbar.set_postfix({'val_loss': f"{running_loss/total_samples:.4f}" if total_samples > 0 else "N/A"})

    avg_loss = running_loss / total_samples if total_samples > 0 else float('inf')

    accuracy = 0.0
    if all_preds and all_targets:
        cat_preds = torch.cat(all_preds)
        cat_targets = torch.cat(all_targets)
        if cat_targets.numel() > 0: # Ensure there are samples
             accuracy = accuracy_score(cat_targets.numpy(), cat_preds.numpy())

    return avg_loss, accuracy


def train_separate_deferral_head_model(
    base_model_frozen: UniversalCGDModel, # Trained and FROZEN base model
    deferral_head_input_dim: int,         # Typically base_model.latent_dim
    # DataLoaders for the *original* task, used to generate deferral head training data
    original_train_loader_for_deferral_data: DataLoader,
    original_val_loader_for_deferral_data: DataLoader,
    # Configs
    deferral_head_config: Dict[str, Any], # DEFERRAL_HEAD_CONFIG from Cell 2
    general_training_config: Dict[str, Any], # For checkpoint_dir
    device: torch.device,
    deferral_head_checkpoint_name: str = "deferral_head_best.pt"
) -> Tuple[Optional[DeferralPredictorHead], Dict[str, List[float]]]:
    """
    Trains a separate DeferralPredictorHead using features from a frozen base model.
    """
    print(f"--- Starting DeferralPredictorHead Training for {deferral_head_checkpoint_name} ---")

    # Ensure base model is frozen
    for param in base_model_frozen.parameters():
        param.requires_grad = False
    base_model_frozen.eval()

    # 1. Prepare data for the Deferral Head
    # Use a portion of the original training data to create training data for the deferral head
    # And use the original validation data to create validation data for the deferral head
    print("Generating training data for Deferral Head...")
    dh_train_loader = create_deferral_head_dataset_and_loader(
        base_model=base_model_frozen,
        original_dataloader=original_train_loader_for_deferral_data,
        device=device,
        target_type=deferral_head_config.get('target_type', 'error_prediction'),
        batch_size=deferral_head_config.get('batch_size', 128)
    )
    print("Generating validation data for Deferral Head...")
    dh_val_loader = create_deferral_head_dataset_and_loader(
        base_model=base_model_frozen,
        original_dataloader=original_val_loader_for_deferral_data,
        device=device,
        target_type=deferral_head_config.get('target_type', 'error_prediction'),
        batch_size=deferral_head_config.get('batch_size', 128)
    )

    if dh_train_loader is None or dh_val_loader is None:
        print("Could not create DataLoaders for Deferral Head. Aborting training.")
        return None, {}

    # 2. Initialize Deferral Head and Optimizer
    deferral_head = DeferralPredictorHead(
        input_feature_dim=deferral_head_input_dim,
        config=deferral_head_config
    ).to(device)

    optimizer = optim.Adam(deferral_head.parameters(), lr=deferral_head_config.get('learning_rate', 1e-3))

    # For error_prediction (binary 0/1 targets), BCEWithLogitsLoss is appropriate
    # For sensitivity_regression, MSELoss would be used.
    if deferral_head_config.get('target_type', 'error_prediction') == 'error_prediction':
        criterion = nn.BCEWithLogitsLoss()
    # elif deferral_head_config.get('target_type') == 'sensitivity_regression':
    #     criterion = nn.MSELoss()
    else:
        raise ValueError("Unsupported target_type in DEFERRAL_HEAD_CONFIG for loss criterion.")

    epochs = deferral_head_config.get('epochs', 10)
    patience = deferral_head_config.get('patience', 3)
    checkpoint_path = os.path.join(general_training_config['checkpoint_dir'], deferral_head_checkpoint_name)

    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []} # Accuracy of predicting base model's errors
    best_val_loss = float('inf')
    epochs_no_improve = 0

    print(f"Training DeferralPredictorHead for up to {epochs} epochs...")
    print(f"Best DeferralPredictorHead will be saved to: {checkpoint_path}")

    for epoch in range(epochs):
        train_loss = train_deferral_head_epoch_fn(deferral_head, dh_train_loader, optimizer, criterion, device, epoch, epochs)
        val_loss, val_accuracy = validate_deferral_head_epoch_fn(deferral_head, dh_val_loader, criterion, device, epoch, epochs)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_accuracy)

        print(f"Deferral Head Epoch {epoch+1}/{epochs} Summary: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc(ErrorPred)={val_accuracy:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(deferral_head.state_dict(), checkpoint_path)
            print(f"  -> New best val_loss for Deferral Head: {best_val_loss:.4f}. Checkpoint saved.")
        else:
            epochs_no_improve += 1
            print(f"  -> Val_loss for Deferral Head did not improve for {epochs_no_improve} epoch(s). Best: {best_val_loss:.4f}")

        if epochs_no_improve >= patience:
            print(f"Early stopping for Deferral Head triggered after {epoch+1} epochs.")
            break

    print("DeferralPredictorHead training finished.")
    if os.path.exists(checkpoint_path):
        print(f"Loading best DeferralPredictorHead state from {checkpoint_path}")
        deferral_head.load_state_dict(torch.load(checkpoint_path, map_location=device))
    else:
        print("Warning: No best DeferralPredictorHead checkpoint found. Using head from last training state.")

    return deferral_head, history

# DEFERRAL_HEAD_CONFIG
if 'DEFERRAL_HEAD_CONFIG' not in globals() :
    DEFERRAL_HEAD_CONFIG = {
        'hidden_dims_deferral_head': [CGD_MODEL_CONFIG['latent_dim'] // 2, CGD_MODEL_CONFIG['latent_dim'] // 4], # e.g. [32, 16]
        'dropout_deferral_head': 0.2,
        'learning_rate': 1e-3,
        'epochs': 20,
        'batch_size': 256,
        'patience': 5,
        'target_type': 'error_prediction'
    }
    print("Using default DEFERRAL_HEAD_CONFIG for testing Cell 18.C")


# --- Conceptual Test (Actual run happens in dataset-specific cells later) ---
if __name__ == '__main__':
    print("\n--- Testing Deferral Head Training Loop (Conceptual) ---")
    # This requires a fully trained and frozen base_model and original dataloaders.
    # For now, just defining the functions.
    if 'cgd_model_mitbih' in globals() and cgd_model_mitbih is not None and \
       'mitbih_loaders' in globals() and mitbih_loaders is not None:
        print("Conceptual test: Would train a separate deferral head for MIT-BIH.")
        print("Deferral Head training loop functions defined.")
    else:
        print("Skipping conceptual deferral head training test as base model/loaders not fully set up.")

print("\nCell 18.C: Training Loop for Deferral Head defined.")

In [None]:
# --- Cell 18.D: Evaluation of Full System (Frozen Base Model + Trained Deferral Head) ---

def compute_adaptive_threshold_for_deferral_head(
    base_model_frozen: UniversalCGDModel,
    deferral_head: DeferralPredictorHead,
    original_val_loader: DataLoader, # e.g., mitbih_loaders['val']
    adaptive_thresh_config: Dict[str, Any],
    device: torch.device
) -> float:
    """
    Computes an adaptive deferral threshold based on the DeferralPredictorHead's scores
    on the original validation set.
    """
    base_model_frozen.eval()
    deferral_head.eval()

    all_deferral_head_scores = []
    all_y_true_val = []
    all_y_pred_logits_base_val = [] # Base model's predictions

    print("Computing adaptive threshold for Deferral Head using validation set...")
    with torch.no_grad():
        for X_batch, y_true_batch, padding_mask_batch in tqdm(original_val_loader, desc="DH Adaptive Threshold Calc"):
            X_batch, y_true_batch, padding_mask_batch = X_batch.to(device), y_true_batch.to(device), padding_mask_batch.to(device)

            # Get features for deferral head (z_original from base model)
            base_model_output = base_model_frozen(X_batch, padding_mask_batch=padding_mask_batch)
            z_original = base_model_output['z_original']
            y_pred_logits_base = base_model_output['y_pred_logits']

            # Get deferral head's score
            deferral_logits_head = deferral_head(z_original)
            deferral_scores_head = torch.sigmoid(deferral_logits_head).squeeze(-1) # Squeeze last dim if it's [B,1]

            all_deferral_head_scores.append(deferral_scores_head.cpu())
            all_y_true_val.append(y_true_batch.cpu())
            all_y_pred_logits_base_val.append(y_pred_logits_base.cpu())

    if not all_deferral_head_scores:
        print("Warning: No deferral head scores collected. Returning default threshold 0.5.")
        return 0.5

    deferral_head_scores_np = torch.cat(all_deferral_head_scores).numpy()
    y_true_np = torch.cat(all_y_true_val).numpy()
    y_pred_logits_base_np = torch.cat(all_y_pred_logits_base_val).numpy()

    if deferral_head_scores_np.size == 0:
        print("Warning: Deferral head scores array is empty. Returning default 0.5.")
        return 0.5

    method = adaptive_thresh_config.get('method', 'percentile')
    num_candidates = adaptive_thresh_config.get('num_threshold_candidates', 200)

    candidate_thresholds = np.unique(deferral_head_scores_np) # Use unique scores from deferral head
    if len(candidate_thresholds) > num_candidates:
        candidate_thresholds = np.percentile(deferral_head_scores_np, np.linspace(0, 100, num_candidates))
    elif len(candidate_thresholds) == 0:
         print("Warning: No unique deferral head scores found. Returning default 0.5")
         return 0.5

    optimal_threshold = 0.5 # Default

    if method == 'percentile':
        percentile_val = adaptive_thresh_config.get('percentile_value_for_threshold', 90)
        # Threshold is the score at this percentile. Defer if score > threshold.
        optimal_threshold = np.percentile(deferral_head_scores_np, percentile_val)
        print(f"Deferral Head adaptive threshold (percentile {percentile_val}%): {optimal_threshold:.4f}")

    elif method == 'target_defer_rate':
        target_dr = adaptive_thresh_config.get('target_defer_rate_value', 0.10)
        best_thresh_for_target_dr = candidate_thresholds[0]
        min_dr_diff = float('inf')
        for thresh in candidate_thresholds:
            deferred_mask = deferral_head_scores_np > thresh # Using deferral head scores
            current_dr = np.mean(deferred_mask)
            dr_diff = abs(current_dr - target_dr)
            if dr_diff < min_dr_diff:
                min_dr_diff = dr_diff
                best_thresh_for_target_dr = thresh
            elif dr_diff == min_dr_diff and current_dr < target_dr:
                best_thresh_for_target_dr = thresh
        optimal_threshold = best_thresh_for_target_dr
        final_dr = np.mean(deferral_head_scores_np > optimal_threshold)
        print(f"Deferral Head adaptive threshold (target DR ~{target_dr*100:.1f}%): {optimal_threshold:.4f} (results in actual DR: {final_dr*100:.1f}%)")

    elif method == 'max_acc_under_budget':
        max_budget_dr = adaptive_thresh_config.get('max_defer_rate_budget', 0.20)
        best_thresh_for_acc = candidate_thresholds[-1]
        max_acc_nd = -1.0

        # Base model's predictions on the validation set
        if base_model_frozen.output_dim_val == 1: # Binary
            base_preds_classes_val = (torch.sigmoid(torch.from_numpy(y_pred_logits_base_np)).squeeze().numpy() > 0.5).astype(int)
            base_true_labels_val = y_true_np.astype(float)
        else: # Multiclass
            base_preds_classes_val = np.argmax(y_pred_logits_base_np, axis=1)
            base_true_labels_val = y_true_np.astype(int)

        for thresh in candidate_thresholds:
            deferred_mask = deferral_head_scores_np > thresh # Using deferral head scores
            current_dr = np.mean(deferred_mask)
            if current_dr <= max_budget_dr:
                non_deferred_mask = ~deferred_mask
                if np.sum(non_deferred_mask) > 0:
                    acc_nd = accuracy_score(base_true_labels_val[non_deferred_mask], base_preds_classes_val[non_deferred_mask])
                    if acc_nd > max_acc_nd:
                        max_acc_nd = acc_nd
                        best_thresh_for_acc = thresh
                    elif acc_nd == max_acc_nd and thresh > best_thresh_for_acc:
                         best_thresh_for_acc = thresh
        optimal_threshold = best_thresh_for_acc
        final_dr = np.mean(deferral_head_scores_np > optimal_threshold)
        final_acc_nd = -1.0
        if np.sum(~(deferral_head_scores_np > optimal_threshold)) > 0:
            final_acc_nd = accuracy_score(base_true_labels_val[~(deferral_head_scores_np > optimal_threshold)], base_preds_classes_val[~(deferral_head_scores_np > optimal_threshold)])
        print(f"Deferral Head adaptive threshold (max Acc_ND under {max_budget_dr*100:.1f}% DR): {optimal_threshold:.4f}")
        print(f"  Results in Val: Actual DR={final_dr*100:.1f}%, Base Model Acc_ND={final_acc_nd*100:.2f}%")
    else:
        print(f"Warning: Unknown adaptive threshold method '{method}'. Using default 0.5.")
        optimal_threshold = 0.5

    return float(optimal_threshold)


def evaluate_system_with_deferral_head(
    base_model_frozen: UniversalCGDModel,
    deferral_head: DeferralPredictorHead,
    test_loader: DataLoader,
    adaptive_threshold_for_head: float,
    device: torch.device,
    dataset_name: str = "Test Set"
) -> Dict[str, Any]:
    """
    Evaluates the combined system: frozen base model + trained deferral head.
    """
    base_model_frozen.eval()
    deferral_head.eval()

    all_y_true_test = []
    all_y_pred_logits_base_test = [] # From the frozen base model
    all_deferral_head_scores_test = []

    print(f"\nEvaluating system on {dataset_name} with Deferral Head threshold: {adaptive_threshold_for_head:.4f}")
    with torch.no_grad():
        for X_batch, y_true_batch, padding_mask_batch in tqdm(test_loader, desc=f"Evaluating System on {dataset_name}"):
            X_batch, y_true_batch, padding_mask_batch = X_batch.to(device), y_true_batch.to(device), padding_mask_batch.to(device)

            # Get features for deferral head (z_original from base model) and base predictions
            base_model_output = base_model_frozen(X_batch, padding_mask_batch=padding_mask_batch)
            z_original = base_model_output['z_original']
            y_pred_logits_base = base_model_output['y_pred_logits']

            # Get deferral head's score
            deferral_logits_head = deferral_head(z_original) # Pass appropriate features
            deferral_scores_head = torch.sigmoid(deferral_logits_head).squeeze(-1)

            all_y_true_test.append(y_true_batch.cpu())
            all_y_pred_logits_base_test.append(y_pred_logits_base.cpu())
            all_deferral_head_scores_test.append(deferral_scores_head.cpu())

    if not all_y_true_test:
        print(f"Warning: No data processed during system evaluation for {dataset_name}. Returning empty metrics.")
        return {}

    y_true_np = torch.cat(all_y_true_test).numpy()
    y_pred_logits_base_np = torch.cat(all_y_pred_logits_base_test).numpy()
    deferral_head_scores_np = torch.cat(all_deferral_head_scores_test).numpy()

    # --- Determine Base Model Predictions and Errors (Overall) ---
    is_binary_classification = base_model_frozen.output_dim_val == 1
    if is_binary_classification:
        y_pred_probs_base_overall = 1 / (1 + np.exp(-y_pred_logits_base_np.squeeze())) # Sigmoid
        y_pred_classes_base_overall = (y_pred_probs_base_overall > 0.5).astype(int)
        y_true_labels_overall = y_true_np.astype(float)
    else: # Multiclass
        y_pred_probs_base_overall = F.softmax(torch.from_numpy(y_pred_logits_base_np), dim=1).numpy()
        y_pred_classes_base_overall = np.argmax(y_pred_logits_base_np, axis=1)
        y_true_labels_overall = y_true_np.astype(int)

    errors_base_overall = (y_pred_classes_base_overall != y_true_labels_overall).astype(int)
    accuracy_base_overall = accuracy_score(y_true_labels_overall, y_pred_classes_base_overall)

    # --- Apply Deferral using Deferral Head scores and its adaptive threshold ---
    defer_mask_eval_head = deferral_head_scores_np > adaptive_threshold_for_head
    defer_rate_eval_head = np.mean(defer_mask_eval_head)
    non_deferred_mask_eval_head = ~defer_mask_eval_head

    metrics = {
        'dataset_name': dataset_name,
        'adaptive_deferral_threshold_head': adaptive_threshold_for_head,
        'accuracy_base_overall': accuracy_base_overall, # Base model's performance if no deferral
        'defer_rate_eval_head': defer_rate_eval_head,
        'total_samples': len(y_true_np),
        'num_deferred_eval_head': np.sum(defer_mask_eval_head),
        'num_non_deferred_eval_head': np.sum(non_deferred_mask_eval_head),
    }

    # --- Metrics for Non-Deferred Samples (using Base Model's predictions) ---
    if np.sum(non_deferred_mask_eval_head) > 0:
        y_true_nd = y_true_labels_overall[non_deferred_mask_eval_head]
        y_pred_classes_base_nd = y_pred_classes_base_overall[non_deferred_mask_eval_head]
        y_pred_probs_base_nd = y_pred_probs_base_overall[non_deferred_mask_eval_head] # For AUC

        metrics['accuracy_non_deferred_system'] = accuracy_score(y_true_nd, y_pred_classes_base_nd)
        if is_binary_classification:
            metrics['precision_non_deferred_system'] = precision_score(y_true_nd, y_pred_classes_base_nd, zero_division=0)
            metrics['recall_non_deferred_system'] = recall_score(y_true_nd, y_pred_classes_base_nd, zero_division=0)
            metrics['f1_score_non_deferred_system'] = f1_score(y_true_nd, y_pred_classes_base_nd, zero_division=0)
            if len(np.unique(y_true_nd)) > 1:
                 metrics['auc_non_deferred_system'] = roc_auc_score(y_true_nd, y_pred_probs_base_nd)
            else:
                 metrics['auc_non_deferred_system'] = float('nan')
        else: # Multiclass
            metrics['precision_non_deferred_system'] = precision_score(y_true_nd, y_pred_classes_base_nd, average='weighted', zero_division=0)
            metrics['recall_non_deferred_system'] = recall_score(y_true_nd, y_pred_classes_base_nd, average='weighted', zero_division=0)
            metrics['f1_score_non_deferred_system'] = f1_score(y_true_nd, y_pred_classes_base_nd, average='weighted', zero_division=0)
            if len(np.unique(y_true_nd)) >= base_model_frozen.output_dim_val and base_model_frozen.output_dim_val > 1:
                try:
                    metrics['auc_non_deferred_system'] = roc_auc_score(y_true_nd, y_pred_probs_base_nd, multi_class='ovr', average='macro')
                except ValueError as e_auc:
                    print(f"Warning: Could not compute multiclass AUC for non-deferred (system): {e_auc}")
                    metrics['auc_non_deferred_system'] = float('nan')
            else:
                metrics['auc_non_deferred_system'] = float('nan')
    else:
        metrics.update({
            'accuracy_non_deferred_system': float('nan'), 'precision_non_deferred_system': float('nan'),
            'recall_non_deferred_system': float('nan'), 'f1_score_non_deferred_system': float('nan'),
            'auc_non_deferred_system': float('nan')
        })

    # --- Metrics for Deferred Samples (what the Base Model would have predicted) ---
    if np.sum(defer_mask_eval_head) > 0:
        y_true_d = y_true_labels_overall[defer_mask_eval_head]
        y_pred_classes_base_d = y_pred_classes_base_overall[defer_mask_eval_head]
        metrics['accuracy_base_on_deferred_samples'] = accuracy_score(y_true_d, y_pred_classes_base_d)
    else:
        metrics['accuracy_base_on_deferred_samples'] = float('nan')

    # --- Correlation between Deferral Head Score and Base Model Error ---
    if len(deferral_head_scores_np) > 1 and len(errors_base_overall) > 1 and \
       np.std(deferral_head_scores_np) > 0 and np.std(errors_base_overall) > 0:
        metrics['deferral_head_score_error_correlation'] = np.corrcoef(deferral_head_scores_np, errors_base_overall)[0, 1]
    else:
        metrics['deferral_head_score_error_correlation'] = float('nan')

    print(f"\n--- System Evaluation Metrics for {dataset_name} (Base Model + Deferral Head) ---")
    for key, value in metrics.items():
        if isinstance(value, float):
            print(f"  {key:<45}: {value:.4f}")
        else:
            print(f"  {key:<45}: {value}")
    print("--------------------------------------------------------------------")

    return metrics

# --- Conceptual Test (Actual run happens in dataset-specific Stage 2 execution cells) ---
if __name__ == '__main__':
    print("\n--- Testing Deferral Head Evaluation Utilities (Conceptual) ---")
    # This requires a trained base_model, trained deferral_head, and dataloaders.
    if 'trained_mitbih_model' in globals() and trained_mitbih_model is not None and \
       'mitbih_loaders' in globals() and mitbih_loaders is not None and 'val' in mitbih_loaders and \
       'DeferralPredictorHead' in globals() and 'DEFERRAL_HEAD_CONFIG' in globals():

        print("Conceptual test: Would compute adaptive threshold for a dummy deferral head and evaluate system.")
        print("Deferral Head evaluation and thresholding functions defined.")
    else:
        print("Skipping conceptual Deferral Head evaluation test: missing components.")

print("\nCell 18.D: Evaluation Utilities for System with Deferral Head defined.")

In [None]:
# --- Cell 19: Execute Stage 2 for MIT-BIH (Train Deferral Head & Evaluate System) ---

print("--- Starting Stage 2 for MIT-BIH Dataset ---")

# Ensure necessary variables from previous cells are available
if 'UniversalCGDModel' not in globals() or 'DeferralPredictorHead' not in globals() or \
   'train_separate_deferral_head_model' not in globals() or \
   'compute_adaptive_threshold_for_deferral_head' not in globals() or \
   'evaluate_system_with_deferral_head' not in globals():
    raise NameError("One or more required model/training/evaluation functions are not defined. Please run Cells 9, 18.A, 18.C, 18.D.")

if 'mitbih_loaders' not in globals() or mitbih_loaders is None or \
   not all(k in mitbih_loaders for k in ['train', 'val', 'test']):
    raise NameError("`mitbih_loaders` are not properly defined. Please ensure Cell 3 (Revised) was run successfully.")

if 'CGD_MODEL_CONFIG' not in globals() or 'ENCODER_CONFIG' not in globals() or \
   'PREDICTOR_CONFIG' not in globals() or 'DEFERRAL_HEAD_CONFIG' not in globals() or \
   'ADAPTIVE_THRESHOLD_CONFIG' not in globals() or 'GENERAL_TRAINING_CONFIG' not in globals() or \
   'MITBIH_NUM_CLASSES' not in globals() or 'MITBIH_CLASS_NAMES' not in globals():
    raise NameError("One or more required configurations or dataset-specific variables are missing. Please re-run relevant setup cells (2, 15, 18.A example).")

# --- 1. Load the Best Trained Base Classifier from Stage 1 ---
base_model_mitbih_path = os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_base_classifier_best.pt")
trained_base_mitbih_model = None

if os.path.exists(base_model_mitbih_path):
    print(f"Loading Stage 1 base MIT-BIH model from: {base_model_mitbih_path}")
    trained_base_mitbih_model = UniversalCGDModel(
        model_config=CGD_MODEL_CONFIG, # Should have loss_beta=0 for Stage 1 training
        encoder_config=ENCODER_CONFIG,
        predictor_config=PREDICTOR_CONFIG,
        perturb_config=PERTURBATION_CONFIG, # Perturb config used during Stage 1 (might be mild)
        sensitivity_config=SENSITIVITY_CONFIG, # Sensitivity config used during Stage 1
        regularizer_config=STRUCTURAL_REGULARIZER_CONFIG,
        output_dim=MITBIH_NUM_CLASSES
    ).to(DEVICE)
    try:
        trained_base_mitbih_model.load_state_dict(torch.load(base_model_mitbih_path, map_location=DEVICE))
        trained_base_mitbih_model.eval() # Set to eval mode
        # Freeze all parameters of the base model
        for param in trained_base_mitbih_model.parameters():
            param.requires_grad = False
        print("Stage 1 MIT-BIH base model loaded and frozen successfully.")
    except Exception as e:
        print(f"Error loading Stage 1 MIT-BIH base model: {e}. Cannot proceed with Stage 2.")
        trained_base_mitbih_model = None
else:
    print(f"ERROR: Stage 1 MIT-BIH base model checkpoint not found at {base_model_mitbih_path}. Please complete Stage 1 training first.")

# --- 2. Train the Separate Deferral Predictor Head ---
trained_deferral_head_mitbih = None
mitbih_deferral_head_history = {}

if trained_base_mitbih_model is not None:
    print("\n--- Training Separate Deferral Head for MIT-BIH ---")
    deferral_head_checkpoint_name = "mitbih_deferral_head_best.pt"

    dh_input_dim = trained_base_mitbih_model.latent_dim

    trained_deferral_head_mitbih, mitbih_deferral_head_history = train_separate_deferral_head_model(
        base_model_frozen=trained_base_mitbih_model,
        deferral_head_input_dim=dh_input_dim,
        original_train_loader_for_deferral_data=mitbih_loaders['train'],
        original_val_loader_for_deferral_data=mitbih_loaders['val'],
        deferral_head_config=DEFERRAL_HEAD_CONFIG,
        general_training_config=GENERAL_TRAINING_CONFIG,
        device=DEVICE,
        deferral_head_checkpoint_name=deferral_head_checkpoint_name
    )
    print("--- MIT-BIH Deferral Head Training Finished ---")

    if trained_deferral_head_mitbih and mitbih_deferral_head_history:
        plot_training_history(
            mitbih_deferral_head_history,
            metrics_to_plot=[
                {'train_key': 'train_loss', 'val_key': 'val_loss', 'title': 'MIT-BIH Deferral Head Loss', 'ylabel': 'BCE Loss'},
                {'val_key': 'val_accuracy', 'title': 'MIT-BIH Deferral Head Val Acc (Error Pred.)', 'ylabel': 'Accuracy'}
            ],
            save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_deferral_head_history.png")
        )
else:
    print("Skipping Deferral Head training as Stage 1 base model was not loaded.")

# --- 3. Compute Adaptive Threshold for the Trained Deferral Head ---
mitbih_adaptive_threshold_for_head = None
if trained_base_mitbih_model is not None and trained_deferral_head_mitbih is not None:
    print("\n--- Computing Adaptive Threshold for MIT-BIH Deferral Head ---")
    mitbih_adaptive_threshold_for_head = compute_adaptive_threshold_for_deferral_head(
        base_model_frozen=trained_base_mitbih_model,
        deferral_head=trained_deferral_head_mitbih,
        original_val_loader=mitbih_loaders['val'],
        adaptive_thresh_config=ADAPTIVE_THRESHOLD_CONFIG,
        device=DEVICE
    )
    print(f"Computed MIT-BIH Adaptive Threshold for Deferral Head: {mitbih_adaptive_threshold_for_head:.4f}")
else:
    print("Skipping adaptive threshold computation for deferral head (model or head missing).")


# --- 4. Evaluate the Full System (Base Model + Deferral Head) on Test Set ---
mitbih_system_eval_metrics = {} # Initialize to ensure it exists
if trained_base_mitbih_model is not None and trained_deferral_head_mitbih is not None and mitbih_adaptive_threshold_for_head is not None:
    print("\n--- Evaluating Full System (Base Model + Deferral Head) on MIT-BIH Test Set ---")
    mitbih_system_eval_metrics = evaluate_system_with_deferral_head(
        base_model_frozen=trained_base_mitbih_model,
        deferral_head=trained_deferral_head_mitbih,
        test_loader=mitbih_loaders['test'],
        adaptive_threshold_for_head=mitbih_adaptive_threshold_for_head,
        device=DEVICE,
        dataset_name="MIT-BIH Test System"
    )
else:
    print("Skipping full system evaluation (components missing).")


# --- 5. Visualize Final System Performance on MIT-BIH Test Set ---
if mitbih_system_eval_metrics:
    print("\n--- Visualizing Final MIT-BIH System Test Set Performance ---")

    all_y_true_test_viz_sys = []
    all_y_pred_logits_base_test_viz_sys = []
    all_deferral_head_scores_test_viz_sys = []

    # Ensure models are in eval mode for this visualization data fetching
    if trained_base_mitbih_model: trained_base_mitbih_model.eval()
    if trained_deferral_head_mitbih: trained_deferral_head_mitbih.eval()

    with torch.no_grad():
        for X_b, y_b, p_mask_b in tqdm(mitbih_loaders['test'], desc="Re-fetching Test Data for System Plots"):
            X_b, y_b, p_mask_b = X_b.to(DEVICE), y_b.to(DEVICE), p_mask_b.to(DEVICE)
            base_output = trained_base_mitbih_model(X_b, p_mask_b)
            z_original_b = base_output['z_original']
            deferral_logits_head_b = trained_deferral_head_mitbih(z_original_b)
            deferral_scores_head_b = torch.sigmoid(deferral_logits_head_b).squeeze(-1)

            all_y_true_test_viz_sys.append(y_b.cpu())
            all_y_pred_logits_base_test_viz_sys.append(base_output['y_pred_logits'].cpu())
            all_deferral_head_scores_test_viz_sys.append(deferral_scores_head_b.cpu())

    if all_y_true_test_viz_sys:
        y_true_test_np_sys = torch.cat(all_y_true_test_viz_sys).numpy().astype(int)
        y_pred_logits_base_test_np_sys = torch.cat(all_y_pred_logits_base_test_viz_sys).numpy()
        deferral_head_scores_test_np_sys = torch.cat(all_deferral_head_scores_test_viz_sys).numpy()

        y_pred_probs_base_test_overall_sys = F.softmax(torch.from_numpy(y_pred_logits_base_test_np_sys), dim=1).numpy()
        y_pred_classes_base_test_overall_sys = np.argmax(y_pred_logits_base_test_np_sys, axis=1)

        defer_mask_test_eval_head_sys = deferral_head_scores_test_np_sys > mitbih_adaptive_threshold_for_head
        non_deferred_mask_test_eval_head_sys = ~defer_mask_test_eval_head_sys

        if np.sum(non_deferred_mask_test_eval_head_sys) > 0:
            plot_confusion_matrix_custom(
                y_true=y_true_test_np_sys[non_deferred_mask_test_eval_head_sys],
                y_pred_classes=y_pred_classes_base_test_overall_sys[non_deferred_mask_test_eval_head_sys],
                class_names=MITBIH_CLASS_NAMES,
                title="MIT-BIH System CM (Non-Deferred by Head, Preds by Base)",
                save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_system_cm_nd_test.png")
            )
            plot_roc_auc_curves(
                y_true_list=[y_true_test_np_sys[non_deferred_mask_test_eval_head_sys]],
                y_pred_probs_list=[y_pred_probs_base_test_overall_sys[non_deferred_mask_test_eval_head_sys]],
                label_names_list=["System Non-Deferred Test Samples"],
                output_dim=MITBIH_NUM_CLASSES,
                class_names=MITBIH_CLASS_NAMES,
                title="MIT-BIH System ROC Curve (Non-Deferred by Head, Preds by Base)",
                save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_system_roc_nd_test.png")
            )
        else:
            print("No non-deferred samples in MIT-BIH test set for system CM/ROC plots.")

        print("\n--- Plotting MIT-BIH System Deferral Performance Curve (using Validation Set data & Deferral Head scores) ---")
        all_y_true_val_dh_viz = []
        all_y_pred_logits_base_val_dh_viz = []
        all_deferral_head_scores_val_viz = []

        if trained_base_mitbih_model: trained_base_mitbih_model.eval() # Ensure eval mode
        if trained_deferral_head_mitbih: trained_deferral_head_mitbih.eval()

        with torch.no_grad():
            for X_b, y_b, p_mask_b in tqdm(mitbih_loaders['val'], desc="Fetching Val Data for System Deferral Curve"):
                X_b, y_b, p_mask_b = X_b.to(DEVICE), y_b.to(DEVICE), p_mask_b.to(DEVICE)
                base_output = trained_base_mitbih_model(X_b, p_mask_b)
                z_original_b = base_output['z_original']
                deferral_logits_head_b = trained_deferral_head_mitbih(z_original_b)
                deferral_scores_head_b = torch.sigmoid(deferral_logits_head_b).squeeze(-1)

                all_y_true_val_dh_viz.append(y_b.cpu())
                all_y_pred_logits_base_val_dh_viz.append(base_output['y_pred_logits'].cpu())
                all_deferral_head_scores_val_viz.append(deferral_scores_head_b.cpu())

        if all_deferral_head_scores_val_viz: # Check if list is not empty
            y_true_val_np_dh_viz = torch.cat(all_y_true_val_dh_viz).numpy()
            y_pred_logits_base_val_np_dh_viz = torch.cat(all_y_pred_logits_base_val_dh_viz).numpy()
            deferral_head_scores_val_np_viz = torch.cat(all_deferral_head_scores_val_viz).numpy()

            plot_deferral_performance_vs_threshold(
                sensitivities_val=deferral_head_scores_val_np_viz,
                y_true_val=y_true_val_np_dh_viz,
                y_pred_logits_val=y_pred_logits_base_val_np_dh_viz,
                model_output_dim=MITBIH_NUM_CLASSES,
                title="MIT-BIH System: Acc_ND (Base Model) vs. Deferral Rate (Deferral Head)",
                save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "mitbih_system_deferral_curve_val.png")
            )
        else:
            print("Could not plot system deferral performance curve: No validation deferral head scores collected.")
    else:
        print("Skipping system-level test visualizations as test data could not be re-fetched for plots.")

# --- 6. Visualize Explanations for a few Test Samples for the Deferral Head's decision ---
# Define explainer_class_ready and captum_ready before the if condition
captum_ready = False
if 'Saliency' in globals() and Saliency is not None and \
   'IntegratedGradients' in globals() and IntegratedGradients is not None:
    captum_ready = True

explainer_class_ready = False
if 'CGDExplainer' in globals() and callable(CGDExplainer):
    explainer_class_ready = True

if captum_ready and explainer_class_ready and trained_base_mitbih_model is not None and trained_deferral_head_mitbih is not None:
    print("\n--- Visualizing Explanations for MIT-BIH Deferral Head Decisions ---")

    def mitbih_deferral_head_score_for_explainer(X_b, p_mask_b): # Renamed to be specific
        # Ensure models are on the correct device within this function's scope if called by Captum
        trained_base_mitbih_model.to(X_b.device)
        trained_deferral_head_mitbih.to(X_b.device)
        base_output = trained_base_mitbih_model(X_b, p_mask_b)
        z_orig = base_output['z_original']
        dh_logits = trained_deferral_head_mitbih(z_orig)
        return torch.sigmoid(dh_logits).squeeze(-1)

    try:
        mitbih_dh_saliency_explainer = Saliency(mitbih_deferral_head_score_for_explainer)
    except Exception as e_captum_init_mitbih_dh:
        print(f"Could not initialize Captum Saliency for MIT-BIH Deferral Head: {e_captum_init_mitbih_dh}")
        mitbih_dh_saliency_explainer = None

    num_samples_to_explain = 3
    explained_count = 0
    try:
        for X_explain_batch, y_explain_batch, p_mask_explain_batch in mitbih_loaders['test']:
            if explained_count >= num_samples_to_explain: break
            samples_to_take_from_batch = min(X_explain_batch.size(0), num_samples_to_explain - explained_count)

            X_explain_current = X_explain_batch[:samples_to_take_from_batch].to(DEVICE)
            y_explain_current = y_explain_batch[:samples_to_take_from_batch].to(DEVICE)
            p_mask_explain_current = p_mask_explain_batch[:samples_to_take_from_batch].to(DEVICE)

            for i_spl in range(X_explain_current.size(0)):
                if explained_count >= num_samples_to_explain: break
                X_s, y_s_true, p_mask_s = X_explain_current[i_spl:i_spl+1], y_explain_current[i_spl], p_mask_explain_current[i_spl:i_spl+1]

                with torch.no_grad():
                    base_out_s = trained_base_mitbih_model(X_s, p_mask_s)
                    z_s = base_out_s['z_original']
                    pred_logits_base_s = base_out_s['y_pred_logits']
                    pred_class_base_s = torch.argmax(pred_logits_base_s, dim=1).item()
                    dh_score_s = torch.sigmoid(trained_deferral_head_mitbih(z_s)).squeeze().item()

                is_deferred_by_head = dh_score_s > (mitbih_adaptive_threshold_for_head if mitbih_adaptive_threshold_for_head is not None else 0.5)
                true_lbl_name = MITBIH_CLASS_NAMES[y_s_true.item()]
                pred_lbl_name = MITBIH_CLASS_NAMES[pred_class_base_s]

                print(f"MIT-BIH System Test Sample {explained_count+1}: True='{true_lbl_name}', BasePred='{pred_lbl_name}', DH_Score={dh_score_s:.3f}, DeferredByHead={is_deferred_by_head}")

                attrs_dh_sens = None
                if mitbih_dh_saliency_explainer: # Check if explainer was initialized
                    X_s.requires_grad_(True)
                    try:
                        attrs_dh_sens = mitbih_dh_saliency_explainer.attribute(X_s, additional_forward_args=(p_mask_s,), abs=True)
                        attrs_dh_sens = attrs_dh_sens.squeeze().cpu().numpy()
                    except Exception as e_attr_dh:
                        print(f"  Error getting saliency for DH score on MIT-BIH sample: {e_attr_dh}")
                        attrs_dh_sens = None
                    X_s.requires_grad_(False) # Reset grad requirement

                plot_ecg_with_saliency(
                        ecg_signal=X_s.squeeze().cpu().numpy(), attributions=attrs_dh_sens,
                        title=f"MIT-BIH Sample {explained_count+1} - Deferral Head Score Explanation",
                        save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], f"mitbih_system_explain_dh_sample{explained_count}.png"),
                        true_label_name=true_lbl_name, pred_label_name=pred_lbl_name,
                        is_deferred=is_deferred_by_head, sensitivity_score=dh_score_s
                    )
                explained_count += 1
            if explained_count >= num_samples_to_explain: break
    except Exception as e_final_explain:
        print(f"Error during system explanation visualization for MIT-BIH: {e_final_explain}")
        traceback.print_exc()
else:
    print("\nSkipping MIT-BIH system explanation visualization (components missing or Captum not fully ready).")


print(f"\n--- MIT-BIH Stage 2 Analysis Complete. Check '{GENERAL_TRAINING_CONFIG['checkpoint_dir']}' for saved artifacts. ---")


In [None]:
# --- Cell 20: Execute Stage 2 for PTB (Train Deferral Head & Evaluate System) ---

print("--- Starting Stage 2 for PTB Dataset ---")

# Ensure necessary variables from previous cells are available
if 'UniversalCGDModel' not in globals() or 'DeferralPredictorHead' not in globals() or \
   'train_separate_deferral_head_model' not in globals() or \
   'compute_adaptive_threshold_for_deferral_head' not in globals() or \
   'evaluate_system_with_deferral_head' not in globals():
    raise NameError("One or more required model/training/evaluation functions are not defined. Please run Cells 9, 18.A, 18.C, 18.D.")

if 'ptb_loaders' not in globals() or ptb_loaders is None or \
   not all(k in ptb_loaders for k in ['train', 'val', 'test']):
    raise NameError("`ptb_loaders` are not properly defined. Please ensure Cell 3 (Revised) was run successfully and created these DataLoaders for PTB.")

if 'CGD_MODEL_CONFIG' not in globals() or 'ENCODER_CONFIG' not in globals() or \
   'PREDICTOR_CONFIG' not in globals() or 'DEFERRAL_HEAD_CONFIG' not in globals() or \
   'ADAPTIVE_THRESHOLD_CONFIG' not in globals() or 'GENERAL_TRAINING_CONFIG' not in globals() or \
   'PTB_NUM_CLASSES' not in globals() or 'PTB_CLASS_NAMES' not in globals():
    raise NameError("One or more required configurations or PTB-specific variables are missing. Please re-run relevant setup cells (2, 17, 18.A example).")

# --- 1. Load the Best Trained Base Classifier for PTB from Stage 1 ---
#  Ensure this checkpoint name matches what you used for saving the Stage 1 PTB model
base_model_ptb_path = os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_base_classifier_best.pt")
trained_base_ptb_model = None

if os.path.exists(base_model_ptb_path):
    print(f"Loading Stage 1 base PTB model from: {base_model_ptb_path}")
    trained_base_ptb_model = UniversalCGDModel(
        model_config=CGD_MODEL_CONFIG, # This should be the Stage 1 config (loss_beta=0)
        encoder_config=ENCODER_CONFIG,
        predictor_config=PREDICTOR_CONFIG,
        perturb_config=PERTURBATION_CONFIG, # Config used during Stage 1 base model training
        sensitivity_config=SENSITIVITY_CONFIG, # Config used during Stage 1 base model training
        regularizer_config=STRUCTURAL_REGULARIZER_CONFIG,
        output_dim=PTB_NUM_CLASSES # Should be 1 for PTB with BCE
    ).to(DEVICE)
    try:
        trained_base_ptb_model.load_state_dict(torch.load(base_model_ptb_path, map_location=DEVICE))
        trained_base_ptb_model.eval() # Set to eval mode
        for param in trained_base_ptb_model.parameters(): # Freeze all parameters
            param.requires_grad = False
        print("Stage 1 PTB base model loaded and frozen successfully.")
    except Exception as e:
        print(f"Error loading Stage 1 PTB base model: {e}. Cannot proceed with Stage 2.")
        trained_base_ptb_model = None
else:
    print(f"ERROR: Stage 1 PTB base model checkpoint not found at {base_model_ptb_path}. Please complete Stage 1 training for PTB first.")

# --- 2. Train the Separate Deferral Predictor Head for PTB ---
trained_deferral_head_ptb = None
ptb_deferral_head_history = {}

if trained_base_ptb_model is not None:
    print("\n--- Training Separate Deferral Head for PTB ---")
    ptb_deferral_head_checkpoint_name = "ptb_deferral_head_best.pt"

    dh_input_dim_ptb = trained_base_ptb_model.latent_dim

    trained_deferral_head_ptb, ptb_deferral_head_history = train_separate_deferral_head_model(
        base_model_frozen=trained_base_ptb_model,
        deferral_head_input_dim=dh_input_dim_ptb,
        original_train_loader_for_deferral_data=ptb_loaders['train'],
        original_val_loader_for_deferral_data=ptb_loaders['val'],
        deferral_head_config=DEFERRAL_HEAD_CONFIG, # From Cell 2 or defined in 18.A
        general_training_config=GENERAL_TRAINING_CONFIG,
        device=DEVICE,
        deferral_head_checkpoint_name=ptb_deferral_head_checkpoint_name
    )
    print("--- PTB Deferral Head Training Finished ---")

    if trained_deferral_head_ptb and ptb_deferral_head_history:
        plot_training_history(
            ptb_deferral_head_history,
            metrics_to_plot=[
                {'train_key': 'train_loss', 'val_key': 'val_loss', 'title': 'PTB Deferral Head Loss', 'ylabel': 'BCE Loss'},
                {'val_key': 'val_accuracy', 'title': 'PTB Deferral Head Val Acc (Error Pred.)', 'ylabel': 'Accuracy'}
            ],
            save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_deferral_head_history.png")
        )
else:
    print("Skipping PTB Deferral Head training as Stage 1 base model was not loaded.")

# --- 3. Compute Adaptive Threshold for the Trained PTB Deferral Head ---
ptb_adaptive_threshold_for_head = None
if trained_base_ptb_model is not None and trained_deferral_head_ptb is not None:
    print("\n--- Computing Adaptive Threshold for PTB Deferral Head ---")
    ptb_adaptive_threshold_for_head = compute_adaptive_threshold_for_deferral_head(
        base_model_frozen=trained_base_ptb_model,
        deferral_head=trained_deferral_head_ptb,
        original_val_loader=ptb_loaders['val'],
        adaptive_thresh_config=ADAPTIVE_THRESHOLD_CONFIG, # From Cell 2
        device=DEVICE
    )
    print(f"Computed PTB Adaptive Threshold for Deferral Head: {ptb_adaptive_threshold_for_head:.4f}")
else:
    print("Skipping adaptive threshold computation for PTB deferral head (model or head missing).")

# --- 4. Evaluate the Full System (Base Model + Deferral Head) on PTB Test Set ---
ptb_system_eval_metrics = {}
if trained_base_ptb_model is not None and trained_deferral_head_ptb is not None and ptb_adaptive_threshold_for_head is not None:
    print("\n--- Evaluating Full System (Base Model + Deferral Head) on PTB Test Set ---")
    ptb_system_eval_metrics = evaluate_system_with_deferral_head(
        base_model_frozen=trained_base_ptb_model,
        deferral_head=trained_deferral_head_ptb,
        test_loader=ptb_loaders['test'],
        adaptive_threshold_for_head=ptb_adaptive_threshold_for_head,
        device=DEVICE,
        dataset_name="PTB Test System"
    )
else:
    print("Skipping full system evaluation for PTB (components missing).")

# --- 5. Visualize Final System Performance on PTB Test Set ---
if ptb_system_eval_metrics:
    print("\n--- Visualizing Final PTB System Test Set Performance ---")

    all_y_true_test_ptb_sys = []
    all_y_pred_logits_base_test_ptb_sys = []
    all_deferral_head_scores_test_ptb_sys = []

    trained_base_ptb_model.eval()
    trained_deferral_head_ptb.eval()

    with torch.no_grad():
        for X_b, y_b, p_mask_b in tqdm(ptb_loaders['test'], desc="Re-fetching PTB Test Data for System Plots"):
            X_b, y_b, p_mask_b = X_b.to(DEVICE), y_b.to(DEVICE), p_mask_b.to(DEVICE)
            base_output = trained_base_ptb_model(X_b, p_mask_b)
            z_original_b = base_output['z_original']
            deferral_logits_head_b = trained_deferral_head_ptb(z_original_b)
            deferral_scores_head_b = torch.sigmoid(deferral_logits_head_b).squeeze(-1)

            all_y_true_test_ptb_sys.append(y_b.cpu())
            all_y_pred_logits_base_test_ptb_sys.append(base_output['y_pred_logits'].cpu())
            all_deferral_head_scores_test_ptb_sys.append(deferral_scores_head_b.cpu())

    if all_y_true_test_ptb_sys:
        y_true_test_np_ptb_sys = torch.cat(all_y_true_test_ptb_sys).numpy().astype(float) # PTB labels are float
        y_pred_logits_base_test_np_ptb_sys = torch.cat(all_y_pred_logits_base_test_ptb_sys).numpy()
        deferral_head_scores_test_np_ptb_sys = torch.cat(all_deferral_head_scores_test_ptb_sys).numpy()

        y_pred_probs_base_test_overall_ptb_sys = 1 / (1 + np.exp(-y_pred_logits_base_test_np_ptb_sys.squeeze()))
        y_pred_classes_base_test_overall_ptb_sys = (y_pred_probs_base_test_overall_ptb_sys > 0.5).astype(int)

        defer_mask_test_eval_head_ptb_sys = deferral_head_scores_test_np_ptb_sys > ptb_adaptive_threshold_for_head
        non_deferred_mask_test_eval_head_ptb_sys = ~defer_mask_test_eval_head_ptb_sys

        if np.sum(non_deferred_mask_test_eval_head_ptb_sys) > 0:
            plot_confusion_matrix_custom(
                y_true=y_true_test_np_ptb_sys[non_deferred_mask_test_eval_head_ptb_sys].astype(int),
                y_pred_classes=y_pred_classes_base_test_overall_ptb_sys[non_deferred_mask_test_eval_head_ptb_sys],
                class_names=PTB_CLASS_NAMES,
                title="PTB System CM (Non-Deferred by Head, Preds by Base)",
                save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_system_cm_nd_test.png")
            )
            plot_roc_auc_curves(
                y_true_list=[y_true_test_np_ptb_sys[non_deferred_mask_test_eval_head_ptb_sys]],
                y_pred_probs_list=[y_pred_probs_base_test_overall_ptb_sys[non_deferred_mask_test_eval_head_ptb_sys]],
                label_names_list=["System Non-Deferred PTB Test Samples"],
                output_dim=PTB_NUM_CLASSES, # Should be 1
                class_names=PTB_CLASS_NAMES,
                title="PTB System ROC Curve (Non-Deferred by Head, Preds by Base)",
                save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_system_roc_nd_test.png")
            )
        else:
            print("No non-deferred samples in PTB test set for system plots (CM, ROC).")

        # Deferral Performance Curve for the Deferral Head's scores on Validation set
        print("\n--- Plotting PTB System Deferral Performance Curve (using Validation Set data & Deferral Head scores) ---")
        all_y_true_val_dh_ptb_viz = []
        all_y_pred_logits_base_val_dh_ptb_viz = []
        all_deferral_head_scores_val_ptb_viz = []

        trained_base_ptb_model.eval()
        trained_deferral_head_ptb.eval()
        with torch.no_grad():
            for X_b, y_b, p_mask_b in tqdm(ptb_loaders['val'], desc="Fetching PTB Val Data for System Deferral Curve"):
                X_b, y_b, p_mask_b = X_b.to(DEVICE), y_b.to(DEVICE), p_mask_b.to(DEVICE)
                base_output = trained_base_ptb_model(X_b, p_mask_b)
                z_original_b = base_output['z_original']
                deferral_logits_head_b = trained_deferral_head_ptb(z_original_b)
                deferral_scores_head_b = torch.sigmoid(deferral_logits_head_b).squeeze(-1)

                all_y_true_val_dh_ptb_viz.append(y_b.cpu())
                all_y_pred_logits_base_val_dh_ptb_viz.append(base_output['y_pred_logits'].cpu())
                all_deferral_head_scores_val_ptb_viz.append(deferral_scores_head_b.cpu())

        if all_deferral_head_scores_val_ptb_viz:
            y_true_val_np_dh_ptb_viz = torch.cat(all_y_true_val_dh_ptb_viz).numpy()
            y_pred_logits_base_val_np_dh_ptb_viz = torch.cat(all_y_pred_logits_base_val_dh_ptb_viz).numpy()
            deferral_head_scores_val_np_ptb_viz = torch.cat(all_deferral_head_scores_val_ptb_viz).numpy()

            plot_deferral_performance_vs_threshold(
                sensitivities_val=deferral_head_scores_val_np_ptb_viz, # Using deferral head's scores
                y_true_val=y_true_val_np_dh_ptb_viz,
                y_pred_logits_val=y_pred_logits_base_val_np_dh_ptb_viz, # Base model's predictions
                model_output_dim=PTB_NUM_CLASSES, # Should be 1
                title="PTB System: Acc_ND (Base Model) vs. Deferral Rate (Deferral Head)",
                save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], "ptb_system_deferral_curve_val.png")
            )
        else:
            print("Could not plot PTB system deferral performance curve: No validation deferral head scores collected.")
    else:
        print("Skipping PTB system-level test visualizations as test data could not be re-fetched for plots.")

# 6. Visualize Explanations for a few PTB Test Samples
if 'CGDExplainer' in globals() and 'Saliency' in globals() and Saliency is not None and \
   'IntegratedGradients' in globals() and IntegratedGradients is not None and \
   trained_base_ptb_model is not None and trained_deferral_head_ptb is not None:
    print("\n--- Visualizing Explanations for PTB Deferral Head Decisions ---")

    # Define the callable for Captum to explain deferral head scores
    def ptb_deferral_head_score_for_explainer(X_b, p_mask_b):
        trained_base_ptb_model.to(X_b.device) # Ensure base model is on same device as input
        trained_deferral_head_ptb.to(X_b.device) # Ensure deferral head is on same device
        base_output = trained_base_ptb_model(X_b, p_mask_b)
        z_orig = base_output['z_original']
        dh_logits = trained_deferral_head_ptb(z_orig)
        return torch.sigmoid(dh_logits).squeeze(-1)

    try:
        ptb_dh_saliency_explainer = Saliency(ptb_deferral_head_score_for_explainer)
    except Exception as e_captum_init:
        print(f"Could not initialize Captum Saliency for PTB Deferral Head: {e_captum_init}")
        ptb_dh_saliency_explainer = None

    num_samples_to_explain = 3
    explained_count = 0
    try:
        for X_explain_batch, y_explain_batch, p_mask_explain_batch in ptb_loaders['test']:
            if explained_count >= num_samples_to_explain: break
            samples_to_take_from_batch = min(X_explain_batch.size(0), num_samples_to_explain - explained_count)

            X_explain_current = X_explain_batch[:samples_to_take_from_batch].to(DEVICE)
            y_explain_current = y_explain_batch[:samples_to_take_from_batch].to(DEVICE)
            p_mask_explain_current = p_mask_explain_batch[:samples_to_take_from_batch].to(DEVICE)

            for i_spl in range(X_explain_current.size(0)):
                if explained_count >= num_samples_to_explain: break
                X_s, y_s_true, p_mask_s = X_explain_current[i_spl:i_spl+1], y_explain_current[i_spl], p_mask_explain_current[i_spl:i_spl+1]

                with torch.no_grad():
                    base_out_s = trained_base_ptb_model(X_s, p_mask_s)
                    z_s = base_out_s['z_original']
                    pred_logits_base_s = base_out_s['y_pred_logits']
                    pred_prob_base_s = torch.sigmoid(pred_logits_base_s).item()
                    pred_class_base_s = 1 if pred_prob_base_s > 0.5 else 0
                    dh_score_s = torch.sigmoid(trained_deferral_head_ptb(z_s)).squeeze().item()

                is_deferred_by_head = dh_score_s > ptb_adaptive_threshold_for_head
                true_lbl_name = PTB_CLASS_NAMES[int(y_s_true.item())]
                pred_lbl_name = PTB_CLASS_NAMES[pred_class_base_s]

                print(f"PTB System Test Sample {explained_count+1}: True='{true_lbl_name}', BasePred='{pred_lbl_name}' (Prob={pred_prob_base_s:.2f}), DH_Score={dh_score_s:.3f}, DeferredByHead={is_deferred_by_head}")

                attrs_dh_sens = None
                if ptb_dh_saliency_explainer:
                    X_s.requires_grad_(True)
                    try:
                        attrs_dh_sens = ptb_dh_saliency_explainer.attribute(X_s, additional_forward_args=(p_mask_s,), abs=True)
                        attrs_dh_sens = attrs_dh_sens.squeeze().cpu().numpy()
                    except Exception as e_attr:
                        print(f"  Error getting saliency for DH score: {e_attr}")
                        attrs_dh_sens = None
                    X_s.requires_grad_(False)

                plot_ecg_with_saliency(
                    ecg_signal=X_s.squeeze().cpu().numpy(),
                    attributions=attrs_dh_sens, # Attributions for Deferral Head's score
                    title=f"PTB Sample {explained_count+1} - Deferral Head Score Explanation",
                    save_path=os.path.join(GENERAL_TRAINING_CONFIG['checkpoint_dir'], f"ptb_system_explain_dh_sample{explained_count}.png"),
                    true_label_name=true_lbl_name, pred_label_name=pred_lbl_name,
                    is_deferred=is_deferred_by_head, sensitivity_score=dh_score_s # Using DH_Score
                )
                explained_count += 1
            if explained_count >= num_samples_to_explain: break
    except Exception as e_final_explain:
        print(f"Error during PTB system explanation visualization: {e_final_explain}")
        traceback.print_exc()
else:
    print("\nSkipping PTB system explanation visualization (components missing or Captum not fully ready).")

print(f"\n--- PTB Stage 2 Analysis Complete. Check '{GENERAL_TRAINING_CONFIG['checkpoint_dir']}' for saved artifacts. ---")