**IMPORTS AND TPU SET UP**

In [None]:
import os
import gc
import json
import warnings
import logging
from pathlib import Path
from collections import Counter
from typing import Dict, List, Tuple, Optional
import numpy as np
import pandas as pd
import cv2
import pydicom
from scipy import ndimage
from scipy.stats import zscore
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    roc_auc_score, confusion_matrix, classification_report,
    roc_curve, precision_recall_curve
)
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as A
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from tqdm import tqdm
import threading
from concurrent.futures import ThreadPoolExecutor
import time

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    TPU_AVAILABLE = xm.xla_device_hw(xm.xla_device()) in ('TPU', 'GPU')
    if TPU_AVAILABLE:
        print(f"TPU detected: {xm.xla_device()}")
        os.environ['XLA_USE_BF16'] = '1'
        os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
except ImportError:
    TPU_AVAILABLE = False
    xm = None
    print("TPU libraries not available, using CPU/GPU")

**TPU OPTMIZED CONFIGURATION AND TRAINING PIPELINE**

In [3]:
class Config:
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    
    # Model parameters
    TARGET_SIZE = (16, 128, 128) 
    BATCH_SIZE = 4 if TPU_AVAILABLE else 8
    EPOCHS = 20
    LEARNING_RATE = 1e-4  
    WEIGHT_DECAY = 1e-4
    WARMUP_EPOCHS = 2
    
    # Training parameters
    GRADIENT_CLIP = 0.5  
    ACCUMULATION_STEPS = 2
    EARLY_STOPPING_PATIENCE = 5
    SCHEDULER_PATIENCE = 3
    
    # Memory management
    MAX_CACHE_SIZE = 50  
    PREFETCH_FACTOR = 2
    
    @staticmethod
    def get_device():
        if TPU_AVAILABLE:
            device = xm.xla_device()
            print(f"Using TPU device: {device}")
            return device
        elif torch.cuda.is_available():
            device = torch.device('cuda')
            torch.cuda.empty_cache()
            print(f"Using CUDA device: {device}")
            print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
            return device
        else:
            device = torch.device('cpu')
            print(f"Using CPU device: {device}")
            return device
    
    DEVICE = get_device()
    
    # Data parameters
    ID_COL = 'SeriesInstanceUID'
    TARGET_COL = 'Aneurysm Present'
    
    # Debug settings
    DEBUG_MODE = True
    DEBUG_SAMPLES = 200 
    
    # Cross-validation - SINGLE FOLD for debugging
    N_FOLDS = 5
    CURRENT_FOLD = 0  
    
    # Data augmentation 
    USE_AUGMENTATION = True
    AUGMENTATION_PROB = 0.2

class OptimizedDICOMProcessor:
    def __init__(self, target_size=None, hu_window=(-1000, 1000)):
        self.target_size = target_size or Config.TARGET_SIZE
        self.max_slices = self.target_size[0]
        self.hu_window = hu_window
        self.stats = {'processed': 0, 'failed': 0, 'dummy': 0}
        self.cache = {}  # Memory cache for processed volumes
        
    def load_dicom_series(self, series_path):
        try:
            if series_path in self.cache:
                self.stats['processed'] += 1
                return self.cache[series_path].copy()
            
            if not os.path.exists(series_path):
                logger.warning(f"Series path does not exist: {series_path}")
                return self._get_dummy_volume()
            
            dicom_files = [f for f in os.listdir(series_path) 
                          if f.lower().endswith('.dcm')][:self.max_slices]
            
            if not dicom_files:
                logger.warning(f"No DICOM files found in: {series_path}")
                return self._get_dummy_volume()
            
            pixel_arrays = []
            target_shape = self.target_size[1:]
            
            for file_name in dicom_files:
                try:
                    file_path = os.path.join(series_path, file_name)
                    ds = pydicom.dcmread(file_path, force=True)
                    
                    if hasattr(ds, 'pixel_array'):
                        arr = ds.pixel_array.astype(np.float32)
                        
                        if arr.ndim == 2:
                            arr = self._simple_preprocess(arr, target_shape)
                            pixel_arrays.append(arr)
                            
                        del ds
                        
                        if len(pixel_arrays) >= self.max_slices:
                            break
                            
                except Exception as e:
                    logger.warning(f"Failed to load {file_name}: {e}")
                    continue
            
            if not pixel_arrays:
                return self._get_dummy_volume()
            
            # Volume
            volume = self._create_volume_efficiently(pixel_arrays)
            
            # Cache
            if len(self.cache) < Config.MAX_CACHE_SIZE:
                self.cache[series_path] = volume.copy()
            
            self.stats['processed'] += 1
            return volume
            
        except Exception as e:
            logger.error(f"Error processing {series_path}: {e}")
            self.stats['failed'] += 1
            return self._get_dummy_volume()
    
    def _simple_preprocess(self, arr, target_shape):
        # Resize
        if arr.shape != target_shape:
            arr = cv2.resize(arr, (target_shape[1], target_shape[0]), 
                           interpolation=cv2.INTER_AREA)
        
        # Simple normalization
        arr = np.clip(arr, *self.hu_window)
        if arr.max() > arr.min():
            arr = (arr - arr.min()) / (arr.max() - arr.min())
        
        return arr.astype(np.float32)
    
    def _create_volume_efficiently(self, pixel_arrays):
        # Pad or truncate to exact size
        if len(pixel_arrays) < self.max_slices:
            # Pad with last slice
            while len(pixel_arrays) < self.max_slices:
                pixel_arrays.append(pixel_arrays[-1])
        else:
            pixel_arrays = pixel_arrays[:self.max_slices]
        
        volume = np.stack(pixel_arrays, axis=0).astype(np.float32)
        
        # Light smoothing
        volume = ndimage.gaussian_filter(volume, sigma=0.3)
        
        return volume
    
    def _get_dummy_volume(self):
        self.stats['dummy'] += 1
        volume = np.random.normal(0.3, 0.1, self.target_size).astype(np.float32)
        return np.clip(volume, 0, 1)
    
    def clear_cache(self):
        self.cache.clear()
        gc.collect()

