# 03 — Zoobot Baseline: EfficientNet-B0 Reproduction

**Goal**: Train and evaluate the Zoobot-style EfficientNet-B0 baseline on Euclid Q1 data.

This notebook:
1. Loads the data pipeline (splits, transforms, DataLoaders)
2. Creates the Zoobot model (EfficientNet-B0 + regression head)
3. Trains with Dirichlet-Multinomial loss (matching Zoobot)
4. Two-phase training: linear probe (5 epochs) → full fine-tune (25 epochs)
5. Evaluates with per-question metrics and TTA
6. Saves results for later model comparison (notebook 05)

**Runtime**: ~20 min on T4, ~10 min on A100

## 0. Colab Setup

Run this cell **only on Google Colab** — it clones the repo, installs dependencies, and downloads the data. Skip if running locally.

In [None]:
import os

IN_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython())

if IN_COLAB:
    # 1. Clone the repo (dev branch has all source code)
    REPO_URL = 'https://github.com/Smooth-Cactus0/euclid-q1-vit-morphology.git'
    REPO_DIR = '/content/euclid-q1-vit-morphology'

    if not os.path.exists(REPO_DIR):
        print('Cloning repository (dev branch)...')
        !git clone --branch dev {REPO_URL} {REPO_DIR}
    os.chdir(REPO_DIR)
    print(f'Working directory: {os.getcwd()}')

    # 2. Install dependencies
    print('\nInstalling dependencies...')
    !pip install -q timm tqdm scipy

    # 3. Download catalog
    CATALOG_PATH = 'data/raw/morphology_catalogue.parquet'
    if not os.path.exists(CATALOG_PATH):
        print('\nDownloading catalog...')
        !python scripts/download_catalog.py

    # 4. Generate splits
    SPLIT_PATH = 'data/processed/split_indices.json'
    if not os.path.exists(SPLIT_PATH):
        print('\nGenerating train/val/test splits...')
        !python scripts/prepare_splits.py

    # 5. Download and extract images (~3.8 GB, takes ~5 min)
    IMAGE_DIR = 'data/raw/images'
    if not os.path.exists(IMAGE_DIR) or len(os.listdir(IMAGE_DIR)) < 10:
        print('\nDownloading images (~3.8 GB)... This takes ~5 minutes.')
        !python scripts/download_images.py
    else:
        n_tiles = len([d for d in os.listdir(IMAGE_DIR) if os.path.isdir(os.path.join(IMAGE_DIR, d))])
        print(f'\nImages already present: {n_tiles} tiles')

    print('\nColab setup complete!')
else:
    print('Not running on Colab — skipping setup.')

## 1. Setup

In [None]:
import sys
from pathlib import Path
import os

# Set project root — Colab runs from repo root, local runs from notebooks/
if 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython()):
    PROJECT_ROOT = Path('/content/euclid-q1-vit-morphology')
else:
    PROJECT_ROOT = Path('..').resolve()

# Ensure project root is on sys.path (remove stale entries first)
project_str = str(PROJECT_ROOT)
sys.path = [p for p in sys.path if p != project_str]
sys.path.insert(0, project_str)

print(f'PROJECT_ROOT: {PROJECT_ROOT}')
print(f'src/ exists: {(PROJECT_ROOT / "src").exists()}')
print(f'src/data/dataset.py exists: {(PROJECT_ROOT / "src" / "data" / "dataset.py").exists()}')

import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from src.data.dataset import EuclidDataset, MorphologySchema
from src.data.transforms import get_transforms, get_tta_transforms
from src.models.factory import create_model
from src.training.losses import DirichletMultinomialLoss
from src.training.trainer import Trainer, TrainConfig
from src.evaluation.metrics import (
    compute_metrics,
    compute_per_question_metrics,
    bootstrap_confidence_interval,
)

import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.dpi'] = 120

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nDevice: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

## 2. Configuration

In [None]:
# Paths
CATALOG_PATH = PROJECT_ROOT / 'data' / 'raw' / 'morphology_catalogue.parquet'
SPLIT_PATH = PROJECT_ROOT / 'data' / 'processed' / 'split_indices.json'
IMAGE_DIR = PROJECT_ROOT / 'data' / 'raw' / 'images'
CHECKPOINT_DIR = PROJECT_ROOT / 'results' / 'checkpoints'
FIGURES_DIR = PROJECT_ROOT / 'results' / 'figures'
TABLES_DIR = PROJECT_ROOT / 'results' / 'tables'

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
TABLES_DIR.mkdir(parents=True, exist_ok=True)

