In [None]:
# Import libraries
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import time
from collections import defaultdict
from scipy import stats
import pandas as pd

# Project imports
from data import CAMUSDataset
from models import MambaUNet, SwinMamba
from utils.metrics import dice_coefficient, hausdorff_distance

# Configuration
DATA_ROOT = '../data/CAMUS'
CHECKPOINT_DIR = '../checkpoints'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 4

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

print(f"Using device: {DEVICE}")

## 1. Model Loading

In [None]:
# Define models to compare
MODELS = {
    'UNet (Baseline)': {
        'class': 'UNet',
        'checkpoint': 'unet_baseline.pth',
        'config': {'in_channels': 1, 'num_classes': NUM_CLASSES}
    },
    'Mamba-UNet': {
        'class': 'MambaUNet',
        'checkpoint': 'mamba_unet_best.pth',
        'config': {'in_channels': 1, 'num_classes': NUM_CLASSES, 'd_state': 16}
    },
    'Swin-Mamba': {
        'class': 'SwinMamba',
        'checkpoint': 'swin_mamba_best.pth',
        'config': {'in_channels': 1, 'num_classes': NUM_CLASSES}
    }
}

def load_model(model_info):
    """Load a model from checkpoint."""
    # This is a placeholder - adapt to your actual model loading
    model_class = model_info['class']
    config = model_info['config']
    checkpoint_path = Path(CHECKPOINT_DIR) / model_info['checkpoint']
    
    if model_class == 'MambaUNet':
        model = MambaUNet(**config)
    elif model_class == 'SwinMamba':
        model = SwinMamba(**config)
    else:
        # Import baseline if available
        from models.baseline import UNet
        model = UNet(**config)
    
    if checkpoint_path.exists():
        state_dict = torch.load(checkpoint_path, map_location=DEVICE)
        model.load_state_dict(state_dict['model_state_dict'])
        print(f"Loaded checkpoint: {checkpoint_path}")
    else:
        print(f"Warning: Checkpoint not found: {checkpoint_path}")
    
    return model.to(DEVICE).eval()

In [None]:
# Load all models
models = {}
for name, info in MODELS.items():
    try:
        models[name] = load_model(info)
        print(f"✓ Loaded {name}")
    except Exception as e:
        print(f"✗ Failed to load {name}: {e}")

In [None]:
# Model parameter counts
print("\nModel Parameters:")
print("-" * 40)
for name, model in models.items():
    params = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{name}:")
    print(f"  Total:     {params:,}")
    print(f"  Trainable: {trainable:,}")

## 2. Inference Comparison

In [None]:
# Load test dataset
test_dataset = CAMUSDataset(root_dir=DATA_ROOT, split='test')
print(f"Test samples: {len(test_dataset)}")

In [None]:
def run_inference(model, image):
    """Run inference and return prediction."""
    with torch.no_grad():
        if image.ndim == 2:
            image = image.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
        elif image.ndim == 3:
            image = image.unsqueeze(0)  # Add batch dim
        
        image = image.float().to(DEVICE)
        output = model(image)
        pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
    return pred

# Run inference on sample
sample = test_dataset[0]
image = sample['image']
gt_mask = sample['mask'].numpy() if hasattr(sample['mask'], 'numpy') else sample['mask']

predictions = {}
for name, model in models.items():
    predictions[name] = run_inference(model, image)
    print(f"✓ {name} inference complete")

In [None]:
# Visualize predictions
n_models = len(models)
fig, axes = plt.subplots(2, n_models + 2, figsize=(4*(n_models+2), 8))

# Get image for display
img_display = image.numpy() if hasattr(image, 'numpy') else image
if img_display.ndim == 3:
    img_display = img_display[0]

