**IMPORTS AND TPU SET UP**

In [None]:
import subprocess
import sys
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import pydicom
from pydicom.errors import InvalidDicomError
import nibabel as nib
import cv2
from scipy import ndimage
from scipy.stats import zscore
from tqdm.auto import tqdm
import warnings
import gc
import time
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (accuracy_score, roc_auc_score, classification_report, 
                           confusion_matrix, precision_recall_curve, roc_curve,
                           f1_score, precision_score, recall_score)
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler
import albumentations as A
from collections import Counter
import logging
warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def install_required_packages():
    packages = [
        'pydicom',
        'nibabel', 
        'opencv-python',
        'scikit-learn',
        'albumentations',
        'seaborn',
        'matplotlib'
    ]
    
    for package in packages:
        try:
            __import__(package.replace('-', '_'))
            print(f"{package} already installed")
        except ImportError:
            print(f"Installing {package}...")
            try:
                subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '-q'])
                print(f"{package} installed successfully")
            except Exception as e:
                print(f"Failed to install {package}: {e}")

print("Checking and installing required packages...")
install_required_packages()

# TPU-SPECIFIC SETUP
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    TPU_AVAILABLE = True
    print("TPU libraries loaded successfully")
except ImportError:
    TPU_AVAILABLE = False
    print("TPU libraries not available, falling back to GPU/CPU")


**TPU OPTMIZED CONFIGURATION AND TRAINING PIPELINE**

In [None]:
class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    
    # Model parameters - More reasonable sizes for better learning
    TARGET_SIZE = (32, 224, 224)  # Reduced depth, increased spatial resolution
    BATCH_SIZE = 1 if TPU_AVAILABLE else 2  # Smaller batch for stability
    EPOCHS = 30  # More epochs for better convergence
    LEARNING_RATE = 5e-5  # Lower learning rate
    WEIGHT_DECAY = 1e-5
    WARMUP_EPOCHS = 3
    
    # Advanced training parameters
    GRADIENT_CLIP = 1.0
    ACCUMULATION_STEPS = 4  # Gradient accumulation
    EARLY_STOPPING_PATIENCE = 8
    SCHEDULER_PATIENCE = 4
    
    # Device setup
    if TPU_AVAILABLE:
        DEVICE = xm.xla_device()
        print(f"Using TPU device: {DEVICE}")
    else:
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {DEVICE}")
    
    # Data parameters
    ID_COL = 'SeriesInstanceUID'
    TARGET_COL = 'Aneurysm Present'
    
    # Debug settings
    DEBUG_MODE = False  # Turn off for full training
    DEBUG_SAMPLES = 500
    
    # Cross-validation
    N_FOLDS = 5
    CURRENT_FOLD = 0  
    
    # Data augmentation
    USE_AUGMENTATION = True
    AUGMENTATION_PROB = 0.3

