# 04 — ViT Experiments: Transformer & Modern CNN Training

**Goal**: Train and evaluate four architectures on Euclid Q1 galaxy morphology,
using the same pipeline as the Zoobot baseline (notebook 03).

| Model | Architecture | Params | Pre-training | Input size |
|-------|-------------|--------|-------------|------------|
| ViT-Base/16 | Vanilla ViT | ~86M | ImageNet-21k | 224×224 |
| Swin-V2-Base | Hierarchical ViT | ~88M | ImageNet-21k | 256×256 |
| DINOv2 ViT-B/14 | SSL Transformer | ~87M | LVD-142M | 224×224 |
| ConvNeXt-Base | Modern CNN | ~89M | ImageNet-21k | 224×224 |

Each model follows the same two-phase training:
1. **Linear probe** (frozen backbone, 5 epochs, lr=1e-3)
2. **Full fine-tune** (all layers, 25 epochs, lr=5e-5, cosine + warmup)

Results saved per model for comparison in notebook 05.

**Runtime**: ~30-40 min per model on A100 (~2-3 hours total for all 4)

## 0. Colab Setup

Run this cell **only on Google Colab** — it clones the repo, installs dependencies, and downloads the data. Skip if running locally.

In [None]:
import os

IN_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython())

if IN_COLAB:
    # 1. Clone the repo (dev branch has all source code)
    REPO_URL = 'https://github.com/Smooth-Cactus0/euclid-q1-vit-morphology.git'
    REPO_DIR = '/content/euclid-q1-vit-morphology'

    if not os.path.exists(REPO_DIR):
        print('Cloning repository (dev branch)...')
        !git clone --branch dev {REPO_URL} {REPO_DIR}
    os.chdir(REPO_DIR)
    print(f'Working directory: {os.getcwd()}')

    # 2. Install dependencies
    print('\nInstalling dependencies...')
    !pip install -q timm tqdm scipy

    # 3. Download catalog
    CATALOG_PATH = 'data/raw/morphology_catalogue.parquet'
    if not os.path.exists(CATALOG_PATH):
        print('\nDownloading catalog...')
        !python scripts/download_catalog.py

    # 4. Generate splits
    SPLIT_PATH = 'data/processed/split_indices.json'
    if not os.path.exists(SPLIT_PATH):
        print('\nGenerating train/val/test splits...')
        !python scripts/prepare_splits.py

    # 5. Download and extract images (~3.8 GB, takes ~5 min)
    IMAGE_DIR = 'data/raw/images'
    if not os.path.exists(IMAGE_DIR) or len(os.listdir(IMAGE_DIR)) < 10:
        print('\nDownloading images (~3.8 GB)... This takes ~5 minutes.')
        !python scripts/download_images.py
    else:
        n_tiles = len([d for d in os.listdir(IMAGE_DIR) if os.path.isdir(os.path.join(IMAGE_DIR, d))])
        print(f'\nImages already present: {n_tiles} tiles')

    print('\nColab setup complete!')
else:
    print('Not running on Colab — skipping setup.')

## 1. Setup

In [None]:
import sys
from pathlib import Path
import os

# Set project root — Colab runs from repo root, local runs from notebooks/
if 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython()):
    PROJECT_ROOT = Path('/content/euclid-q1-vit-morphology')
else:
    PROJECT_ROOT = Path('..').resolve()

# Ensure project root is on sys.path
project_str = str(PROJECT_ROOT)
sys.path = [p for p in sys.path if p != project_str]
sys.path.insert(0, project_str)

print(f'PROJECT_ROOT: {PROJECT_ROOT}')
print(f'src/ exists: {(PROJECT_ROOT / "src").exists()}')
print(f'src/models/architectures.py exists: {(PROJECT_ROOT / "src" / "models" / "architectures.py").exists()}')

import json
import time
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from src.data.dataset import EuclidDataset, MorphologySchema
from src.data.transforms import get_transforms, get_tta_transforms
from src.models.factory import create_model, list_models
from src.training.losses import DirichletMultinomialLoss
from src.training.trainer import Trainer, TrainConfig
from src.evaluation.metrics import (
    compute_metrics,
    compute_per_question_metrics,
    bootstrap_confidence_interval,
)