class LightweightAugmentation:
    def __init__(self, prob=0.2):
        self.prob = prob
        self.transform = A.Compose([
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
            A.GaussNoise(var_limit=(5.0, 25.0), p=0.2),
        ], p=self.prob)
    
    def __call__(self, volume):
        if np.random.random() > self.prob:
            return volume
        
        # Apply to random slices only
        n_slices_to_augment = max(1, volume.shape[0] // 3)
        slice_indices = np.random.choice(volume.shape[0], n_slices_to_augment, replace=False)
        
        for i in slice_indices:
            slice_2d = volume[i]
            slice_uint8 = (slice_2d * 255).astype(np.uint8)
            transformed = self.transform(image=slice_uint8)
            volume[i] = transformed['image'].astype(np.float32) / 255.0
        
        return volume

class OptimizedAneurysmDataset(Dataset):
    def __init__(self, df, series_dir, processor, augmentation=None, mode='train'):
        self.df = df.copy().reset_index(drop=True)
        self.series_dir = series_dir
        self.processor = processor
        self.augmentation = augmentation
        self.mode = mode
        
        # Pre-validate series paths
        self._validate_series_paths()
        
        logger.info(f"Dataset created with {len(self.df)} samples")
        logger.info(f"Positive cases: {self.df[Config.TARGET_COL].sum()}")
        
    def _validate_series_paths(self):
        valid_indices = []
        for idx, row in self.df.iterrows():
            series_id = str(row[Config.ID_COL])
            series_path = os.path.join(self.series_dir, series_id)
            if os.path.exists(series_path):
                valid_indices.append(idx)
            elif len(valid_indices) < len(self.df) * 0.5:
                valid_indices.append(idx)
        
        if len(valid_indices) < len(self.df):
            logger.warning(f"Only {len(valid_indices)}/{len(self.df)} series paths exist")
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            series_id = str(row[Config.ID_COL])
            label = float(row[Config.TARGET_COL])
            
            series_path = os.path.join(self.series_dir, series_id)
            volume = self.processor.load_dicom_series(series_path)
            
            # Apply augmentation
            if self.augmentation and self.mode == 'train':
                volume = self.augmentation(volume)
            
            # Convert to tensor
            volume_tensor = torch.from_numpy(volume).float().unsqueeze(0)
            label_tensor = torch.tensor(label, dtype=torch.float32)
            
            return {
                'volume': volume_tensor,
                'label': label_tensor,
                'series_id': series_id,
                'idx': idx
            }
            
        except Exception as e:
            logger.error(f"Error loading sample {idx}: {e}")
            return {
                'volume': torch.zeros((1, *Config.TARGET_SIZE), dtype=torch.float32),
                'label': torch.tensor(0.0, dtype=torch.float32),
                'series_id': f"DUMMY_{idx}",
                'idx': idx
            }

class EfficientAneurysmNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=1, dropout_rate=0.3):
        super(EfficientAneurysmNet, self).__init__()
        
        # Efficient backbone
        self.features = nn.Sequential(
            # Initial block
            nn.Conv3d(in_channels, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=2, padding=1),
            
            # Block 1
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            
            # Block 2
            nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d((1, 1, 1)),
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.5),
            nn.Linear(64, num_classes)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, pos_weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.pos_weight = pos_weight
    
    def forward(self, inputs, targets):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(
            inputs.view(-1), targets, pos_weight=self.pos_weight, reduction='none'
        )
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

def create_optimized_data_loaders(train_df, val_df, processor):
    augmentation = LightweightAugmentation(prob=Config.AUGMENTATION_PROB) if Config.USE_AUGMENTATION else None
    
    train_dataset = OptimizedAneurysmDataset(
        train_df, Config.SERIES_DIR, processor, augmentation, mode='train'
    )
    val_dataset = OptimizedAneurysmDataset(
        val_df, Config.SERIES_DIR, processor, mode='val'
    )
    
    # Balanced sampler
    targets = [train_dataset[i]['label'].item() for i in range(len(train_dataset))]
    class_counts = Counter(targets)
    weights = [len(targets) / (len(class_counts) * class_counts[target]) for target in targets]
    sampler = WeightedRandomSampler(weights, len(weights))
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.BATCH_SIZE,
        sampler=sampler,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() and not TPU_AVAILABLE else False,
        persistent_workers=False,
        prefetch_factor=Config.PREFETCH_FACTOR
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() and not TPU_AVAILABLE else False,
        persistent_workers=False
    )
    
    return train_loader, val_loader

def train_epoch_optimized(model, loader, optimizer, criterion, device, epoch):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}", leave=False)
    
    for batch_idx, batch in enumerate(progress_bar):
        try:
            volume = batch['volume'].to(device, non_blocking=True)
            label = batch['label'].to(device, non_blocking=True)
            
            optimizer.zero_grad()
            outputs = model(volume)
            loss = criterion(outputs, label)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRADIENT_CLIP)
            
            if TPU_AVAILABLE:
                xm.optimizer_step(optimizer)
            else:
                optimizer.step()
            
            # Metrics
            running_loss += loss.item()
            with torch.no_grad():
                probs = torch.sigmoid(outputs).cpu().numpy()
                preds = (probs > 0.5).astype(int)
                all_preds.extend(preds.flatten())
                all_labels.extend(label.cpu().numpy())
            
            # Update progress
            if batch_idx % 10 == 0:
                progress_bar.set_postfix({
                    'Loss': f'{running_loss/(batch_idx+1):.4f}',
                    'Acc': f'{np.mean(np.array(all_preds) == np.array(all_labels)):.3f}'
                })
            
            # Memory cleanup
            if batch_idx % 20 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            logger.error(f"Error in batch {batch_idx}: {e}")
            continue
    
    # Calculate metrics
    metrics = {
        'loss': running_loss / len(loader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds, zero_division=0),
    }
    
    if len(np.unique(all_labels)) > 1:
        probs = torch.sigmoid(torch.tensor(all_preds, dtype=torch.float32)).numpy()
        metrics['auc'] = roc_auc_score(all_labels, probs)
    else:
        metrics['auc'] = 0.5
    
    return metrics

def validate_epoch_optimized(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(loader, desc="Validation", leave=False)):
            try:
                volume = batch['volume'].to(device, non_blocking=True)
                label = batch['label'].to(device, non_blocking=True)
                
                outputs = model(volume)
                loss = criterion(outputs, label)
                
                running_loss += loss.item()
                probs = torch.sigmoid(outputs).cpu().numpy()
                preds = (probs > 0.5).astype(int)
                
                all_preds.extend(preds.flatten())
                all_probs.extend(probs.flatten())
                all_labels.extend(label.cpu().numpy())
                
            except Exception as e:
                logger.error(f"Error in validation batch {batch_idx}: {e}")
                continue
    
    # Calculate metrics
    metrics = {
        'loss': running_loss / len(loader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds, zero_division=0),
    }
    
    if len(np.unique(all_labels)) > 1:
        metrics['auc'] = roc_auc_score(all_labels, all_probs)
    else:
        metrics['auc'] = 0.5
    
    return metrics