class AdvancedDICOMProcessor:
    def __init__(self, target_size=None, hu_window=(-1000, 1000)):
        self.target_size = target_size or Config.TARGET_SIZE
        self.max_slices = self.target_size[0]
        self.hu_window = hu_window
        self.stats = {'processed': 0, 'failed': 0, 'dummy': 0}
        
    def load_dicom_series(self, series_path):
        try:
            if not os.path.exists(series_path):
                logger.warning(f"Series path does not exist: {series_path}")
                self.stats['dummy'] += 1
                return self._get_dummy_volume()
            
            dicom_files = [f for f in os.listdir(series_path) 
                          if f.endswith('.dcm') or f.endswith('.DCM')]
            
            if not dicom_files:
                logger.warning(f"No DICOM files found in: {series_path}")
                self.stats['dummy'] += 1
                return self._get_dummy_volume()
            
            # Load all DICOM files with metadata
            dicom_data = []
            for file_name in dicom_files:
                try:
                    file_path = os.path.join(series_path, file_name)
                    ds = pydicom.dcmread(file_path, force=True)
                    
                    if hasattr(ds, 'pixel_array') and hasattr(ds, 'ImagePositionPatient'):
                        # Get slice location for proper ordering
                        slice_location = float(getattr(ds, 'SliceLocation', 0))
                        instance_number = int(getattr(ds, 'InstanceNumber', 0))
                        
                        dicom_data.append({
                            'dataset': ds,
                            'slice_location': slice_location,
                            'instance_number': instance_number,
                            'filename': file_name
                        })
                except Exception as e:
                    logger.warning(f"Failed to load {file_name}: {e}")
                    continue
            
            if not dicom_data:
                self.stats['dummy'] += 1
                return self._get_dummy_volume()
            
            # Sort by slice location, then by instance number
            dicom_data.sort(key=lambda x: (x['slice_location'], x['instance_number']))
            
            # Extract pixel arrays with proper preprocessing
            pixel_arrays = []
            target_shape = self.target_size[1:]  # H, W
            
            for item in dicom_data:
                try:
                    ds = item['dataset']
                    arr = ds.pixel_array.astype(np.float32)
                    
                    # Handle different DICOM formats
                    if arr.ndim == 2:
                        # Apply DICOM transformations
                        arr = self._apply_dicom_transforms(ds, arr)
                        
                        # Resize to target shape
                        if arr.shape != target_shape:
                            arr = cv2.resize(arr, (target_shape[1], target_shape[0]), 
                                           interpolation=cv2.INTER_AREA)
                        
                        pixel_arrays.append(arr)
                        
                    elif arr.ndim == 3:
                        # Multi-slice DICOM - take all slices
                        for slice_idx in range(arr.shape[0]):
                            slice_arr = arr[slice_idx]
                            slice_arr = self._apply_dicom_transforms(ds, slice_arr)
                            
                            if slice_arr.shape != target_shape:
                                slice_arr = cv2.resize(slice_arr, (target_shape[1], target_shape[0]), 
                                                     interpolation=cv2.INTER_AREA)
                            
                            pixel_arrays.append(slice_arr)
                    
                    # Clean up
                    del ds
                    
                except Exception as e:
                    logger.warning(f"Failed to process slice: {e}")
                    continue
            
            if not pixel_arrays:
                self.stats['dummy'] += 1
                return self._get_dummy_volume()
            
            # Handle slice selection and padding
            volume = self._create_volume_from_slices(pixel_arrays)
            
            # Advanced preprocessing pipeline
            volume = self._advanced_preprocessing(volume)
            
            self.stats['processed'] += 1
            return volume
            
        except Exception as e:
            logger.error(f"Critical error processing {series_path}: {e}")
            self.stats['failed'] += 1
            return self._get_dummy_volume()
    
    def _apply_dicom_transforms(self, ds, arr):
        # Apply rescale slope and intercept if available
        if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
            slope = float(ds.RescaleSlope)
            intercept = float(ds.RescaleIntercept)
            arr = arr * slope + intercept
        
        # Apply window/level if available
        if hasattr(ds, 'WindowCenter') and hasattr(ds, 'WindowWidth'):
            try:
                center = float(ds.WindowCenter) if not isinstance(ds.WindowCenter, list) else float(ds.WindowCenter[0])
                width = float(ds.WindowWidth) if not isinstance(ds.WindowWidth, list) else float(ds.WindowWidth[0])
                
                min_val = center - width / 2
                max_val = center + width / 2
                arr = np.clip(arr, min_val, max_val)
            except:
                pass
        
        return arr
    
    def _create_volume_from_slices(self, pixel_arrays):
        # If we have more slices than needed, select evenly distributed ones
        if len(pixel_arrays) > self.max_slices:
            indices = np.linspace(0, len(pixel_arrays)-1, self.max_slices, dtype=int)
            selected_arrays = [pixel_arrays[i] for i in indices]
        else:
            selected_arrays = pixel_arrays.copy()
        
        # Pad with edge replication if needed
        while len(selected_arrays) < self.max_slices:
            if selected_arrays:
                # Replicate edge slices
                selected_arrays.insert(0, selected_arrays[0])  # Add at beginning
                if len(selected_arrays) < self.max_slices:
                    selected_arrays.append(selected_arrays[-1])  # Add at end
            else:
                # Create zero slices
                selected_arrays.append(np.zeros(self.target_size[1:], dtype=np.float32))
        
        # Ensure exact number of slices
        selected_arrays = selected_arrays[:self.max_slices]
        
        # Stack into volume
        volume = np.stack(selected_arrays, axis=0).astype(np.float32)
        
        return volume
    
    def _advanced_preprocessing(self, volume):
        # 1. Outlier removal using percentiles
        p1, p99 = np.percentile(volume, [1, 99])
        volume = np.clip(volume, p1, p99)
        
        # 2. Z-score normalization per slice to handle varying contrast
        normalized_slices = []
        for slice_idx in range(volume.shape[0]):
            slice_data = volume[slice_idx]
            if slice_data.std() > 0:
                normalized_slice = zscore(slice_data.flatten()).reshape(slice_data.shape)
                normalized_slice = np.clip(normalized_slice, -3, 3)  # Clip extreme values
            else:
                normalized_slice = slice_data
            normalized_slices.append(normalized_slice)
        
        volume = np.stack(normalized_slices, axis=0)
        
        # 3. Min-max normalization to [0, 1]
        vol_min, vol_max = volume.min(), volume.max()
        if vol_max > vol_min:
            volume = (volume - vol_min) / (vol_max - vol_min)
        
        # 4. Slight Gaussian smoothing for noise reduction
        volume = ndimage.gaussian_filter(volume, sigma=0.5)
        
        return volume.astype(np.float32)
    
    def _get_dummy_volume(self):
        """Generate realistic dummy volume"""
        # Create a volume that resembles brain tissue
        volume = np.random.normal(0.3, 0.15, self.target_size).astype(np.float32)
        
        # Add some structure
        for i in range(self.target_size[0]):
            # Add circular structures (resembling brain anatomy)
            center_y, center_x = self.target_size[1]//2, self.target_size[2]//2
            y, x = np.ogrid[:self.target_size[1], :self.target_size[2]]
            mask = (x - center_x)**2 + (y - center_y)**2 < (min(self.target_size[1], self.target_size[2])//3)**2
            volume[i][mask] += 0.2
        
        volume = np.clip(volume, 0, 1)
        return volume
    
    def get_stats(self):
        return self.stats

class VolumetricAugmentation:
    """3D-aware data augmentation"""
    
    def __init__(self, prob=0.3):
        self.prob = prob
        
        # 2D augmentations (applied slice-wise)
        self.slice_transform = A.Compose([
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.GaussianBlur(blur_limit=3, p=0.2),
        ], p=self.prob)
    
    def __call__(self, volume):
        if np.random.random() > self.prob:
            return volume
        
        augmented_volume = np.zeros_like(volume)
        
        for i in range(volume.shape[0]):
            slice_2d = volume[i]
            
            # Convert to uint8 for albumentations
            slice_uint8 = (slice_2d * 255).astype(np.uint8)
            
            # Apply 2D augmentation
            transformed = self.slice_transform(image=slice_uint8)
            augmented_slice = transformed['image'].astype(np.float32) / 255.0
            
            augmented_volume[i] = augmented_slice
        
        # 3D-specific augmentations
        if np.random.random() < 0.3:
            # Random slice dropout
            n_dropout = np.random.randint(1, min(4, volume.shape[0]//4))
            dropout_indices = np.random.choice(volume.shape[0], n_dropout, replace=False)
            for idx in dropout_indices:
                augmented_volume[idx] = 0
        
        if np.random.random() < 0.2:
            # Flip along depth axis
            augmented_volume = np.flip(augmented_volume, axis=0)
        
        return augmented_volume

class ImprovedAneurysmDataset(Dataset):
    def __init__(self, df, series_dir, processor, augmentation=None, mode='train'):
        self.df = df.copy().reset_index(drop=True)
        self.series_dir = series_dir
        self.processor = processor
        self.augmentation = augmentation
        self.mode = mode
        
        # Validate data
        self._validate_data()
        
        # Create cache for loaded volumes (memory permitting)
        self.cache = {}
        self.use_cache = len(self.df) < 100  # Only cache small datasets
        
        logger.info(f"Dataset created with {len(self.df)} samples")
        logger.info(f"Positive cases: {self.df[Config.TARGET_COL].sum()}")
        logger.info(f"Mode: {mode}, Use cache: {self.use_cache}")
        
    def _validate_data(self):
        # Check for missing values
        missing_ids = self.df[Config.ID_COL].isnull().sum()
        missing_labels = self.df[Config.TARGET_COL].isnull().sum()
        
        if missing_ids > 0:
            logger.warning(f"Found {missing_ids} missing IDs")
            self.df = self.df.dropna(subset=[Config.ID_COL])
        
        if missing_labels > 0:
            logger.warning(f"Found {missing_labels} missing labels")
            self.df = self.df.dropna(subset=[Config.TARGET_COL])
        
        # Validate label values
        unique_labels = self.df[Config.TARGET_COL].unique()
        logger.info(f"Unique labels: {unique_labels}")
        
        if not all(label in [0, 1] for label in unique_labels):
            logger.error(f"Invalid labels found: {unique_labels}")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            series_id = str(row[Config.ID_COL])
            label = float(row[Config.TARGET_COL])
            
            # Try cache first
            if self.use_cache and series_id in self.cache:
                volume = self.cache[series_id]
            else:
                series_path = os.path.join(self.series_dir, series_id)
                volume = self.processor.load_dicom_series(series_path)
                
                if self.use_cache:
                    self.cache[series_id] = volume
            
            # Apply augmentation for training
            if self.augmentation is not None and self.mode == 'train':
                volume = self.augmentation(volume)
            
            # Convert to tensor
            volume_tensor = torch.from_numpy(volume).float().unsqueeze(0)  # Add channel dim
            label_tensor = torch.tensor(label, dtype=torch.float32)
            
            return {
                'volume': volume_tensor,
                'label': label_tensor,
                'series_id': series_id,
                'idx': idx
            }
            
        except Exception as e:
            logger.error(f"Error loading sample {idx}: {e}")
            # Return dummy data with correct label
            return {
                'volume': torch.zeros((1, *Config.TARGET_SIZE), dtype=torch.float32),
                'label': torch.tensor(0.0, dtype=torch.float32),
                'series_id': f"DUMMY_{idx}",
                'idx': idx
            }

class ResidualBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock3D, self).__init__()
        
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, 
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        
        out += residual
        out = self.relu(out)
        
        return out

class AttentionBlock3D(nn.Module):
    def __init__(self, channels):
        super(AttentionBlock3D, self).__init__()
        self.channels = channels
        
        # Channel attention
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(channels, channels // 8, 1),
            nn.ReLU(inplace=True),
            nn.Conv3d(channels // 8, channels, 1),
            nn.Sigmoid()
        )
        
        # Spatial attention
        self.spatial_attention = nn.Sequential(
            nn.Conv3d(2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # Channel attention
        ca = self.channel_attention(x)
        x = x * ca
        
        # Spatial attention
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        sa_input = torch.cat([avg_pool, max_pool], dim=1)
        sa = self.spatial_attention(sa_input)
        x = x * sa
        
        return x

class AdvancedAneurysmNet(nn.Module):
    """Advanced 3D CNN with residual blocks and attention"""
    
    def __init__(self, in_channels=1, num_classes=1, dropout_rate=0.3):
        super(AdvancedAneurysmNet, self).__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        
        # Residual layers
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        # Attention mechanisms
        self.attention1 = AttentionBlock3D(128)
        self.attention2 = AttentionBlock3D(256)
        self.attention3 = AttentionBlock3D(512)
        
        # Global pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.dropout = nn.Dropout(dropout_rate)
        
        # Multi-layer classifier with batch norm
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.5),
            
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.3),
            
            nn.Linear(64, num_classes)
        )
        
        self._initialize_weights()
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )
        
        layers = []
        layers.append(ResidualBlock3D(in_channels, out_channels, stride, downsample))
        
        for _ in range(1, blocks):
            layers.append(ResidualBlock3D(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # Residual layers with attention
        x = self.layer1(x)
        
        x = self.layer2(x)
        x = self.attention1(x)
        
        x = self.layer3(x)
        x = self.attention2(x)
        
        x = self.layer4(x)
        x = self.attention3(x)
        
        # Global pooling and classification
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.classifier(x)
        
        return x

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, pos_weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.pos_weight = pos_weight
    
    def forward(self, inputs, targets):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(
            inputs.view(-1), targets, pos_weight=self.pos_weight, reduction='none'
        )
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
        
        return self.counter >= self.patience
    
    def save_checkpoint(self, model):
        if self.restore_best_weights:
            self.best_weights = model.state_dict().copy()
    
    def restore(self, model):
        if self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class MetricsTracker:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.predictions = []
        self.probabilities = []
        self.labels = []
        self.losses = []
    
    def update(self, preds, probs, labels, loss):
        self.predictions.extend(preds)
        self.probabilities.extend(probs)
        self.labels.extend(labels)
        self.losses.append(loss)
    
    def compute_metrics(self):
        if not self.labels:
            return {}
        
        y_true = np.array(self.labels)
        y_pred = np.array(self.predictions)
        y_proba = np.array(self.probabilities)
        
        metrics = {
            'loss': np.mean(self.losses),
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, zero_division=0),
            'recall': recall_score(y_true, y_pred, zero_division=0),
            'f1': f1_score(y_true, y_pred, zero_division=0),
        }
        
        # Add AUC if we have both classes
        if len(np.unique(y_true)) > 1:
            metrics['auc'] = roc_auc_score(y_true, y_proba)
        else:
            metrics['auc'] = 0.5
        
        return metrics

def create_balanced_sampler(dataset):
    """Create weighted sampler for balanced training"""
    targets = [dataset[i]['label'].item() for i in range(len(dataset))]
    class_counts = Counter(targets)
    
    # Calculate weights
    total_samples = len(targets)
    weights = []
    for target in targets:
        weights.append(total_samples / (len(class_counts) * class_counts[target]))
    
    return WeightedRandomSampler(weights, len(weights))

def train_epoch(model, loader, optimizer, criterion, device, epoch, accumulation_steps=1):
    """Enhanced training loop with gradient accumulation"""
    model.train()
    metrics_tracker = MetricsTracker()
    
    progress_bar = tqdm(
        enumerate(loader),
        total=len(loader),
        desc=f"Training Epoch {epoch+1}",
        leave=False
    )
    
    optimizer.zero_grad()
    
    for batch_idx, batch in progress_bar:
        volume = batch['volume'].to(device, non_blocking=True)
        label = batch['label'].to(device, non_blocking=True)
        
        # Forward pass
        outputs = model(volume)
        loss = criterion(outputs, label)
        
        # Scale loss for gradient accumulation
        loss = loss / accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRADIENT_CLIP)
        
        # Update weights every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            if TPU_AVAILABLE:
                xm.optimizer_step(optimizer)
            else:
                optimizer.step()
            optimizer.zero_grad()
        
        # Compute metrics
        with torch.no_grad():
            probs = torch.sigmoid(outputs).cpu().numpy().flatten()
            preds = (probs > 0.5).astype(int)
            labels_np = label.cpu().numpy()
            
            metrics_tracker.update(preds, probs, labels_np, loss.item() * accumulation_steps)
        
        # Update progress bar
        current_metrics = metrics_tracker.compute_metrics()
        progress_bar.set_postfix({
            'Loss': f'{current_metrics.get("loss", 0):.4f}',
            'Acc': f'{current_metrics.get("accuracy", 0):.3f}',
            'AUC': f'{current_metrics.get("auc", 0):.3f}'
        })
        
        # TPU synchronization
        if TPU_AVAILABLE and batch_idx % 10 == 0:
            xm.mark_step()
    
    # Final optimizer step if needed
    if len(loader) % accumulation_steps != 0:
        if TPU_AVAILABLE:
            xm.optimizer_step(optimizer)
        else:
            optimizer.step()
        optimizer.zero_grad()
    
    progress_bar.close()
    
    if TPU_AVAILABLE:
        xm.mark_step()
    
    return metrics_tracker.compute_metrics()

def validate_epoch(model, loader, criterion, device):
    model.eval()
    metrics_tracker = MetricsTracker()
    
    progress_bar = tqdm(
        enumerate(loader),
        total=len(loader),
        desc="Validation",
        leave=False
    )
    
    with torch.no_grad():
        for batch_idx, batch in progress_bar:
            volume = batch['volume'].to(device, non_blocking=True)
            label = batch['label'].to(device, non_blocking=True)
            
            outputs = model(volume)
            loss = criterion(outputs, label)
            
            # Compute metrics
            probs = torch.sigmoid(outputs).cpu().numpy().flatten()
            preds = (probs > 0.5).astype(int)
            labels_np = label.cpu().numpy()
            
            metrics_tracker.update(preds, probs, labels_np, loss.item())
            
            # Update progress bar
            current_metrics = metrics_tracker.compute_metrics()
            progress_bar.set_postfix({
                'Val Loss': f'{current_metrics.get("loss", 0):.4f}',
                'Val Acc': f'{current_metrics.get("accuracy", 0):.3f}',
                'Val AUC': f'{current_metrics.get("auc", 0):.3f}'
            })
    
    progress_bar.close()
    
    if TPU_AVAILABLE:
        xm.mark_step()
    
    return metrics_tracker.compute_metrics()

def plot_training_history(history):
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    metrics = ['loss', 'accuracy', 'auc', 'precision', 'recall', 'f1']
    
    for i, metric in enumerate(metrics):
        ax = axes[i // 3, i % 3]
        
        if f'train_{metric}' in history and f'val_{metric}' in history:
            ax.plot(history[f'train_{metric}'], label=f'Train {metric.upper()}', marker='o')
            ax.plot(history[f'val_{metric}'], label=f'Val {metric.upper()}', marker='s')
            ax.set_title(f'{metric.upper()} History')
            ax.set_xlabel('Epoch')
            ax.set_ylabel(metric.upper())
            ax.legend()
            ax.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

def comprehensive_evaluation(model, loader, device, save_plots=True):
    """Comprehensive model evaluation with detailed metrics and plots"""
    model.eval()
    
    all_predictions = []
    all_probabilities = []
    all_labels = []
    
    print("🔍 Running comprehensive evaluation...")
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            volume = batch['volume'].to(device)
            label = batch['label']
            
            outputs = model(volume)
            probs = torch.sigmoid(outputs).cpu().numpy().flatten()
            preds = (probs > 0.5).astype(int)
            
            all_probabilities.extend(probs)
            all_predictions.extend(preds)
            all_labels.extend(label.numpy())
    
    # Convert to numpy arrays
    y_true = np.array(all_labels)
    y_pred = np.array(all_predictions)
    y_proba = np.array(all_probabilities)
    
    # Calculate comprehensive metrics
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0),
        'auc': roc_auc_score(y_true, y_proba) if len(np.unique(y_true)) > 1 else 0.5
    }
    
    print("\nCOMPREHENSIVE EVALUATION RESULTS")
    print("="*50)
    for metric, value in metrics.items():
        print(f"{metric.upper():>12}: {value:.4f}")
    
    # Detailed classification report
    print(f"\nClassification Report:")
    print(classification_report(y_true, y_pred, 
                              target_names=['No Aneurysm', 'Aneurysm'],
                              zero_division=0))
    
    if save_plots:
        # Create comprehensive plots
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # 1. Confusion Matrix
        cm = confusion_matrix(y_true, y_pred)
        sns.heatmap(cm, annot=True, fmt='d', ax=axes[0,0], 
                   xticklabels=['No Aneurysm', 'Aneurysm'],
                   yticklabels=['No Aneurysm', 'Aneurysm'])
        axes[0,0].set_title('Confusion Matrix')
        axes[0,0].set_ylabel('True Label')
        axes[0,0].set_xlabel('Predicted Label')
        
        # 2. ROC Curve
        if len(np.unique(y_true)) > 1:
            fpr, tpr, _ = roc_curve(y_true, y_proba)
            axes[0,1].plot(fpr, tpr, color='darkorange', lw=2, 
                          label=f'ROC curve (AUC = {metrics["auc"]:.3f})')
            axes[0,1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
            axes[0,1].set_xlim([0.0, 1.0])
            axes[0,1].set_ylim([0.0, 1.05])
            axes[0,1].set_xlabel('False Positive Rate')
            axes[0,1].set_ylabel('True Positive Rate')
            axes[0,1].set_title('ROC Curve')
            axes[0,1].legend(loc="lower right")
            axes[0,1].grid(True)
        
        # 3. Precision-Recall Curve
        if len(np.unique(y_true)) > 1:
            precision, recall, _ = precision_recall_curve(y_true, y_proba)
            axes[1,0].plot(recall, precision, color='blue', lw=2)
            axes[1,0].set_xlabel('Recall')
            axes[1,0].set_ylabel('Precision')
            axes[1,0].set_title('Precision-Recall Curve')
            axes[1,0].grid(True)
        
        # 4. Probability Distribution
        axes[1,1].hist(y_proba[y_true == 0], bins=30, alpha=0.7, 
                      label='No Aneurysm', density=True)
        axes[1,1].hist(y_proba[y_true == 1], bins=30, alpha=0.7, 
                      label='Aneurysm', density=True)
        axes[1,1].set_xlabel('Predicted Probability')
        axes[1,1].set_ylabel('Density')
        axes[1,1].set_title('Probability Distribution')
        axes[1,1].legend()
        axes[1,1].grid(True)
        
        plt.tight_layout()
        plt.savefig('evaluation_results.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    return metrics

def cross_validation_training(train_df, series_dir):
    """K-fold cross-validation training"""
    skf = StratifiedKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
    fold_results = []
    
    processor = AdvancedDICOMProcessor()
    augmentation = VolumetricAugmentation(prob=Config.AUGMENTATION_PROB) if Config.USE_AUGMENTATION else None
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df[Config.TARGET_COL])):
        print(f"\n{'='*20} FOLD {fold + 1}/{Config.N_FOLDS} {'='*20}")
        
        if Config.CURRENT_FOLD >= 0 and fold != Config.CURRENT_FOLD:
            print(f"Skipping fold {fold + 1}")
            continue
        
        fold_train_df = train_df.iloc[train_idx].reset_index(drop=True)
        fold_val_df = train_df.iloc[val_idx].reset_index(drop=True)
        
        print(f"Train: {len(fold_train_df)}, Val: {len(fold_val_df)}")
        print(f"Train pos: {fold_train_df[Config.TARGET_COL].sum()}, "
              f"Val pos: {fold_val_df[Config.TARGET_COL].sum()}")
        
        # Create datasets
        train_dataset = ImprovedAneurysmDataset(
            fold_train_df, series_dir, processor, augmentation, mode='train'
        )
        val_dataset = ImprovedAneurysmDataset(
            fold_val_df, series_dir, processor, mode='val'
        )
        
        # Create balanced sampler
        sampler = create_balanced_sampler(train_dataset)
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=Config.BATCH_SIZE,
            sampler=sampler,
            num_workers=0,
            pin_memory=True if Config.DEVICE.type == 'cuda' else False,
            persistent_workers=False
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=False,
            num_workers=0,
            pin_memory=True if Config.DEVICE.type == 'cuda' else False
        )
        
        # Calculate class weights
        pos_count = fold_train_df[Config.TARGET_COL].sum()
        neg_count = len(fold_train_df) - pos_count
        pos_weight = torch.tensor([neg_count / pos_count if pos_count > 0 else 1.0])
        
        print(f"Positive weight: {pos_weight.item():.2f}")
        
        # Create model
        model = AdvancedAneurysmNet().to(Config.DEVICE)
        
        # Loss function and optimizer
        criterion = FocalLoss(alpha=1, gamma=2, pos_weight=pos_weight.to(Config.DEVICE))
        
        optimizer = optim.AdamW(
            model.parameters(),
            lr=Config.LEARNING_RATE,
            weight_decay=Config.WEIGHT_DECAY,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=Config.LEARNING_RATE * 10,
            epochs=Config.EPOCHS,
            steps_per_epoch=len(train_loader),
            pct_start=0.3,
            div_factor=25,
            final_div_factor=1000
        )
        
        # Early stopping
        early_stopping = EarlyStopping(patience=Config.EARLY_STOPPING_PATIENCE, min_delta=0.001)
        
        # Training history
        history = {
            'train_loss': [], 'val_loss': [],
            'train_accuracy': [], 'val_accuracy': [],
            'train_auc': [], 'val_auc': [],
            'train_precision': [], 'val_precision': [],
            'train_recall': [], 'val_recall': [],
            'train_f1': [], 'val_f1': []
        }
        
        best_val_auc = 0
        
        print(f"\nStarting training for fold {fold + 1}...")
        
        for epoch in range(Config.EPOCHS):
            print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
            print("-" * 30)
            
            # Training
            train_metrics = train_epoch(
                model, train_loader, optimizer, criterion, 
                Config.DEVICE, epoch, Config.ACCUMULATION_STEPS
            )
            
            # Validation
            val_metrics = validate_epoch(model, val_loader, criterion, Config.DEVICE)
            
            # Update scheduler
            if isinstance(scheduler, optim.lr_scheduler.OneCycleLR):
                # OneCycleLR is stepped per batch, not per epoch
                pass
            else:
                scheduler.step(val_metrics['loss'])
            
            # Record history
            for metric in ['loss', 'accuracy', 'auc', 'precision', 'recall', 'f1']:
                history[f'train_{metric}'].append(train_metrics.get(metric, 0))
                history[f'val_{metric}'].append(val_metrics.get(metric, 0))
            
            # Print metrics
            print(f"Train - Loss: {train_metrics['loss']:.4f}, "
                  f"Acc: {train_metrics['accuracy']:.4f}, "
                  f"AUC: {train_metrics['auc']:.4f}")
            print(f"Val   - Loss: {val_metrics['loss']:.4f}, "
                  f"Acc: {val_metrics['accuracy']:.4f}, "
                  f"AUC: {val_metrics['auc']:.4f}")
            
            # Save best model
            if val_metrics['auc'] > best_val_auc:
                best_val_auc = val_metrics['auc']
                
                checkpoint = {
                    'fold': fold,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_metrics': train_metrics,
                    'val_metrics': val_metrics,
                    'best_val_auc': best_val_auc,
                    'history': history,
                    'config': Config.__dict__.copy()
                }
                
                torch.save(checkpoint, f'fold_{fold}_best_model.pth')
                print(f"Saved best model (AUC: {val_metrics['auc']:.4f})")
            
            # Early stopping
            if early_stopping(val_metrics['loss'], model):
                print(f"Early stopping at epoch {epoch+1}")
                early_stopping.restore(model)
                break
            
            # Memory cleanup
            if not TPU_AVAILABLE:
                torch.cuda.empty_cache()
            gc.collect()
        
        # Final evaluation
        print(f"\n🔍 Final evaluation for fold {fold + 1}")
        final_metrics = comprehensive_evaluation(model, val_loader, Config.DEVICE)
        
        fold_result = {
            'fold': fold + 1,
            'best_val_auc': best_val_auc,
            'final_metrics': final_metrics,
            'history': history
        }
        fold_results.append(fold_result)
        
        # Plot training history for this fold
        plot_training_history(history)
        plt.title(f'Fold {fold + 1} Training History')
        plt.savefig(f'fold_{fold}_history.png', dpi=300, bbox_inches='tight')
        
        print(f"Fold {fold + 1} completed!")
        print(f"Best AUC: {best_val_auc:.4f}")
        
        # Print processor stats
        processor_stats = processor.get_stats()
        print(f"📊 DICOM Processing Stats: {processor_stats}")
        
        if Config.CURRENT_FOLD >= 0:
            break  # Only train specified fold
    
    return fold_results