import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.dpi'] = 120

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nDevice: {device}')
if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    gpu_mem_gb = props.total_memory / 1e9
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {gpu_mem_gb:.1f} GB')

print(f'\nAvailable models: {list_models()}')

## 2. Configuration

In [None]:
# Paths
CATALOG_PATH = PROJECT_ROOT / 'data' / 'raw' / 'morphology_catalogue.parquet'
SPLIT_PATH = PROJECT_ROOT / 'data' / 'processed' / 'split_indices.json'
IMAGE_DIR = PROJECT_ROOT / 'data' / 'raw' / 'images'
CHECKPOINT_DIR = PROJECT_ROOT / 'results' / 'checkpoints'
FIGURES_DIR = PROJECT_ROOT / 'results' / 'figures'
TABLES_DIR = PROJECT_ROOT / 'results' / 'tables'

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
TABLES_DIR.mkdir(parents=True, exist_ok=True)

# Models to train (in order)
MODELS_TO_TRAIN = ['vit-base', 'swin-v2', 'dinov2', 'convnext']

# Model-specific input sizes
# Swin-V2-Base was pretrained at 192→256, so we use 256 for best performance.
# DINOv2 patch14: 224/14=16 patches — works perfectly.
# ViT-Base patch16: 224/16=14 patches — standard.
MODEL_INPUT_SIZES = {
    'vit-base': 224,
    'swin-v2': 256,
    'dinov2': 224,
    'convnext': 224,
}

# Training hyperparameters (from configs/base.yaml)
LR_PROBE = 1e-3
LR_FINETUNE = 5e-5
WEIGHT_DECAY = 0.01
LINEAR_PROBE_EPOCHS = 5
FINETUNE_EPOCHS = 25
PATIENCE = 5
NUM_WORKERS = 2 if IN_COLAB else 4

# Base batch size — adjusted per model and GPU
BASE_BATCH_SIZE = 32
if torch.cuda.is_available():
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    if gpu_mem > 30:       # A100
        BASE_BATCH_SIZE = 64
    elif gpu_mem > 14:     # V100 / T4
        BASE_BATCH_SIZE = 32
    else:
        BASE_BATCH_SIZE = 16

# Swin-V2 at 256×256 uses more memory; reduce batch size if needed
MODEL_BATCH_SIZES = {
    'vit-base': BASE_BATCH_SIZE,
    'swin-v2': max(BASE_BATCH_SIZE // 2, 16),  # 256×256 needs more VRAM
    'dinov2': BASE_BATCH_SIZE,
    'convnext': BASE_BATCH_SIZE,
}

# Augmentation config (matching configs/base.yaml)
AUGMENTATION_CFG = {
    'random_horizontal_flip': True,
    'random_vertical_flip': True,
    'random_rotation': 360,
    'color_jitter': {
        'brightness': 0.1,
        'contrast': 0.1,
        'saturation': 0.0,
        'hue': 0.0,
    },
    'random_resized_crop': {
        'scale': [0.85, 1.0],
        'ratio': [0.95, 1.05],
    },
}

print('Models to train:', MODELS_TO_TRAIN)
print('Batch sizes:', MODEL_BATCH_SIZES)
print('Input sizes:', MODEL_INPUT_SIZES)
print(f'Num workers: {NUM_WORKERS}')
print(f'Training: {LINEAR_PROBE_EPOCHS} probe + {FINETUNE_EPOCHS} fine-tune epochs per model')

## 3. Data Loading

Load the catalog splits once. DataLoaders are rebuilt per model (since input sizes differ).

In [None]:
# Schema
schema = MorphologySchema.default()
print(f'Schema: {schema.num_outputs} outputs across {len(schema.questions)} questions')

# Load splits (once — reused across all models)
train_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'train')
val_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'val')
test_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'test')

print(f'\nTrain: {len(train_df):,}')
print(f'Val:   {len(val_df):,}')
print(f'Test:  {len(test_df):,}')

In [None]:
def make_dataloaders(input_size, batch_size):
    """Build train/val/test DataLoaders for a given input size and batch size."""
    train_tfm = get_transforms('train', input_size, augmentation_cfg=AUGMENTATION_CFG)
    val_tfm = get_transforms('val', input_size)

    train_ds = EuclidDataset(train_df, IMAGE_DIR, schema, train_tfm)
    val_ds = EuclidDataset(val_df, IMAGE_DIR, schema, val_tfm)
    test_ds = EuclidDataset(test_df, IMAGE_DIR, schema, val_tfm)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True,
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True,
    )
    return train_loader, val_loader, test_loader