# Training hyperparameters (from configs/base.yaml)
MODEL_NAME = 'zoobot'
INPUT_SIZE = 224
BATCH_SIZE = 32          # Increase to 64 on A100
LR_PROBE = 1e-3          # Linear probe learning rate
LR_FINETUNE = 5e-5       # Full fine-tune learning rate
WEIGHT_DECAY = 0.01
LINEAR_PROBE_EPOCHS = 5
FINETUNE_EPOCHS = 25
PATIENCE = 5
NUM_WORKERS = 4

# Adjust batch size for GPU memory
if torch.cuda.is_available():
    gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
    if gpu_mem > 30:       # A100
        BATCH_SIZE = 64
    elif gpu_mem > 14:     # V100 / T4
        BATCH_SIZE = 32
    else:
        BATCH_SIZE = 16

print(f'Model: {MODEL_NAME}')
print(f'Batch size: {BATCH_SIZE}')
print(f'Training: {LINEAR_PROBE_EPOCHS} probe + {FINETUNE_EPOCHS} fine-tune epochs')

## 3. Data Loading

In [None]:
# Schema
schema = MorphologySchema.default()
print(f'Schema: {schema.num_outputs} outputs across {len(schema.questions)} questions')

# Load splits
train_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'train')
val_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'val')
test_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'test')

print(f'\nTrain: {len(train_df):,}')
print(f'Val:   {len(val_df):,}')
print(f'Test:  {len(test_df):,}')

In [None]:
# Augmentation config (matching configs/base.yaml)
augmentation_cfg = {
    'random_horizontal_flip': True,
    'random_vertical_flip': True,
    'random_rotation': 360,
    'color_jitter': {
        'brightness': 0.1,
        'contrast': 0.1,
        'saturation': 0.0,
        'hue': 0.0,
    },
    'random_resized_crop': {
        'scale': [0.85, 1.0],
        'ratio': [0.95, 1.05],
    },
}

train_transform = get_transforms('train', INPUT_SIZE, augmentation_cfg=augmentation_cfg)
val_transform = get_transforms('val', INPUT_SIZE)

# Datasets
train_ds = EuclidDataset(train_df, IMAGE_DIR, schema, train_transform)
val_ds = EuclidDataset(val_df, IMAGE_DIR, schema, val_transform)
test_ds = EuclidDataset(test_df, IMAGE_DIR, schema, val_transform)

# DataLoaders
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True,
)
test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True,
)

print(f'Train: {len(train_ds):,} samples ({len(train_loader)} batches)')
print(f'Val:   {len(val_ds):,} samples ({len(val_loader)} batches)')
print(f'Test:  {len(test_ds):,} samples ({len(test_loader)} batches)')

## 4. Model

In [None]:
model = create_model('zoobot', num_outputs=schema.num_outputs, pretrained=True)
params = model.count_parameters()

print(f'Model: {MODEL_NAME} (EfficientNet-B0)')
print(f'Parameters: {params["total"]:,} total, {params["trainable"]:,} trainable')
print(f'Output dim: {schema.num_outputs}')

# Quick sanity check — forward pass with random input
with torch.no_grad():
    dummy = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE)
    out = model(dummy)
    print(f'\nForward pass OK: input {dummy.shape} → output {out.shape}')

## 5. Training

Two-phase training strategy:
- **Phase 1 (Linear Probe)**: Freeze the EfficientNet backbone, train only the new regression head for 5 epochs at lr=1e-3. This quickly establishes a reasonable mapping from pretrained features to vote fractions.
- **Phase 2 (Full Fine-tune)**: Unfreeze all layers, train end-to-end for 25 epochs at lr=5e-5 with cosine annealing. Early stopping (patience=5) prevents overfitting.

In [None]:
# Loss function (Dirichlet-Multinomial — matching Zoobot)
criterion = DirichletMultinomialLoss(schema)

# Trainer config
train_config = TrainConfig(
    lr=LR_FINETUNE,
    lr_linear_probe=LR_PROBE,
    weight_decay=WEIGHT_DECAY,
    batch_size=BATCH_SIZE,
    epochs=LINEAR_PROBE_EPOCHS + FINETUNE_EPOCHS,
    linear_probe_epochs=LINEAR_PROBE_EPOCHS,
    full_finetune_epochs=FINETUNE_EPOCHS,
    warmup_fraction=0.05,
    patience=PATIENCE,
    checkpoint_dir=str(CHECKPOINT_DIR),
    seed=SEED,
)