def main_training():
    print("ADVANCED ANEURYSM DETECTION TRAINING")
    print("=" * 60)
    print(f"Device: {Config.DEVICE}")
    print(f"TPU Available: {TPU_AVAILABLE}")
    print(f"Debug Mode: {Config.DEBUG_MODE}")
    
    # Load and validate data
    print("\nLoading training data...")
    try:
        train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
        print(f"Loaded {len(train_df)} samples from CSV")
    except Exception as e:
        print(f"Error loading CSV: {e}")
        return None
    
    # Data validation and preprocessing
    print("\nData validation...")
    print(f"Original samples: {len(train_df)}")
    print(f"Columns: {list(train_df.columns)}")
    
    # Check required columns
    if Config.ID_COL not in train_df.columns:
        print(f"Missing ID column: {Config.ID_COL}")
        return None
    
    if Config.TARGET_COL not in train_df.columns:
        print(f"Missing target column: {Config.TARGET_COL}")
        return None
    
    # Clean data
    initial_len = len(train_df)
    train_df = train_df.dropna(subset=[Config.ID_COL, Config.TARGET_COL])
    print(f"After removing NaN: {len(train_df)} (removed {initial_len - len(train_df)})")
    
    # Debug mode
    if Config.DEBUG_MODE:
        train_df = train_df.sample(n=min(Config.DEBUG_SAMPLES, len(train_df)), 
                                  random_state=42).reset_index(drop=True)
        print(f"Debug mode: using {len(train_df)} samples")
    
    # Class distribution
    class_dist = train_df[Config.TARGET_COL].value_counts().sort_index()
    print(f"\nClass Distribution:")
    for label, count in class_dist.items():
        percentage = (count / len(train_df)) * 100
        print(f"   Class {label}: {count} samples ({percentage:.1f}%)")
    
    # Check if data is too imbalanced
    min_class_count = class_dist.min()
    if min_class_count < 10:
        print(f"Warning: Very few samples in minority class ({min_class_count})")
        print("Consider collecting more data or adjusting class weights")
    
    # Verify series directories exist
    print(f"\nVerifying series directories...")
    existing_series = 0
    sample_check = train_df[Config.ID_COL].head(100)  # Check first 100
    
    for series_id in sample_check:
        series_path = os.path.join(Config.SERIES_DIR, str(series_id))
        if os.path.exists(series_path):
            existing_series += 1
    
    existence_rate = (existing_series / len(sample_check)) * 100
    print(f"Series existence rate: {existence_rate:.1f}% ({existing_series}/{len(sample_check)} checked)")
    
    if existence_rate < 50:
        print("Warning: Many series directories are missing!")
        print("This may significantly impact training performance.")
    
    # Start cross-validation training
    print(f"\nStarting cross-validation training...")
    fold_results = cross_validation_training(train_df, Config.SERIES_DIR)
    
    # Summarize results
    if fold_results:
        print(f"\nCROSS-VALIDATION RESULTS SUMMARY")
        print("=" * 50)
        
        all_aucs = [result['best_val_auc'] for result in fold_results]
        mean_auc = np.mean(all_aucs)
        std_auc = np.std(all_aucs)
        
        print(f"Mean AUC: {mean_auc:.4f} ± {std_auc:.4f}")
        print(f"Best AUC: {max(all_aucs):.4f}")
        print(f"Worst AUC: {min(all_aucs):.4f}")
        
        for i, result in enumerate(fold_results):
            print(f"Fold {result['fold']}: AUC = {result['best_val_auc']:.4f}")
        
        # Save results
        results_summary = {
            'mean_auc': mean_auc,
            'std_auc': std_auc,
            'fold_results': fold_results,
            'config': Config.__dict__.copy()
        }
        
        with open('cv_results.json', 'w') as f:
            # Convert non-serializable objects to strings
            serializable_results = json.dumps(results_summary, default=str, indent=2)
            f.write(serializable_results)
        
        print(f"\nResults saved to cv_results.json")
        return fold_results
    else:
        print("No results to summarize")
        return None