print('DataLoader factory ready.')

## 4. Helper Functions

In [None]:
@torch.no_grad()
def predict(model, loader, device):
    """Run inference and collect predictions, targets, masks."""
    model.eval()
    all_preds, all_targets, all_masks = [], [], []
    
    for images, targets, masks in loader:
        images = images.to(device)
        logits = model(images)
        
        # Per-question softmax to get vote fractions
        preds = torch.zeros_like(logits)
        for q, (start, end) in schema.question_slices.items():
            preds[:, start:end] = torch.softmax(logits[:, start:end], dim=1)
        
        all_preds.append(preds.cpu().numpy())
        all_targets.append(targets.numpy())
        all_masks.append(masks.numpy())
    
    return (
        np.concatenate(all_preds),
        np.concatenate(all_targets),
        np.concatenate(all_masks),
    )


@torch.no_grad()
def predict_with_tta(model, dataset_df, image_dir, schema, device,
                     input_size=224, batch_size=32, num_workers=2):
    """Run TTA inference: predict with each transform, then average."""
    model.eval()
    tta_transforms = get_tta_transforms(input_size=input_size)
    print(f'  TTA with {len(tta_transforms)} views')
    
    all_view_preds = []
    
    for i, tfm in enumerate(tta_transforms):
        ds = EuclidDataset(dataset_df, image_dir, schema, tfm)
        loader = DataLoader(
            ds, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=True,
        )
        
        preds_list = []
        for images, _, _ in loader:
            images = images.to(device)
            logits = model(images)
            preds = torch.zeros_like(logits)
            for q, (start, end) in schema.question_slices.items():
                preds[:, start:end] = torch.softmax(logits[:, start:end], dim=1)
            preds_list.append(preds.cpu().numpy())
        
        all_view_preds.append(np.concatenate(preds_list))
        print(f'    View {i+1}/{len(tta_transforms)} done')
    
    # Average across TTA views
    tta_preds = np.mean(all_view_preds, axis=0)
    
    # Get targets and masks
    from src.data.transforms import get_transforms as _gt
    orig_ds = EuclidDataset(dataset_df, image_dir, schema, _gt('val', input_size))
    targets = orig_ds.targets.numpy()
    masks = orig_ds.masks.numpy()
    
    return tta_preds, targets, masks