trainer = Trainer(
    model=model,
    criterion=criterion,
    train_loader=train_loader,
    val_loader=val_loader,
    config=train_config,
    device=device,
    model_name=MODEL_NAME,
)

print('Starting training...')
t_start = time.time()
summary = trainer.train()
t_total = time.time() - t_start

print(f'\nTotal training time: {t_total/60:.1f} min')
print(f'Best val_loss: {summary["best_val_loss"]:.4f} (epoch {summary["best_epoch"]})')

## 6. Training Curves

In [None]:
history = pd.DataFrame(trainer.history)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curves
ax = axes[0]
ax.plot(history['epoch'], history['train_loss'], 'b-', label='Train', alpha=0.8)
ax.plot(history['epoch'], history['val_loss'], 'r-', label='Val', alpha=0.8)
best_epoch = summary['best_epoch']
best_loss = summary['best_val_loss']
ax.axvline(best_epoch, color='gray', linestyle='--', alpha=0.5, label=f'Best (epoch {best_epoch})')
ax.scatter([best_epoch], [best_loss], color='red', zorder=5, s=50)

# Mark phase boundary
probe_epochs = history[history['phase'] == 'probe']
if len(probe_epochs) > 0:
    boundary = probe_epochs['epoch'].max() + 0.5
    ax.axvline(boundary, color='green', linestyle=':', alpha=0.5, label='Phase boundary')