In [None]:
if __name__ == "__main__":
    results = main_training()
    
    if results:
        print(f"\nTraining completed successfully!")
        print(f"Check the following files:")
        print(f"   - fold_*_best_model.pth: Best model checkpoints")
        print(f"   - fold_*_history.png: Training history plots") 
        print(f"   - evaluation_results.png: Evaluation plots")
        print(f"   - cv_results.json: Cross-validation summary")
    else:
        print(f"\nTraining failed. Please check the error messages above.")

**Model Evaluation and Performance**

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    roc_curve, precision_recall_curve, confusion_matrix, classification_report,
    average_precision_score, balanced_accuracy_score, matthews_corrcoef,
    cohen_kappa_score, log_loss, brier_score_loss
)
from sklearn.calibration import calibration_curve
from scipy import stats
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
from tqdm.auto import tqdm
import warnings
import json
from datetime import datetime
from collections import defaultdict
import itertools
from scipy.stats import bootstrap
warnings.filterwarnings('ignore')

class AdvancedModelEvaluator:
    """
    Comprehensive model evaluation suite with advanced metrics and visualizations
    """
    
    def __init__(self, model_paths, data_loader, device, class_names=None, save_dir="evaluation_results"):
        self.model_paths = model_paths if isinstance(model_paths, list) else [model_paths]
        self.data_loader = data_loader
        self.device = device
        self.class_names = class_names or ['No Aneurysm', 'Aneurysm']
        self.save_dir = save_dir
        
        # Create save directory
        os.makedirs(save_dir, exist_ok=True)
        
        # Storage for results
        self.results = {}
        self.predictions = {}
        self.raw_outputs = {}
        
    def load_model_predictions(self, model_class):
        """Load models and generate predictions"""
        print("Loading models and generating predictions...")
        
        for i, model_path in enumerate(self.model_paths):
            print(f"\nProcessing model: {model_path}")
            
            # Load model
            model = model_class().to(self.device)
            
            if os.path.exists(model_path):
                checkpoint = torch.load(model_path, map_location=self.device)
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
            else:
                print(f"Model file not found: {model_path}")
                continue
            
            # Generate predictions
            model.eval()
            predictions = []
            probabilities = []
            labels = []
            logits = []
            series_ids = []
            
            with torch.no_grad():
                for batch in tqdm(self.data_loader, desc=f"Model {i+1} Inference"):
                    volume = batch['volume'].to(self.device)
                    label = batch['label']
                    series_id = batch.get('series_id', [f'sample_{j}' for j in range(len(label))])
                    
                    outputs = model(volume)
                    probs = torch.sigmoid(outputs).cpu().numpy().flatten()
                    preds = (probs > 0.5).astype(int)
                    
                    predictions.extend(preds)
                    probabilities.extend(probs)
                    labels.extend(label.numpy())
                    logits.extend(outputs.cpu().numpy().flatten())
                    series_ids.extend(series_id)
            
            model_name = f"Model_{i+1}" if len(self.model_paths) > 1 else "Model"
            
            self.predictions[model_name] = {
                'predictions': np.array(predictions),
                'probabilities': np.array(probabilities),
                'labels': np.array(labels),
                'logits': np.array(logits),
                'series_ids': series_ids,
                'model_path': model_path
            }
    
    def calculate_comprehensive_metrics(self):
        print("\nCalculating comprehensive metrics...")
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_pred = data['predictions']
            y_proba = data['probabilities']
            y_logits = data['logits']
            
            # Basic classification metrics
            basic_metrics = {
                'accuracy': accuracy_score(y_true, y_pred),
                'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
                'precision': precision_score(y_true, y_pred, zero_division=0),
                'recall': recall_score(y_true, y_pred, zero_division=0),
                'specificity': recall_score(1 - y_true, 1 - y_pred, zero_division=0),
                'f1_score': f1_score(y_true, y_pred, zero_division=0),
                'matthews_corr': matthews_corrcoef(y_true, y_pred),
                'cohen_kappa': cohen_kappa_score(y_true, y_pred),
            }
            
            # Probabilistic metrics
            if len(np.unique(y_true)) > 1:
                prob_metrics = {
                    'roc_auc': roc_auc_score(y_true, y_proba),
                    'pr_auc': average_precision_score(y_true, y_proba),
                    'log_loss': log_loss(y_true, y_proba),
                    'brier_score': brier_score_loss(y_true, y_proba),
                }
            else:
                prob_metrics = {
                    'roc_auc': 0.5,
                    'pr_auc': np.mean(y_true),
                    'log_loss': float('inf'),
                    'brier_score': float('inf'),
                }
            
            # Threshold-dependent metrics
            threshold_metrics = self._calculate_threshold_metrics(y_true, y_proba)
            
            # Confidence and calibration metrics
            calibration_metrics = self._calculate_calibration_metrics(y_true, y_proba)
            
            # Class-wise metrics
            class_metrics = self._calculate_class_wise_metrics(y_true, y_pred, y_proba)
            
            # Ensemble metrics (if multiple models)
            ensemble_metrics = {}
            if len(self.predictions) > 1:
                ensemble_metrics = self._calculate_ensemble_metrics()
            
            # Combine all metrics
            all_metrics = {
                **basic_metrics,
                **prob_metrics,
                **threshold_metrics,
                **calibration_metrics,
                **class_metrics,
                **ensemble_metrics
            }
            
            self.results[model_name] = all_metrics
    
    def _calculate_threshold_metrics(self, y_true, y_proba):
        thresholds = np.arange(0.1, 1.0, 0.1)
        threshold_results = {}
        
        for threshold in thresholds:
            y_pred_thresh = (y_proba >= threshold).astype(int)
            
            if len(np.unique(y_pred_thresh)) > 1:
                threshold_results[f'f1_thresh_{threshold:.1f}'] = f1_score(y_true, y_pred_thresh)
                threshold_results[f'precision_thresh_{threshold:.1f}'] = precision_score(y_true, y_pred_thresh, zero_division=0)
                threshold_results[f'recall_thresh_{threshold:.1f}'] = recall_score(y_true, y_pred_thresh, zero_division=0)
        
        # Find optimal threshold
        if len(np.unique(y_true)) > 1:
            fpr, tpr, thresholds_roc = roc_curve(y_true, y_proba)
            optimal_idx = np.argmax(tpr - fpr)
            optimal_threshold = thresholds_roc[optimal_idx]
            
            threshold_results['optimal_threshold'] = optimal_threshold
            threshold_results['optimal_f1'] = f1_score(y_true, (y_proba >= optimal_threshold).astype(int))
        
        return threshold_results
    
    def _calculate_calibration_metrics(self, y_true, y_proba):
        # Calibration curve
        if len(np.unique(y_true)) > 1:
            prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
            calibration_error = np.mean(np.abs(prob_true - prob_pred))
            
            # Expected Calibration Error (ECE)
            bin_boundaries = np.linspace(0, 1, 11)
            bin_lowers = bin_boundaries[:-1]
            bin_uppers = bin_boundaries[1:]
            
            ece = 0
            for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
                in_bin = (y_proba > bin_lower) & (y_proba <= bin_upper)
                prop_in_bin = in_bin.mean()
                
                if prop_in_bin > 0:
                    accuracy_in_bin = y_true[in_bin].mean()
                    avg_confidence_in_bin = y_proba[in_bin].mean()
                    ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        else:
            calibration_error = float('inf')
            ece = float('inf')
        
        # Confidence metrics
        confidence_metrics = {
            'mean_confidence': np.mean(y_proba),
            'std_confidence': np.std(y_proba),
            'confidence_range': np.ptp(y_proba),
            'calibration_error': calibration_error,
            'expected_calibration_error': ece,
        }
        
        return confidence_metrics
    
    def _calculate_class_wise_metrics(self, y_true, y_pred, y_proba):
        cm = confusion_matrix(y_true, y_pred)
        
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            
            class_metrics = {
                'true_positives': tp,
                'true_negatives': tn,
                'false_positives': fp,
                'false_negatives': fn,
                'sensitivity': tp / (tp + fn) if (tp + fn) > 0 else 0,
                'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
                'positive_predictive_value': tp / (tp + fp) if (tp + fp) > 0 else 0,
                'negative_predictive_value': tn / (tn + fn) if (tn + fn) > 0 else 0,
                'false_positive_rate': fp / (fp + tn) if (fp + tn) > 0 else 0,
                'false_negative_rate': fn / (fn + tp) if (fn + tp) > 0 else 0,
                'positive_likelihood_ratio': (tp/(tp+fn)) / (fp/(fp+tn)) if fp > 0 else float('inf'),
                'negative_likelihood_ratio': (fn/(fn+tp)) / (tn/(tn+fp)) if tn > 0 else float('inf'),
            }
        else:
            class_metrics = {}
        
        return class_metrics
    
    def _calculate_ensemble_metrics(self):
        if len(self.predictions) < 2:
            return {}
        
        # Average predictions across models
        all_probas = []
        all_preds = []
        y_true = None
        
        for model_name, data in self.predictions.items():
            all_probas.append(data['probabilities'])
            all_preds.append(data['predictions'])
            if y_true is None:
                y_true = data['labels']
        
        # Ensemble by averaging probabilities
        ensemble_proba = np.mean(all_probas, axis=0)
        ensemble_pred = (ensemble_proba > 0.5).astype(int)
        
        # Ensemble by majority voting
        majority_pred = np.round(np.mean(all_preds, axis=0)).astype(int)
        
        ensemble_metrics = {
            'ensemble_accuracy': accuracy_score(y_true, ensemble_pred),
            'ensemble_f1': f1_score(y_true, ensemble_pred),
            'ensemble_auc': roc_auc_score(y_true, ensemble_proba) if len(np.unique(y_true)) > 1 else 0.5,
            'majority_vote_accuracy': accuracy_score(y_true, majority_pred),
            'majority_vote_f1': f1_score(y_true, majority_pred),
        }
        
        # Model agreement
        agreement_matrix = np.array(all_preds)
        model_agreement = np.mean(np.std(agreement_matrix, axis=0) == 0)  # Fraction of samples where all models agree
        
        ensemble_metrics['model_agreement'] = model_agreement
        
        return ensemble_metrics
    
    def create_comprehensive_visualizations(self):
        print("\nCreating comprehensive visualizations...")
        
        # Set style
        plt.style.use('seaborn-v0_8')
        colors = ['#2E8B57', '#DC143C', '#4169E1', '#FFD700', '#8A2BE2']
        
        # 1. Performance Overview Dashboard
        self._create_performance_dashboard()
        
        # 2. ROC Curves Comparison
        self._create_roc_comparison()
        
        # 3. Precision-Recall Curves
        self._create_pr_comparison()
        
        # 4. Calibration Plots
        self._create_calibration_plots()
        
        # 5. Confusion Matrices
        self._create_confusion_matrices()
        
        # 6. Threshold Analysis
        self._create_threshold_analysis()
        
        # 7. Prediction Distribution Analysis
        self._create_prediction_distributions()
        
        # 8. Error Analysis
        self._create_error_analysis()
        
        # 9. Interactive Plotly Visualizations
        self._create_interactive_plots()
        
        # 10. Statistical Significance Tests
        self._create_statistical_tests()
    
    def _create_performance_dashboard(self):
        n_models = len(self.predictions)
        fig, axes = plt.subplots(3, 4, figsize=(20, 15))
        fig.suptitle('Model Performance Dashboard', fontsize=16, fontweight='bold')
        
        # Key metrics comparison
        metrics_to_plot = ['accuracy', 'f1_score', 'roc_auc', 'pr_auc', 
                          'precision', 'recall', 'specificity', 'matthews_corr']
        
        for i, metric in enumerate(metrics_to_plot):
            ax = axes[i // 4, i % 4]
            
            values = [self.results[model][metric] for model in self.results.keys()]
            model_names = list(self.results.keys())
            
            bars = ax.bar(model_names, values, color=colors[:len(values)])
            ax.set_title(f'{metric.replace("_", " ").title()}', fontweight='bold')
            ax.set_ylabel('Score')
            ax.set_ylim(0, 1)
            
            # Add value labels on bars
            for bar, value in zip(bars, values):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{value:.3f}', ha='center', va='bottom')
            
            ax.tick_params(axis='x', rotation=45)
        
        # Remove empty subplots
        for j in range(len(metrics_to_plot), 12):
            axes[j // 4, j % 4].remove()
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/performance_dashboard.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_roc_comparison(self):
        plt.figure(figsize=(10, 8))
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                fpr, tpr, _ = roc_curve(y_true, y_proba)
                auc = roc_auc_score(y_true, y_proba)
                
                plt.plot(fpr, tpr, color=colors[i], lw=2,
                        label=f'{model_name} (AUC = {auc:.3f})')
        
        plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', alpha=0.8)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title('ROC Curve Comparison', fontsize=14, fontweight='bold')
        plt.legend(loc="lower right")
        plt.grid(True, alpha=0.3)
        
        plt.savefig(f'{self.save_dir}/roc_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_pr_comparison(self):
        plt.figure(figsize=(10, 8))
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                precision, recall, _ = precision_recall_curve(y_true, y_proba)
                avg_precision = average_precision_score(y_true, y_proba)
                
                plt.plot(recall, precision, color=colors[i], lw=2,
                        label=f'{model_name} (AP = {avg_precision:.3f})')
        
        # Baseline
        baseline = np.mean([data['labels'] for data in self.predictions.values()][0])
        plt.axhline(y=baseline, color='gray', linestyle='--', alpha=0.8, 
                   label=f'Baseline (AP = {baseline:.3f})')
        
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Precision-Recall Curve Comparison', fontsize=14, fontweight='bold')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.savefig(f'{self.save_dir}/pr_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_calibration_plots(self):
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Calibration curve
        ax1 = axes[0]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
                ax1.plot(prob_pred, prob_true, marker='o', color=colors[i], 
                        label=model_name, linewidth=2, markersize=6)
        
        ax1.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
        ax1.set_xlabel('Mean Predicted Probability')
        ax1.set_ylabel('Fraction of Positives')
        ax1.set_title('Calibration Plot (Reliability Diagram)')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Confidence histogram
        ax2 = axes[1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_proba = data['probabilities']
            ax2.hist(y_proba, bins=20, alpha=0.7, color=colors[i], 
                    label=model_name, density=True)
        
        ax2.set_xlabel('Predicted Probability')
        ax2.set_ylabel('Density')
        ax2.set_title('Confidence Distribution')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/calibration_plots.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_confusion_matrices(self):
        n_models = len(self.predictions)
        cols = min(3, n_models)
        rows = (n_models + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 4*rows))
        if n_models == 1:
            axes = [axes]
        elif rows == 1:
            axes = [axes]
        else:
            axes = axes.flatten()
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_pred = data['predictions']
            
            cm = confusion_matrix(y_true, y_pred)
            
            # Normalize confusion matrix
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
            ax = axes[i] if n_models > 1 else axes[0]
            
            # Create heatmap
            sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
                       xticklabels=self.class_names, yticklabels=self.class_names,
                       ax=ax, cbar_kws={'shrink': 0.8})
            
            ax.set_title(f'{model_name}\nNormalized Confusion Matrix')
            ax.set_ylabel('True Label')
            ax.set_xlabel('Predicted Label')
            
            # Add counts
            for j in range(cm.shape[0]):
                for k in range(cm.shape[1]):
                    ax.text(k+0.5, j+0.7, f'({cm[j,k]})', 
                           ha='center', va='center', fontsize=10, color='red')
        
        # Remove empty subplots
        for j in range(n_models, len(axes)):
            if j < len(axes):
                fig.delaxes(axes[j])
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/confusion_matrices.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_threshold_analysis(self):
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Threshold vs F1 Score
        ax1 = axes[0, 0]
        thresholds = np.arange(0.1, 1.0, 0.01)
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            f1_scores = []
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                f1_scores.append(f1_score(y_true, y_pred_thresh))
            
            ax1.plot(thresholds, f1_scores, color=colors[i], label=model_name, linewidth=2)
            
            # Mark optimal threshold
            optimal_idx = np.argmax(f1_scores)
            ax1.scatter(thresholds[optimal_idx], f1_scores[optimal_idx], 
                       color=colors[i], s=100, marker='*', zorder=5)
        
        ax1.set_xlabel('Threshold')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('F1 Score vs Threshold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Threshold vs Precision/Recall
        ax2 = axes[0, 1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            precisions = []
            recalls = []
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                precisions.append(precision_score(y_true, y_pred_thresh, zero_division=0))
                recalls.append(recall_score(y_true, y_pred_thresh, zero_division=0))
            
            ax2.plot(thresholds, precisions, color=colors[i], linestyle='-', 
                    label=f'{model_name} Precision', linewidth=2)
            ax2.plot(thresholds, recalls, color=colors[i], linestyle='--', 
                    label=f'{model_name} Recall', linewidth=2)
        
        ax2.set_xlabel('Threshold')
        ax2.set_ylabel('Score')
        ax2.set_title('Precision/Recall vs Threshold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Class distribution by confidence
        ax3 = axes[1, 0]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            # Separate by true class
            pos_probs = y_proba[y_true == 1]
            neg_probs = y_proba[y_true == 0]
            
            ax3.hist(neg_probs, bins=20, alpha=0.5, color=colors[i], 
                    label=f'{model_name} Negative', density=True)
            ax3.hist(pos_probs, bins=20, alpha=0.5, color=colors[i], 
                    label=f'{model_name} Positive', density=True, hatch='///')
        
        ax3.set_xlabel('Predicted Probability')
        ax3.set_ylabel('Density')
        ax3.set_title('Probability Distribution by True Class')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Sensitivity/Specificity vs Threshold
        ax4 = axes[1, 1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            sensitivities = []
            specificities = []
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                sensitivities.append(recall_score(y_true, y_pred_thresh, zero_division=0))
                specificities.append(recall_score(1 - y_true, 1 - y_pred_thresh, zero_division=0))
            
            ax4.plot(thresholds, sensitivities, color=colors[i], linestyle='-', 
                    label=f'{model_name} Sensitivity', linewidth=2)
            ax4.plot(thresholds, specificities, color=colors[i], linestyle='--', 
                    label=f'{model_name} Specificity', linewidth=2)
        
        ax4.set_xlabel('Threshold')
        ax4.set_ylabel('Score')
        ax4.set_title('Sensitivity/Specificity vs Threshold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/threshold_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_prediction_distributions(self):
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Box plots of predictions by true class
        ax1 = axes[0, 0]
        box_data = []
        box_labels = []
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_proba = data['probabilities']
            
            for class_idx, class_name in enumerate(self.class_names):
                class_probs = y_proba[y_true == class_idx]
                box_data.append(class_probs)
                box_labels.append(f'{model_name}\n{class_name}')
        
        ax1.boxplot(box_data, labels=box_labels)
        ax1.set_title('Prediction Distribution by Model and True Class')
        ax1.set_ylabel('Predicted Probability')
        ax1.tick_params(axis='x', rotation=45)
        ax1.grid(True, alpha=0.3)
        
        # Violin plots
        ax2 = axes[0, 1]
        positions = np.arange(1, len(box_data) + 1)
        parts = ax2.violinplot(box_data, positions=positions, showmeans=True, showmedians=True)
        ax2.set_xticks(positions)
        ax2.set_xticklabels(box_labels, rotation=45)
        ax2.set_title('Prediction Distribution (Violin Plot)')
        ax2.set_ylabel('Predicted Probability')
        ax2.grid(True, alpha=0.3)
        
        # Prediction confidence vs accuracy
        ax3 = axes[1, 0]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_pred = data['predictions']
            y_proba = data['probabilities']
            
            # Bin predictions by confidence
            confidence_bins = np.linspace(0, 1, 11)
            bin_accuracies = []
            bin_centers = []
            
            for j in range(len(confidence_bins) - 1):
                mask = (y_proba >= confidence_bins[j]) & (y_proba < confidence_bins[j + 1])
                if mask.sum() > 0:
                    bin_accuracy = (y_pred[mask] == y_true[mask]).mean()
                    bin_accuracies.append(bin_accuracy)
                    bin_centers.append((confidence_bins[j] + confidence_bins[j + 1]) / 2)
            
            ax3.plot(bin_centers, bin_accuracies, 'o-', color=colors[i], 
                    label=model_name, linewidth=2, markersize=6)
        
        ax3.set_xlabel('Prediction Confidence (Binned)')
        ax3.set_ylabel('Accuracy')
        ax3.set_title('Confidence vs Accuracy')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Error distribution
        ax4 = axes[1, 1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_pred = data['predictions']
            
            # Calculate error types
            tp = ((y_true == 1) & (y_pred == 1)).sum()
            tn = ((y_true == 0) & (y_pred == 0)).sum()
            fp = ((y_true == 0) & (y_pred == 1)).sum()
            fn = ((y_true == 1) & (y_pred == 0)).sum()
            
            categories = ['True Pos', 'True Neg', 'False Pos', 'False Neg']
            values = [tp, tn, fp, fn]
            
            x_pos = np.arange(len(categories)) + i * 0.2
            ax4.bar(x_pos, values, width=0.2, color=colors[i], 
                   label=model_name, alpha=0.8)
        
        ax4.set_xlabel('Prediction Type')
        ax4.set_ylabel('Count')
        ax4.set_title('Error Type Distribution')
        ax4.set_xticks(np.arange(len(categories)) + 0.1)
        ax4.set_xticklabels(categories)
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/prediction_distributions.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_error_analysis(self):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            if i >= 6:  # Limit to 6 models for visualization
                break
                
            y_true = data['labels']
            y_pred = data['predictions']
            y_proba = data['probabilities']
            
            ax = axes[i // 3, i % 3]
            
            # Create error analysis scatter plot
            correct = (y_pred == y_true)
            
            # Plot correct predictions
            ax.scatter(y_proba[correct], y_true[correct], 
                      alpha=0.6, c='green', label='Correct', s=20)
            
            # Plot incorrect predictions
            ax.scatter(y_proba[~correct], y_true[~correct], 
                      alpha=0.6, c='red', label='Incorrect', s=20, marker='x')
            
            ax.set_xlabel('Predicted Probability')
            ax.set_ylabel('True Label')
            ax.set_title(f'{model_name} - Error Analysis')
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            # Add decision boundary
            ax.axvline(x=0.5, color='black', linestyle='--', alpha=0.5)
        
        # Remove empty subplots
        for j in range(len(self.predictions), 6):
            axes[j // 3, j % 3].remove()
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/error_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_interactive_plots(self):
        print("Creating interactive visualizations...")
        
        # Interactive ROC curve
        fig_roc = go.Figure()
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                fpr, tpr, thresholds = roc_curve(y_true, y_proba)
                auc = roc_auc_score(y_true, y_proba)
                
                fig_roc.add_trace(go.Scatter(
                    x=fpr, y=tpr,
                    mode='lines',
                    name=f'{model_name} (AUC = {auc:.3f})',
                    line=dict(width=3),
                    hovertemplate='FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>'
                ))
        
        # Add diagonal line
        fig_roc.add_trace(go.Scatter(
            x=[0, 1], y=[0, 1],
            mode='lines',
            name='Random Classifier',
            line=dict(dash='dash', color='gray', width=2)
        ))
        
        fig_roc.update_layout(
            title='Interactive ROC Curve Comparison',
            xaxis_title='False Positive Rate',
            yaxis_title='True Positive Rate',
            width=800, height=600,
            hovermode='closest'
        )
        
        fig_roc.write_html(f'{self.save_dir}/interactive_roc.html')
        
        # Interactive threshold analysis
        fig_threshold = make_subplots(
            rows=2, cols=2,
            subplot_titles=('F1 Score vs Threshold', 'Precision/Recall vs Threshold',
                          'Sensitivity/Specificity vs Threshold', 'Sample Count vs Threshold')
        )
        
        thresholds = np.arange(0.1, 1.0, 0.01)
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            # Calculate metrics for each threshold
            f1_scores = []
            precisions = []
            recalls = []
            specificities = []
            sample_counts = []
            
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                f1_scores.append(f1_score(y_true, y_pred_thresh))
                precisions.append(precision_score(y_true, y_pred_thresh, zero_division=0))
                recalls.append(recall_score(y_true, y_pred_thresh, zero_division=0))
                specificities.append(recall_score(1 - y_true, 1 - y_pred_thresh, zero_division=0))
                sample_counts.append(y_pred_thresh.sum())
            
            # F1 Score
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=f1_scores, name=f'{model_name} F1',
                          line=dict(width=2)),
                row=1, col=1
            )
            
            # Precision/Recall
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=precisions, name=f'{model_name} Precision',
                          line=dict(width=2)),
                row=1, col=2
            )
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=recalls, name=f'{model_name} Recall',
                          line=dict(width=2, dash='dash')),
                row=1, col=2
            )
            
            # Sensitivity
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=recalls, name=f'{model_name} Sensitivity',
                          line=dict(width=2)),
                row=2, col=1
            )
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=specificities, name=f'{model_name} Specificity',
                          line=dict(width=2, dash='dash')),
                row=2, col=1
            )
            
            # Sample count
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=sample_counts, name=f'{model_name} Positive Predictions',
                          line=dict(width=2)),
                row=2, col=2
            )
        
        fig_threshold.update_layout(
            title='Interactive Threshold Analysis',
            width=1200, height=800,
            hovermode='x unified'
        )
        
        fig_threshold.write_html(f'{self.save_dir}/interactive_threshold.html')
        
        # Interactive calibration plot
        fig_cal = go.Figure()
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
                
                fig_cal.add_trace(go.Scatter(
                    x=prob_pred, y=prob_true,
                    mode='markers+lines',
                    name=model_name,
                    marker=dict(size=8),
                    line=dict(width=3),
                    hovertemplate='Predicted: %{x:.3f}<br>Actual: %{y:.3f}<extra></extra>'
                ))
        
        # Perfect calibration line
        fig_cal.add_trace(go.Scatter(
            x=[0, 1], y=[0, 1],
            mode='lines',
            name='Perfect Calibration',
            line=dict(dash='dash', color='gray', width=2)
        ))
        
        fig_cal.update_layout(
            title='Interactive Calibration Plot',
            xaxis_title='Mean Predicted Probability',
            yaxis_title='Fraction of Positives',
            width=800, height=600
        )
        
        fig_cal.write_html(f'{self.save_dir}/interactive_calibration.html')
    
    def _create_statistical_tests(self):
        print("Performing statistical significance tests...")
        
        if len(self.predictions) < 2:
            print("Need at least 2 models for statistical comparison")
            return
        
        # McNemar's test for paired predictions
        results_text = "STATISTICAL SIGNIFICANCE TESTS\n" + "="*50 + "\n\n"
        
        model_names = list(self.predictions.keys())
        n_models = len(model_names)
        
        # Create comparison matrix
        comparison_results = pd.DataFrame(index=model_names, columns=model_names)
        
        for i in range(n_models):
            for j in range(i+1, n_models):
                model1_name = model_names[i]
                model2_name = model_names[j]
                
                y_true = self.predictions[model1_name]['labels']
                pred1 = self.predictions[model1_name]['predictions']
                pred2 = self.predictions[model2_name]['predictions']
                
                # McNemar's test
                # Create contingency table
                correct1 = (pred1 == y_true)
                correct2 = (pred2 == y_true)
                
                both_correct = np.sum(correct1 & correct2)
                model1_only = np.sum(correct1 & ~correct2)
                model2_only = np.sum(~correct1 & correct2)
                both_wrong = np.sum(~correct1 & ~correct2)
                
                # McNemar's test statistic
                if (model1_only + model2_only) > 0:
                    mcnemar_stat = (abs(model1_only - model2_only) - 1)**2 / (model1_only + model2_only)
                    p_value = 1 - stats.chi2.cdf(mcnemar_stat, 1)
                else:
                    mcnemar_stat = 0
                    p_value = 1.0
                
                comparison_results.loc[model1_name, model2_name] = f"p={p_value:.4f}"
                comparison_results.loc[model2_name, model1_name] = f"χ²={mcnemar_stat:.4f}"
                
                results_text += f"McNemar's Test: {model1_name} vs {model2_name}\n"
                results_text += f"  Contingency Table:\n"
                results_text += f"    Both correct: {both_correct}\n"
                results_text += f"    {model1_name} only: {model1_only}\n"
                results_text += f"    {model2_name} only: {model2_only}\n"
                results_text += f"    Both wrong: {both_wrong}\n"
                results_text += f"  Chi-square statistic: {mcnemar_stat:.4f}\n"
                results_text += f"  P-value: {p_value:.4f}\n"
                
                if p_value < 0.05:
                    results_text += f"  Result: Significant difference (p < 0.05)\n"
                else:
                    results_text += f"  Result: No significant difference (p >= 0.05)\n"
                results_text += "\n"
        
        # Bootstrap confidence intervals for AUC
        results_text += "BOOTSTRAP CONFIDENCE INTERVALS (AUC)\n" + "-"*40 + "\n"
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                # Bootstrap sampling
                n_bootstrap = 1000
                bootstrap_aucs = []
                
                for _ in range(n_bootstrap):
                    # Sample with replacement
                    indices = np.random.choice(len(y_true), size=len(y_true), replace=True)
                    y_true_boot = y_true[indices]
                    y_proba_boot = y_proba[indices]
                    
                    if len(np.unique(y_true_boot)) > 1:
                        auc_boot = roc_auc_score(y_true_boot, y_proba_boot)
                        bootstrap_aucs.append(auc_boot)
                
                if bootstrap_aucs:
                    ci_lower = np.percentile(bootstrap_aucs, 2.5)
                    ci_upper = np.percentile(bootstrap_aucs, 97.5)
                    mean_auc = np.mean(bootstrap_aucs)
                    
                    results_text += f"{model_name}:\n"
                    results_text += f"  Mean AUC: {mean_auc:.4f}\n"
                    results_text += f"  95% CI: [{ci_lower:.4f}, {ci_upper:.4f}]\n\n"
        
        # Save results
        with open(f'{self.save_dir}/statistical_tests.txt', 'w') as f:
            f.write(results_text)
        
        # Create visualization of statistical tests
        plt.figure(figsize=(10, 8))
        
        # Plot AUC with confidence intervals
        model_names = []
        auc_values = []
        ci_lowers = []
        ci_uppers = []
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                # Bootstrap for CI
                n_bootstrap = 1000
                bootstrap_aucs = []
                
                for _ in range(n_bootstrap):
                    indices = np.random.choice(len(y_true), size=len(y_true), replace=True)
                    y_true_boot = y_true[indices]
                    y_proba_boot = y_proba[indices]
                    
                    if len(np.unique(y_true_boot)) > 1:
                        auc_boot = roc_auc_score(y_true_boot, y_proba_boot)
                        bootstrap_aucs.append(auc_boot)
                
                if bootstrap_aucs:
                    model_names.append(model_name)
                    auc_values.append(np.mean(bootstrap_aucs))
                    ci_lowers.append(np.percentile(bootstrap_aucs, 2.5))
                    ci_uppers.append(np.percentile(bootstrap_aucs, 97.5))
        
        if model_names:
            y_pos = np.arange(len(model_names))
            
            plt.errorbar(auc_values, y_pos, 
                        xerr=[np.array(auc_values) - np.array(ci_lowers), 
                              np.array(ci_uppers) - np.array(auc_values)],
                        fmt='o', capsize=5, capthick=2, markersize=8)
            
            plt.yticks(y_pos, model_names)
            plt.xlabel('AUC Score')
            plt.title('Model Performance with 95% Bootstrap Confidence Intervals')
            plt.grid(True, alpha=0.3)
            
            # Add vertical line at 0.5 (random performance)
            plt.axvline(x=0.5, color='red', linestyle='--', alpha=0.5, label='Random')
            plt.legend()
            
            plt.tight_layout()
            plt.savefig(f'{self.save_dir}/statistical_comparison.png', dpi=300, bbox_inches='tight')
            plt.close()
    
    def generate_comprehensive_report(self):
        print("Generating comprehensive report...")
        
        html_content = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>Model Evaluation Report</title>
            <style>
                body { font-family: Arial, sans-serif; margin: 20px; }
                .header { background-color: #f0f0f0; padding: 20px; border-radius: 10px; }
                .section { margin: 20px 0; padding: 15px; border-left: 4px solid #007acc; }
                .metrics-table { border-collapse: collapse; width: 100%; margin: 20px 0; }
                .metrics-table th, .metrics-table td { border: 1px solid #ddd; padding: 8px; text-align: center; }
                .metrics-table th { background-color: #f2f2f2; }
                .image-container { text-align: center; margin: 20px 0; }
                .best-score { background-color: #90EE90; }
                .worst-score { background-color: #FFB6C1; }
            </style>
        </head>
        <body>
        """
        
        # Header
        html_content += f"""
        <div class="header">
            <h1>Advanced Model Evaluation Report</h1>
            <p><strong>Generated on:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
            <p><strong>Number of Models:</strong> {len(self.predictions)}</p>
            <p><strong>Dataset Size:</strong> {len(list(self.predictions.values())[0]['labels'])}</p>
        </div>
        """
        
        # Executive Summary
        html_content += """
        <div class="section">
            <h2>Executive Summary</h2>
        """
        
        # Find best performing model
        best_model = None
        best_auc = 0
        
        for model_name in self.results:
            if self.results[model_name]['roc_auc'] > best_auc:
                best_auc = self.results[model_name]['roc_auc']
                best_model = model_name
        
        if best_model:
            html_content += f"""
            <p><strong>Best Performing Model:</strong> {best_model}</p>
            <p><strong>Best AUC Score:</strong> {best_auc:.4f}</p>
            <p><strong>Best F1 Score:</strong> {self.results[best_model]['f1_score']:.4f}</p>
            <p><strong>Best Accuracy:</strong> {self.results[best_model]['accuracy']:.4f}</p>
            """
        
        html_content += "</div>"
        
        # Detailed Metrics Table
        html_content += """
        <div class="section">
            <h2>Detailed Performance Metrics</h2>
            <table class="metrics-table">
                <tr>
                    <th>Model</th>
                    <th>Accuracy</th>
                    <th>Precision</th>
                    <th>Recall</th>
                    <th>F1 Score</th>
                    <th>AUC</th>
                    <th>Specificity</th>
                    <th>Matthews Corr</th>
                </tr>
        """
        
        # Find best/worst for each metric
        metrics_to_highlight = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc', 'specificity', 'matthews_corr']
        best_values = {}
        worst_values = {}
        
        for metric in metrics_to_highlight:
            values = [self.results[model][metric] for model in self.results]
            best_values[metric] = max(values)
            worst_values[metric] = min(values)
        
        for model_name in self.results:
            html_content += f"<tr><td><strong>{model_name}</strong></td>"
            
            for metric in metrics_to_highlight:
                value = self.results[model_name][metric]
                css_class = ""
                
                if value == best_values[metric] and len(self.results) > 1:
                    css_class = "best-score"
                elif value == worst_values[metric] and len(self.results) > 1:
                    css_class = "worst-score"
                
                html_content += f'<td class="{css_class}">{value:.4f}</td>'
            
            html_content += "</tr>"
        
        html_content += "</table></div>"
        
        # Visualizations
        html_content += """
        <div class="section">
            <h2>Visualizations</h2>
        """
        
        # List all generated plots
        plot_files = [
            'performance_dashboard.png',
            'roc_comparison.png', 
            'pr_comparison.png',
            'calibration_plots.png',
            'confusion_matrices.png',
            'threshold_analysis.png',
            'prediction_distributions.png',
            'error_analysis.png',
            'statistical_comparison.png'
        ]
        
        for plot_file in plot_files:
            if os.path.exists(f'{self.save_dir}/{plot_file}'):
                html_content += f"""
                <div class="image-container">
                    <h3>{plot_file.replace('_', ' ').replace('.png', '').title()}</h3>
                    <img src="{plot_file}" style="max-width: 100%; height: auto;" alt="{plot_file}">
                </div>
                """
        
        html_content += "</div>"
        
        # Interactive Plots
        html_content += """
        <div class="section">
            <h2>Interactive Visualizations</h2>
            <p>The following interactive plots have been generated:</p>
            <ul>
                <li><a href="interactive_roc.html">Interactive ROC Curves</a></li>
                <li><a href="interactive_threshold.html">Interactive Threshold Analysis</a></li>
                <li><a href="interactive_calibration.html">Interactive Calibration Plot</a></li>
            </ul>
        </div>
        """
        
        # Model Recommendations
        html_content += """
        <div class="section">
            <h2>Recommendations</h2>
        """
        
        if best_model and self.results[best_model]['roc_auc'] > 0.8:
            html_content += f"<p><strong>{best_model}</strong> shows excellent performance with AUC > 0.8. This model is ready for deployment consideration.</p>"
        elif best_model and self.results[best_model]['roc_auc'] > 0.7:
            html_content += f"<p><strong>{best_model}</strong> shows good performance with AUC > 0.7, but there's room for improvement.</p>"
        else:
            html_content += "<p>All models show suboptimal performance (AUC < 0.7). Consider:</p>"
            html_content += "<ul><li>Collecting more training data</li><li>Feature engineering</li><li>Different model architectures</li><li>Hyperparameter tuning</li></ul>"
        
        # Check for overfitting signs
        for model_name in self.results:
            if 'calibration_error' in self.results[model_name]:
                cal_error = self.results[model_name]['calibration_error']
                if cal_error > 0.1:
                    html_content += f"<p><strong>{model_name}</strong> shows poor calibration (error: {cal_error:.3f}). Consider calibration techniques.</p>"
        
        html_content += "</div>"
        
        html_content += """
        </body>
        </html>
        """
        
        # Save HTML report
        with open(f'{self.save_dir}/comprehensive_report.html', 'w') as f:
            f.write(html_content)
        
        print(f"Comprehensive report saved to {self.save_dir}/comprehensive_report.html")
    
    def run_complete_evaluation(self, model_class):
        print("Starting comprehensive model evaluation...")
        print("="*60)
        
        # Load models and generate predictions
        self.load_model_predictions(model_class)
        
        if not self.predictions:
            print("No valid predictions generated. Check model paths and data loader.")
            return None
        
        # Calculate comprehensive metrics
        self.calculate_comprehensive_metrics()
        
        # Create visualizations
        self.create_comprehensive_visualizations()
        
        # Generate report
        self.generate_comprehensive_report()
        
        # Save results to JSON
        results_for_json = {}
        for model_name in self.results:
            results_for_json[model_name] = {}
            for metric, value in self.results[model_name].items():
                if isinstance(value, (int, float, np.integer, np.floating)):
                    results_for_json[model_name][metric] = float(value)
                else:
                    results_for_json[model_name][metric] = str(value)
        
        with open(f'{self.save_dir}/detailed_results.json', 'w') as f:
            json.dump(results_for_json, f, indent=2)
        
        print(f"\nEvaluation completed successfully!")
        print(f"Results saved in: {self.save_dir}/")
        print(f"Open {self.save_dir}/comprehensive_report.html for full report")
        
        return self.results


# Usage Example
def run_evaluation_pipeline(model_paths, data_loader, model_class, device, save_dir="evaluation_results"):
    """
    Run the complete evaluation pipeline
    
    Args:
        model_paths: List of model checkpoint paths or single path
        data_loader: PyTorch DataLoader with test/validation data
        model_class: Model class (not instantiated)
        device: torch.device
        save_dir: Directory to save results
    
    Returns:
        Dictionary with evaluation results
    """
    
    # Create evaluator
    evaluator = AdvancedModelEvaluator(
        model_paths=model_paths,
        data_loader=data_loader,
        device=device,
        class_names=['No Aneurysm', 'Aneurysm'],
        save_dir=save_dir
    )
    
    # Run complete evaluation
    results = evaluator.run_complete_evaluation(model_class)
    
    return results