def plot_training_curves(history_df, model_name, best_epoch, best_loss, save_path):
    """Plot loss, LR, and timing curves for a single model."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Loss
    ax = axes[0]
    ax.plot(history_df['epoch'], history_df['train_loss'], 'b-', label='Train', alpha=0.8)
    ax.plot(history_df['epoch'], history_df['val_loss'], 'r-', label='Val', alpha=0.8)
    ax.axvline(best_epoch, color='gray', linestyle='--', alpha=0.5, label=f'Best (epoch {best_epoch})')
    ax.scatter([best_epoch], [best_loss], color='red', zorder=5, s=50)
    probe_epochs = history_df[history_df['phase'] == 'probe']
    if len(probe_epochs) > 0:
        boundary = probe_epochs['epoch'].max() + 0.5
        ax.axvline(boundary, color='green', linestyle=':', alpha=0.5, label='Phase boundary')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Dirichlet-Multinomial Loss')
    ax.set_title('Training & Validation Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # LR
    ax = axes[1]
    ax.plot(history_df['epoch'], history_df['lr'], 'g-', linewidth=2)
    if len(probe_epochs) > 0:
        ax.axvline(boundary, color='green', linestyle=':', alpha=0.5)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Learning Rate')
    ax.set_title('LR Schedule')
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)

    # Timing
    ax = axes[2]
    colors = ['steelblue' if p == 'probe' else 'coral' for p in history_df['phase']]
    ax.bar(history_df['epoch'], history_df['time_s'], color=colors, alpha=0.7)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Time (s)')
    ax.set_title('Epoch Duration')
    ax.grid(True, alpha=0.3, axis='y')

    fig.suptitle(f'{model_name} Training — Best val_loss: {best_loss:.4f}', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.show()


print('Helper functions defined.')

## 5. Training Loop

Train each model sequentially. For each model:
1. Create model and DataLoaders (with model-specific input size)
2. Two-phase training (linear probe → full fine-tune)
3. Plot training curves
4. Evaluate on test set (with and without TTA)
5. Bootstrap confidence intervals
6. Save all results
7. Free GPU memory before the next model

In [None]:
from sklearn.metrics import mean_squared_error, r2_score

def mse_metric(preds, targets, masks):
    valid = masks.astype(bool).flatten()
    return mean_squared_error(targets.flatten()[valid], preds.flatten()[valid])

def r2_metric(preds, targets, masks):
    valid = masks.astype(bool).flatten()
    return r2_score(targets.flatten()[valid], preds.flatten()[valid])


# Store all results for comparison
all_results = {}

for model_name in MODELS_TO_TRAIN:
    print(f'\n{"#" * 70}')
    print(f'# MODEL: {model_name}')
    print(f'{"#" * 70}')
    
    input_size = MODEL_INPUT_SIZES[model_name]
    batch_size = MODEL_BATCH_SIZES[model_name]
    
    # --- Seed reset for fair comparison ---
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    
    # --- 1. DataLoaders ---
    print(f'\n  Input size: {input_size}×{input_size}, batch_size: {batch_size}')
    train_loader, val_loader, test_loader = make_dataloaders(input_size, batch_size)
    print(f'  Train: {len(train_loader)} batches, Val: {len(val_loader)}, Test: {len(test_loader)}')
    
    # --- 2. Model ---
    print(f'\n  Creating model: {model_name}...')
    model = create_model(model_name, num_outputs=schema.num_outputs, pretrained=True)
    params = model.count_parameters()
    print(f'  Parameters: {params["total"]:,} total, {params["trainable"]:,} trainable')
    
    # Quick forward-pass sanity check
    with torch.no_grad():
        dummy = torch.randn(2, 3, input_size, input_size)
        out = model(dummy)
        print(f'  Forward pass OK: {dummy.shape} → {out.shape}')
    
    # --- 3. Train ---
    criterion = DirichletMultinomialLoss(schema)
    train_config = TrainConfig(
        lr=LR_FINETUNE,
        lr_linear_probe=LR_PROBE,
        weight_decay=WEIGHT_DECAY,
        batch_size=batch_size,
        epochs=LINEAR_PROBE_EPOCHS + FINETUNE_EPOCHS,
        linear_probe_epochs=LINEAR_PROBE_EPOCHS,
        full_finetune_epochs=FINETUNE_EPOCHS,
        warmup_fraction=0.05,
        patience=PATIENCE,
        checkpoint_dir=str(CHECKPOINT_DIR),
        seed=SEED,
    )
    
    trainer = Trainer(
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        config=train_config,
        device=device,
        model_name=model_name,
    )
    
    print(f'\n  Starting training...')
    t_start = time.time()
    summary = trainer.train()
    t_total = time.time() - t_start
    
    best_epoch = summary['best_epoch']
    best_loss = summary['best_val_loss']
    print(f'\n  Training time: {t_total/60:.1f} min')
    print(f'  Best val_loss: {best_loss:.4f} (epoch {best_epoch})')
    
    # --- 4. Training curves ---
    history = pd.DataFrame(trainer.history)
    plot_training_curves(
        history, model_name, best_epoch, best_loss,
        FIGURES_DIR / f'{model_name}_training_curves.pdf',
    )
    
    # --- 5. Load best checkpoint & evaluate ---
    best_ckpt = CHECKPOINT_DIR / f'{model_name}_best.pt'
    model.load_state_dict(torch.load(best_ckpt, map_location=device, weights_only=True))
    model.to(device)
    
    # Standard evaluation (no TTA)
    print(f'\n  Evaluating on test set (no TTA)...')
    test_preds, test_targets, test_masks = predict(model, test_loader, device)
    metrics = compute_metrics(test_preds, test_targets, test_masks, schema)
    per_q = compute_per_question_metrics(test_preds, test_targets, test_masks, schema)
    
    print(f'  MSE={metrics["mse"]:.4f}  MAE={metrics["mae"]:.4f}  '
          f'R²={metrics["r2"]:.4f}  Acc={metrics.get("accuracy_mean", float("nan")):.4f}')
    
    # TTA evaluation
    print(f'\n  Evaluating with TTA...')
    tta_preds, tta_targets, tta_masks = predict_with_tta(
        model, test_df, IMAGE_DIR, schema, device,
        input_size=input_size, batch_size=batch_size, num_workers=NUM_WORKERS,
    )
    tta_metrics = compute_metrics(tta_preds, tta_targets, tta_masks, schema)
    tta_per_q = compute_per_question_metrics(tta_preds, tta_targets, tta_masks, schema)
    
    print(f'  MSE={tta_metrics["mse"]:.4f}  MAE={tta_metrics["mae"]:.4f}  '
          f'R²={tta_metrics["r2"]:.4f}  Acc={tta_metrics.get("accuracy_mean", float("nan")):.4f}')
    
    # --- 6. Bootstrap CIs ---
    print(f'\n  Computing bootstrap CIs (1000 iterations)...')
    mse_point, mse_lo, mse_hi = bootstrap_confidence_interval(
        tta_preds, tta_targets, tta_masks, mse_metric, n_iterations=1000,
    )
    r2_point, r2_lo, r2_hi = bootstrap_confidence_interval(
        tta_preds, tta_targets, tta_masks, r2_metric, n_iterations=1000,
    )
    print(f'  MSE (TTA): {mse_point:.4f}  [{mse_lo:.4f}, {mse_hi:.4f}]')
    print(f'  R²  (TTA): {r2_point:.4f}  [{r2_lo:.4f}, {r2_hi:.4f}]')
    
    # --- 7. Save results ---
    model_results = {
        'model_name': model_name,
        'architecture': model.__class__.__name__,
        'pretrained': 'ImageNet-21k' if model_name != 'dinov2' else 'LVD-142M (DINO)',
        'input_size': input_size,
        'params_total': params['total'],
        'params_trainable': params['trainable'],
        'training': {
            'batch_size': batch_size,
            'lr_probe': LR_PROBE,
            'lr_finetune': LR_FINETUNE,
            'linear_probe_epochs': LINEAR_PROBE_EPOCHS,
            'finetune_epochs': FINETUNE_EPOCHS,
            'total_epochs_trained': len(trainer.history),
            'best_epoch': best_epoch,
            'best_val_loss': best_loss,
            'training_time_min': t_total / 60,
        },
        'metrics_no_tta': metrics,
        'metrics_tta': tta_metrics,
        'per_question_tta': {q: dict(m) for q, m in tta_per_q.items()},
        'bootstrap_ci': {
            'mse': {'point': mse_point, 'ci_lower': mse_lo, 'ci_upper': mse_hi},
            'r2': {'point': r2_point, 'ci_lower': r2_lo, 'ci_upper': r2_hi},
        },
    }
    
    # Save JSON
    results_path = TABLES_DIR / f'{model_name}_results.json'
    with open(results_path, 'w') as f:
        json.dump(model_results, f, indent=2, default=float)
    print(f'  Results saved: {results_path}')
    
    # Save predictions
    np.savez_compressed(
        TABLES_DIR / f'{model_name}_predictions.npz',
        predictions=tta_preds,
        targets=tta_targets,
        masks=tta_masks,
    )
    
    # Save per-question CSV
    pq_rows = []
    for question, qm in tta_per_q.items():
        if qm.get('n_valid', 0) == 0:
            continue
        pq_rows.append({
            'question': question,
            'n_valid': qm['n_valid'],
            'mse': qm['mse'],
            'mae': qm['mae'],
            'r2': qm.get('r2', np.nan),
            'pearson_r': qm.get('pearson_r', np.nan),
            'accuracy': qm.get('accuracy', np.nan),
            'f1_weighted': qm.get('f1_weighted', np.nan),
        })
    pd.DataFrame(pq_rows).to_csv(TABLES_DIR / f'{model_name}_per_question.csv', index=False)
    
    # Save training history
    history.to_csv(TABLES_DIR / f'{model_name}_history.csv', index=False)
    
    all_results[model_name] = model_results
    
    # --- 8. Cleanup GPU memory ---
    del model, trainer, criterion
    del train_loader, val_loader, test_loader
    del test_preds, test_targets, test_masks
    del tta_preds, tta_targets, tta_masks
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print(f'\n  {model_name} complete! GPU memory freed.')

print(f'\n{"=" * 70}')
print(f'ALL {len(MODELS_TO_TRAIN)} MODELS TRAINED SUCCESSFULLY')
print(f'{"=" * 70}')

## 6. Quick Comparison Table

Side-by-side comparison of all trained models. The full benchmarking analysis is in notebook 05.

In [None]:
# Build comparison DataFrame
comp_rows = []

# Include Zoobot baseline if available
zoobot_path = TABLES_DIR / 'zoobot_results.json'
if zoobot_path.exists():
    with open(zoobot_path) as f:
        zoobot_results = json.load(f)
    comp_rows.append({
        'Model': 'Zoobot (EfficientNet-B0)',
        'Params (M)': zoobot_results['params_total'] / 1e6,
        'MSE': zoobot_results['metrics_tta']['mse'],
        'MAE': zoobot_results['metrics_tta']['mae'],
        'R²': zoobot_results['metrics_tta']['r2'],
        'Pearson r': zoobot_results['metrics_tta'].get('pearson_r', np.nan),
        'Accuracy': zoobot_results['metrics_tta'].get('accuracy_mean', np.nan),
        'Train time (min)': zoobot_results['training']['training_time_min'],
    })

# Add all ViT experiment models
model_display_names = {
    'vit-base': 'ViT-Base/16',
    'swin-v2': 'Swin-V2-Base',
    'dinov2': 'DINOv2 ViT-B/14',
    'convnext': 'ConvNeXt-Base',
}

for model_name, results in all_results.items():
    tta = results['metrics_tta']
    comp_rows.append({
        'Model': model_display_names.get(model_name, model_name),
        'Params (M)': results['params_total'] / 1e6,
        'MSE': tta['mse'],
        'MAE': tta['mae'],
        'R²': tta['r2'],
        'Pearson r': tta.get('pearson_r', np.nan),
        'Accuracy': tta.get('accuracy_mean', np.nan),
        'Train time (min)': results['training']['training_time_min'],
    })

comp_df = pd.DataFrame(comp_rows)
print(comp_df.to_string(index=False, float_format='{:.4f}'.format))

In [None]:
# Bar chart: R² and MSE comparison
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

models = comp_df['Model']
x = np.arange(len(models))
colors = plt.cm.Set2(np.linspace(0, 1, len(models)))

# R²
ax = axes[0]
bars = ax.barh(x, comp_df['R²'], color=colors, edgecolor='gray', alpha=0.8)
ax.set_yticks(x)
ax.set_yticklabels(models, fontsize=10)
ax.set_xlabel('R² (higher is better)')
ax.set_title('R² — All Models (with TTA)')
ax.grid(True, alpha=0.3, axis='x')
for bar, val in zip(bars, comp_df['R²']):
    ax.text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,
            f'{val:.4f}', va='center', fontsize=9)

# MSE
ax = axes[1]
bars = ax.barh(x, comp_df['MSE'], color=colors, edgecolor='gray', alpha=0.8)
ax.set_yticks(x)
ax.set_yticklabels(models, fontsize=10)
ax.set_xlabel('MSE (lower is better)')
ax.set_title('MSE — All Models (with TTA)')
ax.grid(True, alpha=0.3, axis='x')
for bar, val in zip(bars, comp_df['MSE']):
    ax.text(bar.get_width() + 0.0005, bar.get_y() + bar.get_height()/2,
            f'{val:.4f}', va='center', fontsize=9)

fig.suptitle('Model Comparison — Test Set Metrics (with TTA)', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'model_comparison_r2_mse.pdf', bbox_inches='tight')
plt.show()

## 7. Per-Question Comparison

Which model is best at each morphology question?

In [None]:
# Load per-question CSVs for all models
all_model_names = ['zoobot'] + MODELS_TO_TRAIN
all_display_names = {
    'zoobot': 'Zoobot',
    **model_display_names,
}

per_q_data = {}
for mn in all_model_names:
    csv_path = TABLES_DIR / f'{mn}_per_question.csv'
    if csv_path.exists():
        per_q_data[mn] = pd.read_csv(csv_path)

if len(per_q_data) > 1:
    questions = per_q_data[list(per_q_data.keys())[0]]['question'].tolist()
    
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    
    x = np.arange(len(questions))
    width = 0.8 / len(per_q_data)
    model_colors = plt.cm.Set2(np.linspace(0, 1, len(per_q_data)))
    
    for i, (mn, df) in enumerate(per_q_data.items()):
        offset = (i - len(per_q_data)/2 + 0.5) * width
        
        # R² per question
        axes[0].barh(x + offset, df['r2'], height=width,
                     label=all_display_names.get(mn, mn),
                     color=model_colors[i], alpha=0.8, edgecolor='gray')
        
        # Accuracy per question
        axes[1].barh(x + offset, df['accuracy'], height=width,
                     label=all_display_names.get(mn, mn),
                     color=model_colors[i], alpha=0.8, edgecolor='gray')
    
    for ax, metric_name in zip(axes, ['R²', 'Accuracy']):
        ax.set_yticks(x)
        ax.set_yticklabels(questions, fontsize=10)
        ax.set_xlabel(metric_name)
        ax.set_title(f'{metric_name} per Morphology Question')
        ax.legend(loc='lower right', fontsize=9)
        ax.grid(True, alpha=0.3, axis='x')
    
    fig.suptitle('Per-Question Model Comparison (with TTA)', fontsize=14)
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'per_question_model_comparison.pdf', bbox_inches='tight')
    plt.show()
else:
    print('Not enough models with per-question results for comparison.')

## 8. Training Dynamics Comparison

Overlay validation loss curves for all models to see convergence speed.

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

line_styles = ['-', '--', '-.', ':', '-']
model_colors = plt.cm.tab10(np.linspace(0, 0.5, len(all_model_names)))

for i, mn in enumerate(all_model_names):
    hist_path = TABLES_DIR / f'{mn}_history.csv'
    if not hist_path.exists():
        continue
    hist = pd.read_csv(hist_path)
    ax.plot(hist['epoch'], hist['val_loss'],
            linestyle=line_styles[i % len(line_styles)],
            color=model_colors[i],
            label=all_display_names.get(mn, mn),
            linewidth=2, alpha=0.8)

ax.axvline(LINEAR_PROBE_EPOCHS + 0.5, color='gray', linestyle=':', alpha=0.4,
           label='Probe → Fine-tune')
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Loss (DM)')
ax.set_title('Validation Loss Convergence — All Models')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'all_models_val_loss.pdf', bbox_inches='tight')
plt.show()

## 9. Summary

In [None]:
print('=' * 70)
print('NOTEBOOK 04 — ViT EXPERIMENTS SUMMARY')
print('=' * 70)

for model_name, results in all_results.items():
    tta = results['metrics_tta']
    ci = results['bootstrap_ci']
    t = results['training']
    print(f'''
  {model_display_names.get(model_name, model_name)}
    Parameters:   {results['params_total']:,}
    Input size:   {results['input_size']}×{results['input_size']}
    Best epoch:   {t['best_epoch']} / {t['total_epochs_trained']}
    Train time:   {t['training_time_min']:.1f} min
    MSE (TTA):    {tta['mse']:.4f}  [{ci['mse']['ci_lower']:.4f}, {ci['mse']['ci_upper']:.4f}]
    R² (TTA):     {tta['r2']:.4f}  [{ci['r2']['ci_lower']:.4f}, {ci['r2']['ci_upper']:.4f}]
    Pearson r:    {tta.get('pearson_r', float('nan')):.4f}
    Accuracy:     {tta.get('accuracy_mean', float('nan')):.4f}''')

print(f'''
{'=' * 70}
All results saved in: {TABLES_DIR}/
  - <model>_results.json       (full metrics + CIs)
  - <model>_predictions.npz    (raw predictions for post-hoc analysis)
  - <model>_per_question.csv   (per-question breakdown)
  - <model>_history.csv        (training curves)

Next: notebooks/05_benchmarking.ipynb
  - Aggregate comparison tables (LaTeX-ready)
  - Statistical significance tests (paired bootstrap)
  - Inference speed benchmarking
{'=' * 70}''')