ax.set_xlabel('Epoch')
ax.set_ylabel('Dirichlet-Multinomial Loss')
ax.set_title('Training & Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Learning rate schedule
ax = axes[1]
ax.plot(history['epoch'], history['lr'], 'g-', linewidth=2)
if len(probe_epochs) > 0:
    ax.axvline(boundary, color='green', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('LR Schedule')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# Epoch timing
ax = axes[2]
colors = ['steelblue' if p == 'probe' else 'coral' for p in history['phase']]
ax.bar(history['epoch'], history['time_s'], color=colors, alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Time (s)')
ax.set_title('Epoch Duration')
ax.grid(True, alpha=0.3, axis='y')

fig.suptitle(f'Zoobot (EfficientNet-B0) Training — Best val_loss: {best_loss:.4f}', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'zoobot_training_curves.pdf', bbox_inches='tight')
plt.show()

## 7. Test Set Evaluation

Load the best checkpoint and evaluate on the held-out test set.

In [None]:
# Load best checkpoint
best_ckpt = CHECKPOINT_DIR / f'{MODEL_NAME}_best.pt'
model.load_state_dict(torch.load(best_ckpt, map_location=device, weights_only=True))
model.to(device)
model.eval()
print(f'Loaded best checkpoint: {best_ckpt}')

In [None]:
@torch.no_grad()
def predict(model, loader, device, apply_activation=True):
    """Run inference and collect predictions, targets, masks."""
    all_preds, all_targets, all_masks = [], [], []
    
    for images, targets, masks in loader:
        images = images.to(device)
        logits = model(images)
        
        if apply_activation:
            # For evaluation: convert logits to vote fractions
            # Apply softmax per question group to get proper fractions
            preds = torch.zeros_like(logits)
            schema = MorphologySchema.default()
            for q, (start, end) in schema.question_slices.items():
                preds[:, start:end] = torch.softmax(logits[:, start:end], dim=1)
        else:
            preds = logits
        
        all_preds.append(preds.cpu().numpy())
        all_targets.append(targets.numpy())
        all_masks.append(masks.numpy())
    
    return (
        np.concatenate(all_preds),
        np.concatenate(all_targets),
        np.concatenate(all_masks),
    )


# Standard evaluation (no TTA)
print('Running inference on test set...')
test_preds, test_targets, test_masks = predict(model, test_loader, device)
print(f'Predictions shape: {test_preds.shape}')
print(f'Targets shape:     {test_targets.shape}')
print(f'Masks shape:       {test_masks.shape}')

In [None]:
# Aggregate metrics
metrics = compute_metrics(test_preds, test_targets, test_masks, schema)

print('=' * 60)
print(f'ZOOBOT BASELINE — TEST SET METRICS (no TTA)')
print('=' * 60)
for name, value in sorted(metrics.items()):
    if name.endswith('_p'):  # Skip p-values for display
        continue
    print(f'  {name:25s}: {value:.4f}')
print('=' * 60)

In [None]:
# Per-question metrics
per_q = compute_per_question_metrics(test_preds, test_targets, test_masks, schema)

print(f'\n{"Question":30s} {"N":>8s} {"MSE":>8s} {"MAE":>8s} {"R²":>8s} {"Pearson":>8s} {"Acc":>8s}')
print('-' * 90)
for question, qm in per_q.items():
    if qm.get('n_valid', 0) == 0:
        print(f'{question:30s} {0:>8d}   (no valid samples)')
        continue
    print(
        f'{question:30s} '
        f'{qm["n_valid"]:>8,d} '
        f'{qm["mse"]:>8.4f} '
        f'{qm["mae"]:>8.4f} '
        f'{qm.get("r2", float("nan")):>8.4f} '
        f'{qm.get("pearson_r", float("nan")):>8.4f} '
        f'{qm.get("accuracy", float("nan")):>8.4f}'
    )

## 8. Test-Time Augmentation (TTA)

Apply 7 geometric transforms (original + flips + rotations), average predictions.
This is "free" — no additional training, just inference passes.

In [None]:
@torch.no_grad()
def predict_with_tta(model, dataset_df, image_dir, schema, device, batch_size=32, num_workers=4):
    """Run TTA inference: predict with each transform, then average."""
    tta_transforms = get_tta_transforms()
    print(f'TTA with {len(tta_transforms)} views')
    
    all_view_preds = []
    
    for i, tfm in enumerate(tta_transforms):
        ds = EuclidDataset(dataset_df, image_dir, schema, tfm)
        loader = DataLoader(
            ds, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=True,
        )
        
        preds_list = []
        for images, _, _ in loader:
            images = images.to(device)
            logits = model(images)
            
            # Convert to fractions via per-question softmax
            preds = torch.zeros_like(logits)
            for q, (start, end) in schema.question_slices.items():
                preds[:, start:end] = torch.softmax(logits[:, start:end], dim=1)
            
            preds_list.append(preds.cpu().numpy())
        
        view_preds = np.concatenate(preds_list)
        all_view_preds.append(view_preds)
        print(f'  View {i+1}/{len(tta_transforms)} done')
    
    # Average across TTA views
    tta_preds = np.mean(all_view_preds, axis=0)
    
    # Collect targets and masks from original dataset
    orig_ds = EuclidDataset(dataset_df, image_dir, schema, get_transforms('val'))
    targets = orig_ds.targets.numpy()
    masks = orig_ds.masks.numpy()
    
    return tta_preds, targets, masks


print('Running TTA on test set...')
t0 = time.time()
tta_preds, tta_targets, tta_masks = predict_with_tta(
    model, test_df, IMAGE_DIR, schema, device,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
)
tta_time = time.time() - t0
print(f'\nTTA inference time: {tta_time:.1f}s')

In [None]:
# TTA metrics
tta_metrics = compute_metrics(tta_preds, tta_targets, tta_masks, schema)

print('=' * 60)
print(f'ZOOBOT BASELINE — TEST SET METRICS (with TTA)')
print('=' * 60)
for name, value in sorted(tta_metrics.items()):
    if name.endswith('_p'):
        continue
    no_tta = metrics.get(name, float('nan'))
    delta = value - no_tta if not np.isnan(no_tta) else float('nan')
    arrow = '↑' if delta > 0 and name in ('r2', 'pearson_r', 'spearman_r', 'accuracy_mean', 'f1_weighted_mean') else \
            '↓' if delta < 0 and name in ('r2', 'pearson_r', 'spearman_r', 'accuracy_mean', 'f1_weighted_mean') else \
            '↓' if delta < 0 and name in ('mse', 'mae') else \
            '↑' if delta > 0 and name in ('mse', 'mae') else ''
    print(f'  {name:25s}: {value:.4f}  (Δ={delta:+.4f} {arrow})')
print('=' * 60)

## 9. Bootstrap Confidence Intervals

In [None]:
from sklearn.metrics import mean_squared_error, r2_score

def mse_metric(preds, targets, masks):
    valid = masks.astype(bool).flatten()
    return mean_squared_error(targets.flatten()[valid], preds.flatten()[valid])

def r2_metric(preds, targets, masks):
    valid = masks.astype(bool).flatten()
    return r2_score(targets.flatten()[valid], preds.flatten()[valid])

print('Computing bootstrap CIs (1000 iterations)...')
print('This may take a minute.\n')

# MSE CI (with TTA)
mse_point, mse_lo, mse_hi = bootstrap_confidence_interval(
    tta_preds, tta_targets, tta_masks, mse_metric, n_iterations=1000,
)
print(f'MSE (TTA):  {mse_point:.4f}  [{mse_lo:.4f}, {mse_hi:.4f}]')

# R² CI (with TTA)
r2_point, r2_lo, r2_hi = bootstrap_confidence_interval(
    tta_preds, tta_targets, tta_masks, r2_metric, n_iterations=1000,
)
print(f'R² (TTA):   {r2_point:.4f}  [{r2_lo:.4f}, {r2_hi:.4f}]')

## 10. Predicted vs True Vote Fractions

In [None]:
# Scatter plots: predicted vs true for top-level questions
key_questions = ['smooth-or-featured', 'merging', 'disk-edge-on', 'has-spiral-arms']

fig, axes = plt.subplots(2, 2, figsize=(12, 12))

for ax, question in zip(axes.flat, key_questions):
    start, end = schema.question_slices[question]
    q_mask = tta_masks[:, start] > 0
    
    if q_mask.sum() == 0:
        ax.set_title(f'{question} (no valid samples)')
        continue
    
    # Use the first answer column as the main fraction to plot
    pred_frac = tta_preds[q_mask, start]
    true_frac = tta_targets[q_mask, start]
    answer_name = schema.questions[question][0]
    
    ax.scatter(true_frac, pred_frac, alpha=0.05, s=1, color='steelblue')
    ax.plot([0, 1], [0, 1], 'r--', linewidth=1, alpha=0.8)
    
    # Correlation
    from scipy import stats
    r, _ = stats.pearsonr(true_frac, pred_frac)
    
    ax.set_xlabel(f'True {answer_name} fraction')
    ax.set_ylabel(f'Predicted {answer_name} fraction')
    ax.set_title(f'{question} (r={r:.3f}, N={q_mask.sum():,})')
    ax.set_xlim(-0.02, 1.02)
    ax.set_ylim(-0.02, 1.02)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

fig.suptitle('Zoobot Baseline: Predicted vs True Vote Fractions (TTA)', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'zoobot_pred_vs_true.pdf', bbox_inches='tight')
plt.show()

## 11. Per-Question Performance Summary

In [None]:
# Per-question metrics (with TTA)
tta_per_q = compute_per_question_metrics(tta_preds, tta_targets, tta_masks, schema)

# Build a summary DataFrame
rows = []
for question, qm in tta_per_q.items():
    if qm.get('n_valid', 0) == 0:
        continue
    rows.append({
        'question': question,
        'n_valid': qm['n_valid'],
        'mse': qm['mse'],
        'mae': qm['mae'],
        'r2': qm.get('r2', np.nan),
        'pearson_r': qm.get('pearson_r', np.nan),
        'accuracy': qm.get('accuracy', np.nan),
        'f1_weighted': qm.get('f1_weighted', np.nan),
    })

results_df = pd.DataFrame(rows)
print(results_df.to_string(index=False, float_format='{:.4f}'.format))

In [None]:
# Bar chart: R² and accuracy per question
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

questions = results_df['question']
x = np.arange(len(questions))

# R² per question
ax = axes[0]
colors = plt.cm.Set2(np.linspace(0, 1, len(questions)))
bars = ax.barh(x, results_df['r2'], color=colors, edgecolor='gray', alpha=0.8)
ax.set_yticks(x)
ax.set_yticklabels(questions, fontsize=10)
ax.set_xlabel('R²')
ax.set_title('R² per Morphology Question')
ax.set_xlim(0, 1)
ax.grid(True, alpha=0.3, axis='x')
for bar, val in zip(bars, results_df['r2']):
    ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
            f'{val:.3f}', va='center', fontsize=9)

# Accuracy per question
ax = axes[1]
bars = ax.barh(x, results_df['accuracy'], color=colors, edgecolor='gray', alpha=0.8)
ax.set_yticks(x)
ax.set_yticklabels(questions, fontsize=10)
ax.set_xlabel('Accuracy (argmax)')
ax.set_title('Classification Accuracy per Question')
ax.set_xlim(0, 1)
ax.grid(True, alpha=0.3, axis='x')
for bar, val in zip(bars, results_df['accuracy']):
    ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
            f'{val:.3f}', va='center', fontsize=9)

fig.suptitle('Zoobot Baseline: Per-Question Performance (with TTA)', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'zoobot_per_question_metrics.pdf', bbox_inches='tight')
plt.show()

## 12. Save Results

Save all results for later comparison with ViT models in notebook 05.

In [None]:
# Save comprehensive results
results = {
    'model_name': MODEL_NAME,
    'architecture': 'EfficientNet-B0',
    'pretrained': 'ImageNet',
    'params_total': params['total'],
    'params_trainable': params['trainable'],
    'training': {
        'batch_size': BATCH_SIZE,
        'lr_probe': LR_PROBE,
        'lr_finetune': LR_FINETUNE,
        'linear_probe_epochs': LINEAR_PROBE_EPOCHS,
        'finetune_epochs': FINETUNE_EPOCHS,
        'total_epochs_trained': len(trainer.history),
        'best_epoch': summary['best_epoch'],
        'best_val_loss': summary['best_val_loss'],
        'training_time_min': t_total / 60,
    },
    'metrics_no_tta': metrics,
    'metrics_tta': tta_metrics,
    'per_question_tta': {q: dict(m) for q, m in tta_per_q.items()},
    'bootstrap_ci': {
        'mse': {'point': mse_point, 'ci_lower': mse_lo, 'ci_upper': mse_hi},
        'r2': {'point': r2_point, 'ci_lower': r2_lo, 'ci_upper': r2_hi},
    },
}

# Save JSON
results_path = TABLES_DIR / f'{MODEL_NAME}_results.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=float)
print(f'Results saved: {results_path}')

# Save predictions for post-hoc analysis
np.savez_compressed(
    TABLES_DIR / f'{MODEL_NAME}_predictions.npz',
    predictions=tta_preds,
    targets=tta_targets,
    masks=tta_masks,
)
print(f'Predictions saved: {TABLES_DIR / f"{MODEL_NAME}_predictions.npz"}')

In [None]:
# Per-question results table (CSV for LaTeX)
results_df.to_csv(TABLES_DIR / f'{MODEL_NAME}_per_question.csv', index=False)
print(f'Per-question CSV saved: {TABLES_DIR / f"{MODEL_NAME}_per_question.csv"}')

# Training history CSV
history.to_csv(TABLES_DIR / f'{MODEL_NAME}_history.csv', index=False)
print(f'Training history saved: {TABLES_DIR / f"{MODEL_NAME}_history.csv"}')

## 13. Summary

In [None]:
print('=' * 60)
print('ZOOBOT BASELINE — FINAL SUMMARY')
print('=' * 60)
print(f'''
  Model:          {MODEL_NAME} (EfficientNet-B0)
  Parameters:     {params["total"]:,} total, {params["trainable"]:,} trainable
  Training:       {LINEAR_PROBE_EPOCHS} probe + {len(trainer.history) - LINEAR_PROBE_EPOCHS} fine-tune epochs
  Best epoch:     {summary["best_epoch"]}
  Training time:  {t_total/60:.1f} min

  --- Metrics (no TTA) ---
  MSE:            {metrics["mse"]:.4f}
  MAE:            {metrics["mae"]:.4f}
  R²:             {metrics["r2"]:.4f}
  Pearson r:      {metrics.get("pearson_r", float("nan")):.4f}
  Accuracy:       {metrics.get("accuracy_mean", float("nan")):.4f}

  --- Metrics (with TTA) ---
  MSE:            {tta_metrics["mse"]:.4f}  [{mse_lo:.4f}, {mse_hi:.4f}]
  MAE:            {tta_metrics["mae"]:.4f}
  R²:             {tta_metrics["r2"]:.4f}  [{r2_lo:.4f}, {r2_hi:.4f}]
  Pearson r:      {tta_metrics.get("pearson_r", float("nan")):.4f}
  Accuracy:       {tta_metrics.get("accuracy_mean", float("nan")):.4f}
''')
print('=' * 60)
print('\nNext: notebooks/04_vit_experiments.ipynb')
print('  Train ViT-Base, Swin-V2, DINOv2, and ConvNeXt on the same data.')