# Row 1: Images and predictions
axes[0, 0].imshow(img_display, cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(gt_mask, cmap='jet')
axes[0, 1].set_title('Ground Truth')
axes[0, 1].axis('off')

for i, (name, pred) in enumerate(predictions.items()):
    axes[0, i+2].imshow(pred, cmap='jet')
    axes[0, i+2].set_title(name)
    axes[0, i+2].axis('off')

# Row 2: Overlay
axes[1, 0].imshow(img_display, cmap='gray')
axes[1, 0].set_title('Input')
axes[1, 0].axis('off')

axes[1, 1].imshow(img_display, cmap='gray')
axes[1, 1].imshow(gt_mask, cmap='jet', alpha=0.5)
axes[1, 1].set_title('GT Overlay')
axes[1, 1].axis('off')

for i, (name, pred) in enumerate(predictions.items()):
    axes[1, i+2].imshow(img_display, cmap='gray')
    axes[1, i+2].imshow(pred, cmap='jet', alpha=0.5)
    axes[1, i+2].set_title(f'{name} Overlay')
    axes[1, i+2].axis('off')

plt.tight_layout()
plt.show()

## 3. Quantitative Metrics

In [None]:
def compute_metrics(pred, target, num_classes=4):
    """Compute segmentation metrics."""
    metrics = {}
    
    # Overall Dice
    dice_scores = []
    for c in range(1, num_classes):  # Skip background
        pred_c = (pred == c).astype(float)
        target_c = (target == c).astype(float)
        
        intersection = np.sum(pred_c * target_c)
        union = np.sum(pred_c) + np.sum(target_c)
        
        if union > 0:
            dice = 2 * intersection / union
        else:
            dice = 1.0 if np.sum(target_c) == 0 else 0.0
        
        dice_scores.append(dice)
    
    metrics['dice_mean'] = np.mean(dice_scores)
    metrics['dice_lv'] = dice_scores[0]
    metrics['dice_myo'] = dice_scores[1]
    metrics['dice_la'] = dice_scores[2] if len(dice_scores) > 2 else 0.0
    
    return metrics

In [None]:
# Evaluate all models on test set
results = defaultdict(lambda: defaultdict(list))

print("Evaluating models on test set...")
for i in range(len(test_dataset)):
    sample = test_dataset[i]
    image = sample['image']
    gt = sample['mask'].numpy() if hasattr(sample['mask'], 'numpy') else sample['mask']
    
    for name, model in models.items():
        pred = run_inference(model, image)
        metrics = compute_metrics(pred, gt)
        
        for metric_name, value in metrics.items():
            results[name][metric_name].append(value)
    
    if (i + 1) % 50 == 0:
        print(f"  Processed {i+1}/{len(test_dataset)} samples")

print("Done!")

In [None]:
# Create results table
summary = []
for name in models.keys():
    row = {'Model': name}
    for metric in ['dice_mean', 'dice_lv', 'dice_myo', 'dice_la']:
        values = results[name][metric]
        row[f'{metric}_mean'] = np.mean(values)
        row[f'{metric}_std'] = np.std(values)
    summary.append(row)

df = pd.DataFrame(summary)

# Display formatted table
print("\n" + "="*80)
print("MODEL COMPARISON RESULTS")
print("="*80)
print(f"{'Model':<20} {'Mean Dice':<15} {'LV Dice':<15} {'Myo Dice':<15} {'LA Dice':<15}")
print("-"*80)
for _, row in df.iterrows():
    print(f"{row['Model']:<20} "
          f"{row['dice_mean_mean']:.4f}±{row['dice_mean_std']:.4f}  "
          f"{row['dice_lv_mean']:.4f}±{row['dice_lv_std']:.4f}  "
          f"{row['dice_myo_mean']:.4f}±{row['dice_myo_std']:.4f}  "
          f"{row['dice_la_mean']:.4f}±{row['dice_la_std']:.4f}")
print("="*80)

## 4. Per-Class Analysis

In [None]:
# Box plot comparison
CLASS_NAMES = ['LV Endocardium', 'Myocardium', 'Left Atrium']

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, (metric, class_name) in enumerate(zip(['dice_lv', 'dice_myo', 'dice_la'], CLASS_NAMES)):
    data = [results[name][metric] for name in models.keys()]
    bp = axes[i].boxplot(data, labels=models.keys(), patch_artist=True)
    
    colors = ['#ff9999', '#66b3ff', '#99ff99']
    for patch, color in zip(bp['boxes'], colors[:len(models)]):
        patch.set_facecolor(color)
    
    axes[i].set_ylabel('Dice Score')
    axes[i].set_title(f'{class_name}')
    axes[i].set_ylim([0.5, 1.0])
    plt.setp(axes[i].get_xticklabels(), rotation=45, ha='right')

plt.suptitle('Per-Class Dice Score Comparison', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Statistical Significance Testing

In [None]:
# Paired t-tests between models
print("Statistical Significance Tests (Paired t-test)")
print("="*60)

model_names = list(models.keys())
baseline = model_names[0]  # First model is baseline

for name in model_names[1:]:
    baseline_dice = results[baseline]['dice_mean']
    model_dice = results[name]['dice_mean']
    
    t_stat, p_value = stats.ttest_rel(model_dice, baseline_dice)
    
    improvement = np.mean(model_dice) - np.mean(baseline_dice)
    
    print(f"\n{name} vs {baseline}:")
    print(f"  Mean improvement: {improvement*100:+.2f}%")
    print(f"  t-statistic: {t_stat:.3f}")
    print(f"  p-value: {p_value:.6f}")
    print(f"  Significant (α=0.05): {'Yes' if p_value < 0.05 else 'No'}")

In [None]:
# Wilcoxon signed-rank test (non-parametric alternative)
print("\nWilcoxon Signed-Rank Test (non-parametric)")
print("="*60)

for name in model_names[1:]:
    baseline_dice = results[baseline]['dice_mean']
    model_dice = results[name]['dice_mean']
    
    stat, p_value = stats.wilcoxon(model_dice, baseline_dice)
    
    print(f"\n{name} vs {baseline}:")
    print(f"  W-statistic: {stat:.3f}")
    print(f"  p-value: {p_value:.6f}")
    print(f"  Significant (α=0.05): {'Yes' if p_value < 0.05 else 'No'}")

## 6. Efficiency Analysis

In [None]:
# Inference time comparison
def benchmark_model(model, input_size=(1, 1, 256, 256), n_runs=100):
    """Benchmark model inference time."""
    model.eval()
    dummy_input = torch.randn(input_size).to(DEVICE)
    
    # Warmup
    for _ in range(10):
        with torch.no_grad():
            _ = model(dummy_input)
    
    # Synchronize if using CUDA
    if DEVICE.type == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        with torch.no_grad():
            _ = model(dummy_input)
        if DEVICE.type == 'cuda':
            torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    return np.mean(times) * 1000, np.std(times) * 1000  # Return in ms

print("Inference Time Benchmark")
print("="*50)
timing_results = {}
for name, model in models.items():
    mean_time, std_time = benchmark_model(model)
    timing_results[name] = {'mean': mean_time, 'std': std_time}
    print(f"{name}: {mean_time:.2f} ± {std_time:.2f} ms")

In [None]:
# Memory usage
def get_model_memory(model, input_size=(1, 1, 256, 256)):
    """Estimate model memory usage."""
    if DEVICE.type != 'cuda':
        return 0, 0
    
    torch.cuda.reset_peak_memory_stats()
    dummy_input = torch.randn(input_size).to(DEVICE)
    
    with torch.no_grad():
        _ = model(dummy_input)
    
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2  # MB
    current_memory = torch.cuda.memory_allocated() / 1024**2  # MB
    
    return peak_memory, current_memory

print("\nGPU Memory Usage")
print("="*50)
for name, model in models.items():
    peak, current = get_model_memory(model)
    print(f"{name}: Peak={peak:.1f}MB, Current={current:.1f}MB")

In [None]:
# Efficiency plot: Dice vs Parameters vs Speed
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Dice vs Parameters
params = [sum(p.numel() for p in m.parameters())/1e6 for m in models.values()]
dice_scores = [np.mean(results[n]['dice_mean']) for n in models.keys()]

axes[0].scatter(params, dice_scores, s=200, c=['red', 'blue', 'green'][:len(models)])
for i, name in enumerate(models.keys()):
    axes[0].annotate(name, (params[i], dice_scores[i]), 
                     textcoords="offset points", xytext=(0,10), ha='center')
axes[0].set_xlabel('Parameters (M)')
axes[0].set_ylabel('Mean Dice Score')
axes[0].set_title('Accuracy vs Model Size')
axes[0].grid(True, alpha=0.3)

# Dice vs Inference Time
times = [timing_results[n]['mean'] for n in models.keys()]

axes[1].scatter(times, dice_scores, s=200, c=['red', 'blue', 'green'][:len(models)])
for i, name in enumerate(models.keys()):
    axes[1].annotate(name, (times[i], dice_scores[i]), 
                     textcoords="offset points", xytext=(0,10), ha='center')
axes[1].set_xlabel('Inference Time (ms)')
axes[1].set_ylabel('Mean Dice Score')
axes[1].set_title('Accuracy vs Speed')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Qualitative Comparison

In [None]:
# Show multiple example comparisons
n_examples = 5
indices = np.random.choice(len(test_dataset), n_examples, replace=False)

fig, axes = plt.subplots(n_examples, len(models) + 2, figsize=(4*(len(models)+2), 4*n_examples))

for row, idx in enumerate(indices):
    sample = test_dataset[idx]
    image = sample['image']
    gt = sample['mask'].numpy() if hasattr(sample['mask'], 'numpy') else sample['mask']
    
    img_display = image.numpy() if hasattr(image, 'numpy') else image
    if img_display.ndim == 3:
        img_display = img_display[0]
    
    # Input
    axes[row, 0].imshow(img_display, cmap='gray')
    axes[row, 0].set_title('Input' if row == 0 else '')
    axes[row, 0].axis('off')
    
    # Ground truth
    axes[row, 1].imshow(img_display, cmap='gray')
    axes[row, 1].imshow(gt, cmap='jet', alpha=0.5)
    axes[row, 1].set_title('Ground Truth' if row == 0 else '')
    axes[row, 1].axis('off')
    
    # Predictions
    for col, (name, model) in enumerate(models.items()):
        pred = run_inference(model, image)
        dice = compute_metrics(pred, gt)['dice_mean']
        
        axes[row, col+2].imshow(img_display, cmap='gray')
        axes[row, col+2].imshow(pred, cmap='jet', alpha=0.5)
        if row == 0:
            axes[row, col+2].set_title(name)
        axes[row, col+2].text(5, 20, f'Dice: {dice:.3f}', fontsize=10, 
                              color='white', backgroundcolor='black')
        axes[row, col+2].axis('off')

plt.suptitle('Qualitative Comparison', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Error analysis: Show cases where Mamba improves most over baseline
if len(models) >= 2:
    baseline_name = list(models.keys())[0]
    mamba_name = list(models.keys())[1]
    
    improvements = np.array(results[mamba_name]['dice_mean']) - np.array(results[baseline_name]['dice_mean'])
    top_indices = np.argsort(improvements)[-5:][::-1]  # Top 5 improvements
    
    print(f"\nCases where {mamba_name} improved most over {baseline_name}:")
    for idx in top_indices:
        print(f"  Sample {idx}: +{improvements[idx]*100:.1f}% Dice")
    
    # Visualize
    fig, axes = plt.subplots(len(top_indices), 4, figsize=(16, 4*len(top_indices)))
    
    for row, idx in enumerate(top_indices):
        sample = test_dataset[idx]
        image = sample['image']
        gt = sample['mask'].numpy() if hasattr(sample['mask'], 'numpy') else sample['mask']
        
        img_display = image.numpy() if hasattr(image, 'numpy') else image
        if img_display.ndim == 3:
            img_display = img_display[0]
        
        baseline_pred = run_inference(models[baseline_name], image)
        mamba_pred = run_inference(models[mamba_name], image)
        
        axes[row, 0].imshow(img_display, cmap='gray')
        axes[row, 0].set_title('Input')
        axes[row, 0].axis('off')
        
        axes[row, 1].imshow(img_display, cmap='gray')
        axes[row, 1].imshow(gt, cmap='jet', alpha=0.5)
        axes[row, 1].set_title('Ground Truth')
        axes[row, 1].axis('off')
        
        axes[row, 2].imshow(img_display, cmap='gray')
        axes[row, 2].imshow(baseline_pred, cmap='jet', alpha=0.5)
        axes[row, 2].set_title(f'{baseline_name}: {results[baseline_name]["dice_mean"][idx]:.3f}')
        axes[row, 2].axis('off')
        
        axes[row, 3].imshow(img_display, cmap='gray')
        axes[row, 3].imshow(mamba_pred, cmap='jet', alpha=0.5)
        axes[row, 3].set_title(f'{mamba_name}: {results[mamba_name]["dice_mean"][idx]:.3f} (+{improvements[idx]*100:.1f}%)')
        axes[row, 3].axis('off')
    
    plt.suptitle('Cases with Largest Improvement', fontsize=14)
    plt.tight_layout()
    plt.show()

## Summary

In [None]:
# Final summary
print("="*70)
print("MODEL COMPARISON SUMMARY")
print("="*70)

for name in models.keys():
    params = sum(p.numel() for p in models[name].parameters()) / 1e6
    dice = np.mean(results[name]['dice_mean'])
    time_ms = timing_results[name]['mean']
    
    print(f"\n{name}:")
    print(f"  Parameters:     {params:.2f}M")
    print(f"  Mean Dice:      {dice:.4f}")
    print(f"  Inference Time: {time_ms:.2f}ms")
    print(f"  FPS:            {1000/time_ms:.1f}")

print("\n" + "="*70)