def optimized_training_pipeline():
    print("OPTIMIZED ANEURYSM DETECTION TRAINING")
    print("=" * 60)
    print(f"Device: {Config.DEVICE}")
    print(f"Debug Mode: {Config.DEBUG_MODE}")
    print(f"Target Size: {Config.TARGET_SIZE}")
    print(f"Batch Size: {Config.BATCH_SIZE}")
    
    # Load data
    try:
        train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
        print(f"Loaded {len(train_df)} samples")
    except Exception as e:
        print(f"Error loading data: {e}")
        return None
    
    # Data validation
    train_df = train_df.dropna(subset=[Config.ID_COL, Config.TARGET_COL])
    train_df = train_df[train_df[Config.TARGET_COL].isin([0, 1])]
    
    if Config.DEBUG_MODE:
        # Sample balanced data for debugging
        pos_samples = train_df[train_df[Config.TARGET_COL] == 1].sample(
            n=min(Config.DEBUG_SAMPLES//2, train_df[Config.TARGET_COL].sum()), 
            random_state=42
        )
        neg_samples = train_df[train_df[Config.TARGET_COL] == 0].sample(
            n=min(Config.DEBUG_SAMPLES//2, (train_df[Config.TARGET_COL] == 0).sum()), 
            random_state=42
        )
        train_df = pd.concat([pos_samples, neg_samples]).reset_index(drop=True)
        print(f"Debug mode: using {len(train_df)} samples")
    
    print(f"Class distribution: {train_df[Config.TARGET_COL].value_counts().sort_index().to_dict()}")
    
    # Cross-validation setup
    skf = StratifiedKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
    fold_results = []
    
    processor = OptimizedDICOMProcessor()
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df[Config.TARGET_COL])):
        if Config.CURRENT_FOLD >= 0 and fold != Config.CURRENT_FOLD:
            continue
            
        print(f"\nFOLD {fold + 1}/{Config.N_FOLDS}")
        print("-" * 40)
        
        fold_train_df = train_df.iloc[train_idx].reset_index(drop=True)
        fold_val_df = train_df.iloc[val_idx].reset_index(drop=True)
        
        print(f"Train: {len(fold_train_df)}, Val: {len(fold_val_df)}")
        print(f"Train pos: {fold_train_df[Config.TARGET_COL].sum()}, Val pos: {fold_val_df[Config.TARGET_COL].sum()}")
        
        # Data loaders
        train_loader, val_loader = create_optimized_data_loaders(fold_train_df, fold_val_df, processor)
        
        # Model setup
        model = EfficientAneurysmNet().to(Config.DEVICE)
        
        # Loss function with class weights
        pos_count = fold_train_df[Config.TARGET_COL].sum()
        neg_count = len(fold_train_df) - pos_count
        pos_weight = torch.tensor([neg_count / max(pos_count, 1)]).to(Config.DEVICE)
        
        criterion = FocalLoss(alpha=1, gamma=2, pos_weight=pos_weight)
        
        # Optimizer and scheduler
        optimizer = optim.AdamW(
            model.parameters(),
            lr=Config.LEARNING_RATE,
            weight_decay=Config.WEIGHT_DECAY
        )
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', patience=Config.SCHEDULER_PATIENCE,
            factor=0.5, verbose=True
        )
        
        # Training loop
        best_val_auc = 0
        patience_counter = 0
        history = {'train_loss': [], 'val_loss': [], 'train_auc': [], 'val_auc': []}
        
        for epoch in range(Config.EPOCHS):
            print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
            
            # Training
            train_metrics = train_epoch_optimized(model, train_loader, optimizer, criterion, Config.DEVICE, epoch)
            
            # Validation
            val_metrics = validate_epoch_optimized(model, val_loader, criterion, Config.DEVICE)
            
            # Scheduler step
            scheduler.step(val_metrics['loss'])
            
            # Record history
            history['train_loss'].append(train_metrics['loss'])
            history['val_loss'].append(val_metrics['loss'])
            history['train_auc'].append(train_metrics['auc'])
            history['val_auc'].append(val_metrics['auc'])
            
            # Print metrics
            print(f"Train - Loss: {train_metrics['loss']:.4f}, AUC: {train_metrics['auc']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
            print(f"Val   - Loss: {val_metrics['loss']:.4f}, AUC: {val_metrics['auc']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
            
            # Early stopping and model saving
            if val_metrics['auc'] > best_val_auc:
                best_val_auc = val_metrics['auc']
                patience_counter = 0
                
                # Save best model
                torch.save({
                    'fold': fold,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_auc': best_val_auc,
                    'history': history
                }, f'fold_{fold}_best_model.pth')
                
                print(f"New best AUC: {best_val_auc:.4f} - Model saved!")
            else:
                patience_counter += 1
                
            if patience_counter >= Config.EARLY_STOPPING_PATIENCE:
                print(f"Early stopping at epoch {epoch+1}")
                break
            
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
        
        # Save fold results
        fold_results.append({
            'fold': fold + 1,
            'best_val_auc': best_val_auc,
            'history': history
        })
        
        print(f"Fold {fold + 1} completed! Best AUC: {best_val_auc:.4f}")
        
        # Clear processor cache
        processor.clear_cache()
        
        if Config.CURRENT_FOLD >= 0:
            break
    
    # Summary
    if fold_results:
        aucs = [result['best_val_auc'] for result in fold_results]
        print(f"\nCROSS-VALIDATION SUMMARY")
        print("=" * 40)
        print(f"Mean AUC: {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
        print(f"Best AUC: {max(aucs):.4f}")
        
        # Save results
        with open('optimized_results.json', 'w') as f:
            json.dump(fold_results, f, indent=2, default=str)
    
    return fold_results

Using TPU device: xla:0


In [None]:
if __name__ == "__main__":
    results = optimized_training_pipeline()
    
    if results:
        print("\nTraining completed successfully!")
        print("Check optimized_results.json for detailed results")
    else:
        print("\nTraining failed.")

**Model Evaluation and Performance**

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    roc_curve, precision_recall_curve, confusion_matrix, classification_report,
    average_precision_score, balanced_accuracy_score, matthews_corrcoef,
    cohen_kappa_score, log_loss, brier_score_loss
)
from sklearn.calibration import calibration_curve
from scipy import stats
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
from tqdm.auto import tqdm
import warnings
import json
from datetime import datetime
from collections import defaultdict
import itertools
from scipy.stats import bootstrap
warnings.filterwarnings('ignore')

class AdvancedModelEvaluator:
    def __init__(self, model_paths, data_loader, device, class_names=None, save_dir="evaluation_results"):
        self.model_paths = model_paths if isinstance(model_paths, list) else [model_paths]
        self.data_loader = data_loader
        self.device = device
        self.class_names = class_names or ['No Aneurysm', 'Aneurysm']
        self.save_dir = save_dir
        
        # Create save directory
        os.makedirs(save_dir, exist_ok=True)
        
        # Storage for results
        self.results = {}
        self.predictions = {}
        self.raw_outputs = {}
        
    def load_model_predictions(self, model_class):
        """Load models and generate predictions"""
        print("Loading models and generating predictions...")
        
        for i, model_path in enumerate(self.model_paths):
            print(f"\nProcessing model: {model_path}")
            
            # Load model
            model = model_class().to(self.device)
            
            if os.path.exists(model_path):
                checkpoint = torch.load(model_path, map_location=self.device)
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
            else:
                print(f"Model file not found: {model_path}")
                continue
            
            # Generate predictions
            model.eval()
            predictions = []
            probabilities = []
            labels = []
            logits = []
            series_ids = []
            
            with torch.no_grad():
                for batch in tqdm(self.data_loader, desc=f"Model {i+1} Inference"):
                    volume = batch['volume'].to(self.device)
                    label = batch['label']
                    series_id = batch.get('series_id', [f'sample_{j}' for j in range(len(label))])
                    
                    outputs = model(volume)
                    probs = torch.sigmoid(outputs).cpu().numpy().flatten()
                    preds = (probs > 0.5).astype(int)
                    
                    predictions.extend(preds)
                    probabilities.extend(probs)
                    labels.extend(label.numpy())
                    logits.extend(outputs.cpu().numpy().flatten())
                    series_ids.extend(series_id)
            
            model_name = f"Model_{i+1}" if len(self.model_paths) > 1 else "Model"
            
            self.predictions[model_name] = {
                'predictions': np.array(predictions),
                'probabilities': np.array(probabilities),
                'labels': np.array(labels),
                'logits': np.array(logits),
                'series_ids': series_ids,
                'model_path': model_path
            }
    
    def calculate_comprehensive_metrics(self):
        print("\nCalculating comprehensive metrics...")
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_pred = data['predictions']
            y_proba = data['probabilities']
            y_logits = data['logits']
            
            # Basic classification metrics
            basic_metrics = {
                'accuracy': accuracy_score(y_true, y_pred),
                'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
                'precision': precision_score(y_true, y_pred, zero_division=0),
                'recall': recall_score(y_true, y_pred, zero_division=0),
                'specificity': recall_score(1 - y_true, 1 - y_pred, zero_division=0),
                'f1_score': f1_score(y_true, y_pred, zero_division=0),
                'matthews_corr': matthews_corrcoef(y_true, y_pred),
                'cohen_kappa': cohen_kappa_score(y_true, y_pred),
            }
            
            # Probabilistic metrics
            if len(np.unique(y_true)) > 1:
                prob_metrics = {
                    'roc_auc': roc_auc_score(y_true, y_proba),
                    'pr_auc': average_precision_score(y_true, y_proba),
                    'log_loss': log_loss(y_true, y_proba),
                    'brier_score': brier_score_loss(y_true, y_proba),
                }
            else:
                prob_metrics = {
                    'roc_auc': 0.5,
                    'pr_auc': np.mean(y_true),
                    'log_loss': float('inf'),
                    'brier_score': float('inf'),
                }
            
            # Threshold-dependent metrics
            threshold_metrics = self._calculate_threshold_metrics(y_true, y_proba)
            
            # Confidence and calibration metrics
            calibration_metrics = self._calculate_calibration_metrics(y_true, y_proba)
            
            # Class-wise metrics
            class_metrics = self._calculate_class_wise_metrics(y_true, y_pred, y_proba)
            
            # Ensemble metrics (if multiple models)
            ensemble_metrics = {}
            if len(self.predictions) > 1:
                ensemble_metrics = self._calculate_ensemble_metrics()
            
            # Combine all metrics
            all_metrics = {
                **basic_metrics,
                **prob_metrics,
                **threshold_metrics,
                **calibration_metrics,
                **class_metrics,
                **ensemble_metrics
            }
            
            self.results[model_name] = all_metrics
    
    def _calculate_threshold_metrics(self, y_true, y_proba):
        thresholds = np.arange(0.1, 1.0, 0.1)
        threshold_results = {}
        
        for threshold in thresholds:
            y_pred_thresh = (y_proba >= threshold).astype(int)
            
            if len(np.unique(y_pred_thresh)) > 1:
                threshold_results[f'f1_thresh_{threshold:.1f}'] = f1_score(y_true, y_pred_thresh)
                threshold_results[f'precision_thresh_{threshold:.1f}'] = precision_score(y_true, y_pred_thresh, zero_division=0)
                threshold_results[f'recall_thresh_{threshold:.1f}'] = recall_score(y_true, y_pred_thresh, zero_division=0)
        
        # Find optimal threshold
        if len(np.unique(y_true)) > 1:
            fpr, tpr, thresholds_roc = roc_curve(y_true, y_proba)
            optimal_idx = np.argmax(tpr - fpr)
            optimal_threshold = thresholds_roc[optimal_idx]
            
            threshold_results['optimal_threshold'] = optimal_threshold
            threshold_results['optimal_f1'] = f1_score(y_true, (y_proba >= optimal_threshold).astype(int))
        
        return threshold_results
    
    def _calculate_calibration_metrics(self, y_true, y_proba):
        # Calibration curve
        if len(np.unique(y_true)) > 1:
            prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
            calibration_error = np.mean(np.abs(prob_true - prob_pred))
            
            # Expected Calibration Error (ECE)
            bin_boundaries = np.linspace(0, 1, 11)
            bin_lowers = bin_boundaries[:-1]
            bin_uppers = bin_boundaries[1:]
            
            ece = 0
            for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
                in_bin = (y_proba > bin_lower) & (y_proba <= bin_upper)
                prop_in_bin = in_bin.mean()
                
                if prop_in_bin > 0:
                    accuracy_in_bin = y_true[in_bin].mean()
                    avg_confidence_in_bin = y_proba[in_bin].mean()
                    ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        else:
            calibration_error = float('inf')
            ece = float('inf')
        
        # Confidence metrics
        confidence_metrics = {
            'mean_confidence': np.mean(y_proba),
            'std_confidence': np.std(y_proba),
            'confidence_range': np.ptp(y_proba),
            'calibration_error': calibration_error,
            'expected_calibration_error': ece,
        }
        
        return confidence_metrics
    
    def _calculate_class_wise_metrics(self, y_true, y_pred, y_proba):
        cm = confusion_matrix(y_true, y_pred)
        
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            
            class_metrics = {
                'true_positives': tp,
                'true_negatives': tn,
                'false_positives': fp,
                'false_negatives': fn,
                'sensitivity': tp / (tp + fn) if (tp + fn) > 0 else 0,
                'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
                'positive_predictive_value': tp / (tp + fp) if (tp + fp) > 0 else 0,
                'negative_predictive_value': tn / (tn + fn) if (tn + fn) > 0 else 0,
                'false_positive_rate': fp / (fp + tn) if (fp + tn) > 0 else 0,
                'false_negative_rate': fn / (fn + tp) if (fn + tp) > 0 else 0,
                'positive_likelihood_ratio': (tp/(tp+fn)) / (fp/(fp+tn)) if fp > 0 else float('inf'),
                'negative_likelihood_ratio': (fn/(fn+tp)) / (tn/(tn+fp)) if tn > 0 else float('inf'),
            }
        else:
            class_metrics = {}
        
        return class_metrics
    
    def _calculate_ensemble_metrics(self):
        if len(self.predictions) < 2:
            return {}
        
        # Average predictions across models
        all_probas = []
        all_preds = []
        y_true = None
        
        for model_name, data in self.predictions.items():
            all_probas.append(data['probabilities'])
            all_preds.append(data['predictions'])
            if y_true is None:
                y_true = data['labels']
        
        # Ensemble by averaging probabilities
        ensemble_proba = np.mean(all_probas, axis=0)
        ensemble_pred = (ensemble_proba > 0.5).astype(int)
        
        # Ensemble by majority voting
        majority_pred = np.round(np.mean(all_preds, axis=0)).astype(int)
        
        ensemble_metrics = {
            'ensemble_accuracy': accuracy_score(y_true, ensemble_pred),
            'ensemble_f1': f1_score(y_true, ensemble_pred),
            'ensemble_auc': roc_auc_score(y_true, ensemble_proba) if len(np.unique(y_true)) > 1 else 0.5,
            'majority_vote_accuracy': accuracy_score(y_true, majority_pred),
            'majority_vote_f1': f1_score(y_true, majority_pred),
        }
        
        # Model agreement
        agreement_matrix = np.array(all_preds)
        model_agreement = np.mean(np.std(agreement_matrix, axis=0) == 0)  # Fraction of samples where all models agree
        
        ensemble_metrics['model_agreement'] = model_agreement
        
        return ensemble_metrics
    
    def create_comprehensive_visualizations(self):
        print("\nCreating comprehensive visualizations...")
        
        # Set style
        plt.style.use('seaborn-v0_8')
        colors = ['#2E8B57', '#DC143C', '#4169E1', '#FFD700', '#8A2BE2']
        
        # 1. Performance Overview Dashboard
        self._create_performance_dashboard()
        
        # 2. ROC Curves Comparison
        self._create_roc_comparison()
        
        # 3. Precision-Recall Curves
        self._create_pr_comparison()
        
        # 4. Calibration Plots
        self._create_calibration_plots()
        
        # 5. Confusion Matrices
        self._create_confusion_matrices()
        
        # 6. Threshold Analysis
        self._create_threshold_analysis()
        
        # 7. Prediction Distribution Analysis
        self._create_prediction_distributions()
        
        # 8. Error Analysis
        self._create_error_analysis()
        
        # 9. Interactive Plotly Visualizations
        self._create_interactive_plots()
        
        # 10. Statistical Significance Tests
        self._create_statistical_tests()
    
    def _create_performance_dashboard(self):
        n_models = len(self.predictions)
        fig, axes = plt.subplots(3, 4, figsize=(20, 15))
        fig.suptitle('Model Performance Dashboard', fontsize=16, fontweight='bold')
        
        # Key metrics comparison
        metrics_to_plot = ['accuracy', 'f1_score', 'roc_auc', 'pr_auc', 
                          'precision', 'recall', 'specificity', 'matthews_corr']
        
        for i, metric in enumerate(metrics_to_plot):
            ax = axes[i // 4, i % 4]
            
            values = [self.results[model][metric] for model in self.results.keys()]
            model_names = list(self.results.keys())
            
            bars = ax.bar(model_names, values, color=colors[:len(values)])
            ax.set_title(f'{metric.replace("_", " ").title()}', fontweight='bold')
            ax.set_ylabel('Score')
            ax.set_ylim(0, 1)
            
            # Add value labels on bars
            for bar, value in zip(bars, values):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{value:.3f}', ha='center', va='bottom')
            
            ax.tick_params(axis='x', rotation=45)
        
        # Remove empty subplots
        for j in range(len(metrics_to_plot), 12):
            axes[j // 4, j % 4].remove()
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/performance_dashboard.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_roc_comparison(self):
        plt.figure(figsize=(10, 8))
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                fpr, tpr, _ = roc_curve(y_true, y_proba)
                auc = roc_auc_score(y_true, y_proba)
                
                plt.plot(fpr, tpr, color=colors[i], lw=2,
                        label=f'{model_name} (AUC = {auc:.3f})')
        
        plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', alpha=0.8)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title('ROC Curve Comparison', fontsize=14, fontweight='bold')
        plt.legend(loc="lower right")
        plt.grid(True, alpha=0.3)
        
        plt.savefig(f'{self.save_dir}/roc_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_pr_comparison(self):
        plt.figure(figsize=(10, 8))
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                precision, recall, _ = precision_recall_curve(y_true, y_proba)
                avg_precision = average_precision_score(y_true, y_proba)
                
                plt.plot(recall, precision, color=colors[i], lw=2,
                        label=f'{model_name} (AP = {avg_precision:.3f})')
        
        # Baseline
        baseline = np.mean([data['labels'] for data in self.predictions.values()][0])
        plt.axhline(y=baseline, color='gray', linestyle='--', alpha=0.8, 
                   label=f'Baseline (AP = {baseline:.3f})')
        
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Precision-Recall Curve Comparison', fontsize=14, fontweight='bold')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.savefig(f'{self.save_dir}/pr_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_calibration_plots(self):
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Calibration curve
        ax1 = axes[0]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
                ax1.plot(prob_pred, prob_true, marker='o', color=colors[i], 
                        label=model_name, linewidth=2, markersize=6)
        
        ax1.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
        ax1.set_xlabel('Mean Predicted Probability')
        ax1.set_ylabel('Fraction of Positives')
        ax1.set_title('Calibration Plot (Reliability Diagram)')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Confidence histogram
        ax2 = axes[1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_proba = data['probabilities']
            ax2.hist(y_proba, bins=20, alpha=0.7, color=colors[i], 
                    label=model_name, density=True)
        
        ax2.set_xlabel('Predicted Probability')
        ax2.set_ylabel('Density')
        ax2.set_title('Confidence Distribution')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/calibration_plots.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_confusion_matrices(self):
        n_models = len(self.predictions)
        cols = min(3, n_models)
        rows = (n_models + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 4*rows))
        if n_models == 1:
            axes = [axes]
        elif rows == 1:
            axes = [axes]
        else:
            axes = axes.flatten()
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_pred = data['predictions']
            
            cm = confusion_matrix(y_true, y_pred)
            
            # Normalize confusion matrix
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
            ax = axes[i] if n_models > 1 else axes[0]
            
            # Create heatmap
            sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
                       xticklabels=self.class_names, yticklabels=self.class_names,
                       ax=ax, cbar_kws={'shrink': 0.8})
            
            ax.set_title(f'{model_name}\nNormalized Confusion Matrix')
            ax.set_ylabel('True Label')
            ax.set_xlabel('Predicted Label')
            
            # Add counts
            for j in range(cm.shape[0]):
                for k in range(cm.shape[1]):
                    ax.text(k+0.5, j+0.7, f'({cm[j,k]})', 
                           ha='center', va='center', fontsize=10, color='red')
        
        # Remove empty subplots
        for j in range(n_models, len(axes)):
            if j < len(axes):
                fig.delaxes(axes[j])
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/confusion_matrices.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_threshold_analysis(self):
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Threshold vs F1 Score
        ax1 = axes[0, 0]
        thresholds = np.arange(0.1, 1.0, 0.01)
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            f1_scores = []
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                f1_scores.append(f1_score(y_true, y_pred_thresh))
            
            ax1.plot(thresholds, f1_scores, color=colors[i], label=model_name, linewidth=2)
            
            # Mark optimal threshold
            optimal_idx = np.argmax(f1_scores)
            ax1.scatter(thresholds[optimal_idx], f1_scores[optimal_idx], 
                       color=colors[i], s=100, marker='*', zorder=5)
        
        ax1.set_xlabel('Threshold')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('F1 Score vs Threshold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Threshold vs Precision/Recall
        ax2 = axes[0, 1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            precisions = []
            recalls = []
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                precisions.append(precision_score(y_true, y_pred_thresh, zero_division=0))
                recalls.append(recall_score(y_true, y_pred_thresh, zero_division=0))
            
            ax2.plot(thresholds, precisions, color=colors[i], linestyle='-', 
                    label=f'{model_name} Precision', linewidth=2)
            ax2.plot(thresholds, recalls, color=colors[i], linestyle='--', 
                    label=f'{model_name} Recall', linewidth=2)
        
        ax2.set_xlabel('Threshold')
        ax2.set_ylabel('Score')
        ax2.set_title('Precision/Recall vs Threshold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Class distribution by confidence
        ax3 = axes[1, 0]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            # Separate by true class
            pos_probs = y_proba[y_true == 1]
            neg_probs = y_proba[y_true == 0]
            
            ax3.hist(neg_probs, bins=20, alpha=0.5, color=colors[i], 
                    label=f'{model_name} Negative', density=True)
            ax3.hist(pos_probs, bins=20, alpha=0.5, color=colors[i], 
                    label=f'{model_name} Positive', density=True, hatch='///')
        
        ax3.set_xlabel('Predicted Probability')
        ax3.set_ylabel('Density')
        ax3.set_title('Probability Distribution by True Class')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Sensitivity/Specificity vs Threshold
        ax4 = axes[1, 1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            sensitivities = []
            specificities = []
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                sensitivities.append(recall_score(y_true, y_pred_thresh, zero_division=0))
                specificities.append(recall_score(1 - y_true, 1 - y_pred_thresh, zero_division=0))
            
            ax4.plot(thresholds, sensitivities, color=colors[i], linestyle='-', 
                    label=f'{model_name} Sensitivity', linewidth=2)
            ax4.plot(thresholds, specificities, color=colors[i], linestyle='--', 
                    label=f'{model_name} Specificity', linewidth=2)
        
        ax4.set_xlabel('Threshold')
        ax4.set_ylabel('Score')
        ax4.set_title('Sensitivity/Specificity vs Threshold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/threshold_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_prediction_distributions(self):
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Box plots of predictions by true class
        ax1 = axes[0, 0]
        box_data = []
        box_labels = []
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_proba = data['probabilities']
            
            for class_idx, class_name in enumerate(self.class_names):
                class_probs = y_proba[y_true == class_idx]
                box_data.append(class_probs)
                box_labels.append(f'{model_name}\n{class_name}')
        
        ax1.boxplot(box_data, labels=box_labels)
        ax1.set_title('Prediction Distribution by Model and True Class')
        ax1.set_ylabel('Predicted Probability')
        ax1.tick_params(axis='x', rotation=45)
        ax1.grid(True, alpha=0.3)
        
        # Violin plots
        ax2 = axes[0, 1]
        positions = np.arange(1, len(box_data) + 1)
        parts = ax2.violinplot(box_data, positions=positions, showmeans=True, showmedians=True)
        ax2.set_xticks(positions)
        ax2.set_xticklabels(box_labels, rotation=45)
        ax2.set_title('Prediction Distribution (Violin Plot)')
        ax2.set_ylabel('Predicted Probability')
        ax2.grid(True, alpha=0.3)
        
        # Prediction confidence vs accuracy
        ax3 = axes[1, 0]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_pred = data['predictions']
            y_proba = data['probabilities']
            
            # Bin predictions by confidence
            confidence_bins = np.linspace(0, 1, 11)
            bin_accuracies = []
            bin_centers = []
            
            for j in range(len(confidence_bins) - 1):
                mask = (y_proba >= confidence_bins[j]) & (y_proba < confidence_bins[j + 1])
                if mask.sum() > 0:
                    bin_accuracy = (y_pred[mask] == y_true[mask]).mean()
                    bin_accuracies.append(bin_accuracy)
                    bin_centers.append((confidence_bins[j] + confidence_bins[j + 1]) / 2)
            
            ax3.plot(bin_centers, bin_accuracies, 'o-', color=colors[i], 
                    label=model_name, linewidth=2, markersize=6)
        
        ax3.set_xlabel('Prediction Confidence (Binned)')
        ax3.set_ylabel('Accuracy')
        ax3.set_title('Confidence vs Accuracy')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Error distribution
        ax4 = axes[1, 1]
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_pred = data['predictions']
            
            # Calculate error types
            tp = ((y_true == 1) & (y_pred == 1)).sum()
            tn = ((y_true == 0) & (y_pred == 0)).sum()
            fp = ((y_true == 0) & (y_pred == 1)).sum()
            fn = ((y_true == 1) & (y_pred == 0)).sum()
            
            categories = ['True Pos', 'True Neg', 'False Pos', 'False Neg']
            values = [tp, tn, fp, fn]
            
            x_pos = np.arange(len(categories)) + i * 0.2
            ax4.bar(x_pos, values, width=0.2, color=colors[i], 
                   label=model_name, alpha=0.8)
        
        ax4.set_xlabel('Prediction Type')
        ax4.set_ylabel('Count')
        ax4.set_title('Error Type Distribution')
        ax4.set_xticks(np.arange(len(categories)) + 0.1)
        ax4.set_xticklabels(categories)
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/prediction_distributions.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_error_analysis(self):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            if i >= 6:  # Limit to 6 models for visualization
                break
                
            y_true = data['labels']
            y_pred = data['predictions']
            y_proba = data['probabilities']
            
            ax = axes[i // 3, i % 3]
            
            # Create error analysis scatter plot
            correct = (y_pred == y_true)
            
            # Plot correct predictions
            ax.scatter(y_proba[correct], y_true[correct], 
                      alpha=0.6, c='green', label='Correct', s=20)
            
            # Plot incorrect predictions
            ax.scatter(y_proba[~correct], y_true[~correct], 
                      alpha=0.6, c='red', label='Incorrect', s=20, marker='x')
            
            ax.set_xlabel('Predicted Probability')
            ax.set_ylabel('True Label')
            ax.set_title(f'{model_name} - Error Analysis')
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            # Add decision boundary
            ax.axvline(x=0.5, color='black', linestyle='--', alpha=0.5)
        
        # Remove empty subplots
        for j in range(len(self.predictions), 6):
            axes[j // 3, j % 3].remove()
        
        plt.tight_layout()
        plt.savefig(f'{self.save_dir}/error_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_interactive_plots(self):
        print("Creating interactive visualizations...")
        
        # Interactive ROC curve
        fig_roc = go.Figure()
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                fpr, tpr, thresholds = roc_curve(y_true, y_proba)
                auc = roc_auc_score(y_true, y_proba)
                
                fig_roc.add_trace(go.Scatter(
                    x=fpr, y=tpr,
                    mode='lines',
                    name=f'{model_name} (AUC = {auc:.3f})',
                    line=dict(width=3),
                    hovertemplate='FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>'
                ))
        
        # Add diagonal line
        fig_roc.add_trace(go.Scatter(
            x=[0, 1], y=[0, 1],
            mode='lines',
            name='Random Classifier',
            line=dict(dash='dash', color='gray', width=2)
        ))
        
        fig_roc.update_layout(
            title='Interactive ROC Curve Comparison',
            xaxis_title='False Positive Rate',
            yaxis_title='True Positive Rate',
            width=800, height=600,
            hovermode='closest'
        )
        
        fig_roc.write_html(f'{self.save_dir}/interactive_roc.html')
        
        # Interactive threshold analysis
        fig_threshold = make_subplots(
            rows=2, cols=2,
            subplot_titles=('F1 Score vs Threshold', 'Precision/Recall vs Threshold',
                          'Sensitivity/Specificity vs Threshold', 'Sample Count vs Threshold')
        )
        
        thresholds = np.arange(0.1, 1.0, 0.01)
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            # Calculate metrics for each threshold
            f1_scores = []
            precisions = []
            recalls = []
            specificities = []
            sample_counts = []
            
            for threshold in thresholds:
                y_pred_thresh = (y_proba >= threshold).astype(int)
                f1_scores.append(f1_score(y_true, y_pred_thresh))
                precisions.append(precision_score(y_true, y_pred_thresh, zero_division=0))
                recalls.append(recall_score(y_true, y_pred_thresh, zero_division=0))
                specificities.append(recall_score(1 - y_true, 1 - y_pred_thresh, zero_division=0))
                sample_counts.append(y_pred_thresh.sum())
            
            # F1 Score
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=f1_scores, name=f'{model_name} F1',
                          line=dict(width=2)),
                row=1, col=1
            )
            
            # Precision/Recall
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=precisions, name=f'{model_name} Precision',
                          line=dict(width=2)),
                row=1, col=2
            )
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=recalls, name=f'{model_name} Recall',
                          line=dict(width=2, dash='dash')),
                row=1, col=2
            )
            
            # Sensitivity
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=recalls, name=f'{model_name} Sensitivity',
                          line=dict(width=2)),
                row=2, col=1
            )
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=specificities, name=f'{model_name} Specificity',
                          line=dict(width=2, dash='dash')),
                row=2, col=1
            )
            
            # Sample count
            fig_threshold.add_trace(
                go.Scatter(x=thresholds, y=sample_counts, name=f'{model_name} Positive Predictions',
                          line=dict(width=2)),
                row=2, col=2
            )
        
        fig_threshold.update_layout(
            title='Interactive Threshold Analysis',
            width=1200, height=800,
            hovermode='x unified'
        )
        
        fig_threshold.write_html(f'{self.save_dir}/interactive_threshold.html')
        
        # Interactive calibration plot
        fig_cal = go.Figure()
        
        for i, (model_name, data) in enumerate(self.predictions.items()):
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
                
                fig_cal.add_trace(go.Scatter(
                    x=prob_pred, y=prob_true,
                    mode='markers+lines',
                    name=model_name,
                    marker=dict(size=8),
                    line=dict(width=3),
                    hovertemplate='Predicted: %{x:.3f}<br>Actual: %{y:.3f}<extra></extra>'
                ))
        
        # Perfect calibration line
        fig_cal.add_trace(go.Scatter(
            x=[0, 1], y=[0, 1],
            mode='lines',
            name='Perfect Calibration',
            line=dict(dash='dash', color='gray', width=2)
        ))
        
        fig_cal.update_layout(
            title='Interactive Calibration Plot',
            xaxis_title='Mean Predicted Probability',
            yaxis_title='Fraction of Positives',
            width=800, height=600
        )
        
        fig_cal.write_html(f'{self.save_dir}/interactive_calibration.html')
    
    def _create_statistical_tests(self):
        print("Performing statistical significance tests...")
        
        if len(self.predictions) < 2:
            print("Need at least 2 models for statistical comparison")
            return
        
        # McNemar's test for paired predictions
        results_text = "STATISTICAL SIGNIFICANCE TESTS\n" + "="*50 + "\n\n"
        
        model_names = list(self.predictions.keys())
        n_models = len(model_names)
        
        # Create comparison matrix
        comparison_results = pd.DataFrame(index=model_names, columns=model_names)
        
        for i in range(n_models):
            for j in range(i+1, n_models):
                model1_name = model_names[i]
                model2_name = model_names[j]
                
                y_true = self.predictions[model1_name]['labels']
                pred1 = self.predictions[model1_name]['predictions']
                pred2 = self.predictions[model2_name]['predictions']
                
                # McNemar's test
                # Create contingency table
                correct1 = (pred1 == y_true)
                correct2 = (pred2 == y_true)
                
                both_correct = np.sum(correct1 & correct2)
                model1_only = np.sum(correct1 & ~correct2)
                model2_only = np.sum(~correct1 & correct2)
                both_wrong = np.sum(~correct1 & ~correct2)
                
                # McNemar's test statistic
                if (model1_only + model2_only) > 0:
                    mcnemar_stat = (abs(model1_only - model2_only) - 1)**2 / (model1_only + model2_only)
                    p_value = 1 - stats.chi2.cdf(mcnemar_stat, 1)
                else:
                    mcnemar_stat = 0
                    p_value = 1.0
                
                comparison_results.loc[model1_name, model2_name] = f"p={p_value:.4f}"
                comparison_results.loc[model2_name, model1_name] = f"χ²={mcnemar_stat:.4f}"
                
                results_text += f"McNemar's Test: {model1_name} vs {model2_name}\n"
                results_text += f"  Contingency Table:\n"
                results_text += f"    Both correct: {both_correct}\n"
                results_text += f"    {model1_name} only: {model1_only}\n"
                results_text += f"    {model2_name} only: {model2_only}\n"
                results_text += f"    Both wrong: {both_wrong}\n"
                results_text += f"  Chi-square statistic: {mcnemar_stat:.4f}\n"
                results_text += f"  P-value: {p_value:.4f}\n"
                
                if p_value < 0.05:
                    results_text += f"  Result: Significant difference (p < 0.05)\n"
                else:
                    results_text += f"  Result: No significant difference (p >= 0.05)\n"
                results_text += "\n"
        
        # Bootstrap confidence intervals for AUC
        results_text += "BOOTSTRAP CONFIDENCE INTERVALS (AUC)\n" + "-"*40 + "\n"
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                # Bootstrap sampling
                n_bootstrap = 1000
                bootstrap_aucs = []
                
                for _ in range(n_bootstrap):
                    # Sample with replacement
                    indices = np.random.choice(len(y_true), size=len(y_true), replace=True)
                    y_true_boot = y_true[indices]
                    y_proba_boot = y_proba[indices]
                    
                    if len(np.unique(y_true_boot)) > 1:
                        auc_boot = roc_auc_score(y_true_boot, y_proba_boot)
                        bootstrap_aucs.append(auc_boot)
                
                if bootstrap_aucs:
                    ci_lower = np.percentile(bootstrap_aucs, 2.5)
                    ci_upper = np.percentile(bootstrap_aucs, 97.5)
                    mean_auc = np.mean(bootstrap_aucs)
                    
                    results_text += f"{model_name}:\n"
                    results_text += f"  Mean AUC: {mean_auc:.4f}\n"
                    results_text += f"  95% CI: [{ci_lower:.4f}, {ci_upper:.4f}]\n\n"
        
        # Save results
        with open(f'{self.save_dir}/statistical_tests.txt', 'w') as f:
            f.write(results_text)
        
        # Create visualization of statistical tests
        plt.figure(figsize=(10, 8))
        
        # Plot AUC with confidence intervals
        model_names = []
        auc_values = []
        ci_lowers = []
        ci_uppers = []
        
        for model_name, data in self.predictions.items():
            y_true = data['labels']
            y_proba = data['probabilities']
            
            if len(np.unique(y_true)) > 1:
                # Bootstrap for CI
                n_bootstrap = 1000
                bootstrap_aucs = []
                
                for _ in range(n_bootstrap):
                    indices = np.random.choice(len(y_true), size=len(y_true), replace=True)
                    y_true_boot = y_true[indices]
                    y_proba_boot = y_proba[indices]
                    
                    if len(np.unique(y_true_boot)) > 1:
                        auc_boot = roc_auc_score(y_true_boot, y_proba_boot)
                        bootstrap_aucs.append(auc_boot)
                
                if bootstrap_aucs:
                    model_names.append(model_name)
                    auc_values.append(np.mean(bootstrap_aucs))
                    ci_lowers.append(np.percentile(bootstrap_aucs, 2.5))
                    ci_uppers.append(np.percentile(bootstrap_aucs, 97.5))
        
        if model_names:
            y_pos = np.arange(len(model_names))
            
            plt.errorbar(auc_values, y_pos, 
                        xerr=[np.array(auc_values) - np.array(ci_lowers), 
                              np.array(ci_uppers) - np.array(auc_values)],
                        fmt='o', capsize=5, capthick=2, markersize=8)
            
            plt.yticks(y_pos, model_names)
            plt.xlabel('AUC Score')
            plt.title('Model Performance with 95% Bootstrap Confidence Intervals')
            plt.grid(True, alpha=0.3)
            
            # Add vertical line at 0.5 (random performance)
            plt.axvline(x=0.5, color='red', linestyle='--', alpha=0.5, label='Random')
            plt.legend()
            
            plt.tight_layout()
            plt.savefig(f'{self.save_dir}/statistical_comparison.png', dpi=300, bbox_inches='tight')
            plt.close()
    
    def generate_comprehensive_report(self):
        print("Generating comprehensive report...")
        
        html_content = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>Model Evaluation Report</title>
            <style>
                body { font-family: Arial, sans-serif; margin: 20px; }
                .header { background-color: #f0f0f0; padding: 20px; border-radius: 10px; }
                .section { margin: 20px 0; padding: 15px; border-left: 4px solid #007acc; }
                .metrics-table { border-collapse: collapse; width: 100%; margin: 20px 0; }
                .metrics-table th, .metrics-table td { border: 1px solid #ddd; padding: 8px; text-align: center; }
                .metrics-table th { background-color: #f2f2f2; }
                .image-container { text-align: center; margin: 20px 0; }
                .best-score { background-color: #90EE90; }
                .worst-score { background-color: #FFB6C1; }
            </style>
        </head>
        <body>
        """
        
        # Header
        html_content += f"""
        <div class="header">
            <h1>Advanced Model Evaluation Report</h1>
            <p><strong>Generated on:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
            <p><strong>Number of Models:</strong> {len(self.predictions)}</p>
            <p><strong>Dataset Size:</strong> {len(list(self.predictions.values())[0]['labels'])}</p>
        </div>
        """
        
        # Executive Summary
        html_content += """
        <div class="section">
            <h2>Executive Summary</h2>
        """
        
        # Find best performing model
        best_model = None
        best_auc = 0
        
        for model_name in self.results:
            if self.results[model_name]['roc_auc'] > best_auc:
                best_auc = self.results[model_name]['roc_auc']
                best_model = model_name
        
        if best_model:
            html_content += f"""
            <p><strong>Best Performing Model:</strong> {best_model}</p>
            <p><strong>Best AUC Score:</strong> {best_auc:.4f}</p>
            <p><strong>Best F1 Score:</strong> {self.results[best_model]['f1_score']:.4f}</p>
            <p><strong>Best Accuracy:</strong> {self.results[best_model]['accuracy']:.4f}</p>
            """
        
        html_content += "</div>"
        
        # Detailed Metrics Table
        html_content += """
        <div class="section">
            <h2>Detailed Performance Metrics</h2>
            <table class="metrics-table">
                <tr>
                    <th>Model</th>
                    <th>Accuracy</th>
                    <th>Precision</th>
                    <th>Recall</th>
                    <th>F1 Score</th>
                    <th>AUC</th>
                    <th>Specificity</th>
                    <th>Matthews Corr</th>
                </tr>
        """
        
        # Find best/worst for each metric
        metrics_to_highlight = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc', 'specificity', 'matthews_corr']
        best_values = {}
        worst_values = {}
        
        for metric in metrics_to_highlight:
            values = [self.results[model][metric] for model in self.results]
            best_values[metric] = max(values)
            worst_values[metric] = min(values)
        
        for model_name in self.results:
            html_content += f"<tr><td><strong>{model_name}</strong></td>"
            
            for metric in metrics_to_highlight:
                value = self.results[model_name][metric]
                css_class = ""
                
                if value == best_values[metric] and len(self.results) > 1:
                    css_class = "best-score"
                elif value == worst_values[metric] and len(self.results) > 1:
                    css_class = "worst-score"
                
                html_content += f'<td class="{css_class}">{value:.4f}</td>'
            
            html_content += "</tr>"
        
        html_content += "</table></div>"
        
        # Visualizations
        html_content += """
        <div class="section">
            <h2>Visualizations</h2>
        """
        
        # List all generated plots
        plot_files = [
            'performance_dashboard.png',
            'roc_comparison.png', 
            'pr_comparison.png',
            'calibration_plots.png',
            'confusion_matrices.png',
            'threshold_analysis.png',
            'prediction_distributions.png',
            'error_analysis.png',
            'statistical_comparison.png'
        ]
        
        for plot_file in plot_files:
            if os.path.exists(f'{self.save_dir}/{plot_file}'):
                html_content += f"""
                <div class="image-container">
                    <h3>{plot_file.replace('_', ' ').replace('.png', '').title()}</h3>
                    <img src="{plot_file}" style="max-width: 100%; height: auto;" alt="{plot_file}">
                </div>
                """
        
        html_content += "</div>"
        
        # Interactive Plots
        html_content += """
        <div class="section">
            <h2>Interactive Visualizations</h2>
            <p>The following interactive plots have been generated:</p>
            <ul>
                <li><a href="interactive_roc.html">Interactive ROC Curves</a></li>
                <li><a href="interactive_threshold.html">Interactive Threshold Analysis</a></li>
                <li><a href="interactive_calibration.html">Interactive Calibration Plot</a></li>
            </ul>
        </div>
        """
        
        # Model Recommendations
        html_content += """
        <div class="section">
            <h2>Recommendations</h2>
        """
        
        if best_model and self.results[best_model]['roc_auc'] > 0.8:
            html_content += f"<p><strong>{best_model}</strong> shows excellent performance with AUC > 0.8. This model is ready for deployment consideration.</p>"
        elif best_model and self.results[best_model]['roc_auc'] > 0.7:
            html_content += f"<p><strong>{best_model}</strong> shows good performance with AUC > 0.7, but there's room for improvement.</p>"
        else:
            html_content += "<p>All models show suboptimal performance (AUC < 0.7). Consider:</p>"
            html_content += "<ul><li>Collecting more training data</li><li>Feature engineering</li><li>Different model architectures</li><li>Hyperparameter tuning</li></ul>"
        
        # Check for overfitting signs
        for model_name in self.results:
            if 'calibration_error' in self.results[model_name]:
                cal_error = self.results[model_name]['calibration_error']
                if cal_error > 0.1:
                    html_content += f"<p><strong>{model_name}</strong> shows poor calibration (error: {cal_error:.3f}). Consider calibration techniques.</p>"
        
        html_content += "</div>"
        
        html_content += """
        </body>
        </html>
        """
        
        # Save HTML report
        with open(f'{self.save_dir}/comprehensive_report.html', 'w') as f:
            f.write(html_content)
        
        print(f"Comprehensive report saved to {self.save_dir}/comprehensive_report.html")
    
    def run_complete_evaluation(self, model_class):
        print("Starting comprehensive model evaluation...")
        print("="*60)
        
        # Load models and generate predictions
        self.load_model_predictions(model_class)
        
        if not self.predictions:
            print("No valid predictions generated. Check model paths and data loader.")
            return None
        
        # Calculate comprehensive metrics
        self.calculate_comprehensive_metrics()
        
        # Create visualizations
        self.create_comprehensive_visualizations()
        
        # Generate report
        self.generate_comprehensive_report()
        
        # Save results to JSON
        results_for_json = {}
        for model_name in self.results:
            results_for_json[model_name] = {}
            for metric, value in self.results[model_name].items():
                if isinstance(value, (int, float, np.integer, np.floating)):
                    results_for_json[model_name][metric] = float(value)
                else:
                    results_for_json[model_name][metric] = str(value)
        
        with open(f'{self.save_dir}/detailed_results.json', 'w') as f:
            json.dump(results_for_json, f, indent=2)
        
        print(f"\nEvaluation completed successfully!")
        print(f"Results saved in: {self.save_dir}/")
        print(f"Open {self.save_dir}/comprehensive_report.html for full report")
        
        return self.results


# Usage Example
def run_evaluation_pipeline(model_paths, data_loader, model_class, device, save_dir="evaluation_results"):
    """
    Run the complete evaluation pipeline
    
    Args:
        model_paths: List of model checkpoint paths or single path
        data_loader: PyTorch DataLoader with test/validation data
        model_class: Model class (not instantiated)
        device: torch.device
        save_dir: Directory to save results
    
    Returns:
        Dictionary with evaluation results
    """
    
    # Create evaluator
    evaluator = AdvancedModelEvaluator(
        model_paths=model_paths,
        data_loader=data_loader,
        device=device,
        class_names=['No Aneurysm', 'Aneurysm'],
        save_dir=save_dir
    )
    
    # Run complete evaluation
    results = evaluator.run_complete_evaluation(model_class)
    
    return results