In [1]:
# Kernel/env quick check: Python path and rich availability
import sys
print("ENV_PY:", sys.executable)
try:
    import rich  # noqa: F401
    print("RICH_OK")
except Exception:
    print("RICH_MISSING -> installing to current kernel...")
    # Install into the active kernel if needed
    %pip install -q rich
    import rich  # noqa: F401
    print("RICH_INSTALLED")

ENV_PY: /Users/javier/Documents/202508_DeepL_lab1_B/.venv/bin/python
RICH_OK


# Enhanced PyTorch Image Classifier (Notebook)
Interactive version of `train_enhanced.py` with MPS (Apple Silicon) support, rich progress, and evaluation.

- Configure parameters in the Config cell.
- Run the Training cell to start.
- Results and artifacts go to `runs/<variant>/`.

In [2]:
# Optional: install requirements into the current kernel environment
# You can skip if your environment already has these.
# !pip install -r requirements.txt
pass

In [3]:
# Imports & setup
import os, json, time, warnings
from pathlib import Path
from typing import Dict, Tuple, Optional, List
from datetime import timedelta
from collections import deque

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torchvision import datasets, transforms

import timm
from timm.data import Mixup, create_transform
from timm.loss import SoftTargetCrossEntropy
from timm.utils import ModelEmaV2

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

from tqdm import tqdm
from rich.console import Console
from rich.panel import Panel

warnings.filterwarnings('ignore')
console = Console()

console.print(f'Using torch {torch.__version__}', style='cyan')
console.print(f'MPS available: {torch.backends.mps.is_available()}', style='cyan')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Progress tracking helper
class TrainingTracker:
    def __init__(self, total_epochs: int, train_batches: int, val_batches: int):
        self.total_epochs = total_epochs
        self.train_batches = train_batches
        self.val_batches = val_batches
        self.start_time = time.time()
        self.epoch_start_time = None
        self.train_losses = deque(maxlen=100)
        self.val_losses = []
        self.val_accs = []
        self.lrs = []
        self.best_acc = 0
        self.best_epoch = 0
        self.epoch_times = deque(maxlen=5)

    def start_epoch(self, epoch: int):
        self.epoch_start_time = time.time()
        self.current_epoch = epoch

    def end_epoch(self, val_acc: float):
        epoch_time = time.time() - self.epoch_start_time
        self.epoch_times.append(epoch_time)
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.best_epoch = self.current_epoch

    def get_eta(self) -> str:
        if not self.epoch_times:
            return 'Calculating...'
        avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
        remaining_epochs = self.total_epochs - self.current_epoch - 1
        return str(timedelta(seconds=int(avg_epoch_time * remaining_epochs)))

    def get_speed(self) -> str:
        if not self.epoch_times:
            return 'N/A'
        avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
        samples_per_sec = (self.train_batches * 32) / avg_epoch_time
        return f'{samples_per_sec:.1f} img/s'

    def create_dashboard(self, epoch: int, train_loss: float, val_loss: float, val_acc: float, lr: float):
        from rich.table import Table
        table = Table(title=f'Training Progress - Epoch {epoch+1}/{self.total_epochs}')
        table.add_column('Metric', style='cyan', no_wrap=True)
        table.add_column('Current', style='magenta')
        table.add_column('Best', style='green')
        table.add_row('Train Loss', f'{train_loss:.4f}', '-')
        table.add_row('Val Loss', f'{val_loss:.4f}', f"{min(self.val_losses) if self.val_losses else 0:.4f}")
        table.add_row('Val Acc', f'{val_acc:.2f}%', f'{self.best_acc:.2f}% (E{self.best_epoch+1})')
        table.add_row('Learning Rate', f'{lr:.2e}', '-')
        table.add_row('Speed', self.get_speed(), '-')
        table.add_row('ETA', self.get_eta(), '-')
        elapsed = str(timedelta(seconds=int(time.time() - self.start_time)))
        table.add_row('Elapsed', elapsed, '-')
        return table

In [5]:
# Device & memory management
def get_device(force_cpu: bool = False) -> torch.device:
    if force_cpu:
        return torch.device('cpu')
    if torch.backends.mps.is_available():
        console.print('✓ MPS device detected - using Apple Silicon GPU acceleration', style='green')
        return torch.device('mps')
    elif torch.cuda.is_available():
        console.print('✓ CUDA device detected', style='green')
        return torch.device('cuda')
    console.print('⚠️ No GPU found - using CPU', style='yellow')
    return torch.device('cpu')

def clear_memory(device: torch.device):
    if device.type == 'mps':
        torch.mps.empty_cache()
    elif device.type == 'cuda':
        torch.cuda.empty_cache()

In [6]:
# Model configurations
MODEL_CONFIGS = {
    'r18_base': {'timm_name': 'resnet18', 'input_size': 224, 'batch_size': 32, 'use_custom_head': False},
    'r34_base': {'timm_name': 'resnet34', 'input_size': 224, 'batch_size': 24, 'use_custom_head': False},
    'efficientnet_b0': {'timm_name': 'efficientnet_b0', 'input_size': 224, 'batch_size': 32, 'use_custom_head': False},
    'efficientnet_b1': {'timm_name': 'efficientnet_b1', 'input_size': 240, 'batch_size': 24, 'use_custom_head': False},
    'efficientnet_b2': {'timm_name': 'efficientnet_b2', 'input_size': 260, 'batch_size': 16, 'use_custom_head': False},
    'densenet121': {'timm_name': 'densenet121', 'input_size': 224, 'batch_size': 24, 'use_custom_head': False},
}
list(MODEL_CONFIGS.keys())

['r18_base',
 'r34_base',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'densenet121']

In [7]:
# Custom model head
class EnhancedResNetHead(nn.Module):
    def __init__(self, in_features: int, num_classes: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 512)
        self.bn = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(512, num_classes)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.bn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [8]:
# Model builder
def build_model(variant: str, num_classes: int, use_se: bool = False, pretrained: bool = True):
    config = MODEL_CONFIGS[variant]
    model_name = config['timm_name']
    if use_se and 'resnet' in model_name:
        model_name = model_name.replace('resnet', 'seresnet')
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes if not config.get('use_custom_head') else 0)
    if config.get('use_custom_head'):
        in_features = model.num_features
        model.fc = EnhancedResNetHead(in_features, num_classes)
    param_groups = []
    if config.get('two_phase'):
        backbone_early, backbone_late, head_params = [], [], []
        for name, p in model.named_parameters():
            if any(k in name for k in ['fc', 'head', 'classifier']):
                head_params.append(p)
            elif 'layer1' in name or 'layer2' in name:
                backbone_early.append(p)
            else:
                backbone_late.append(p)
        param_groups = [
            {'params': backbone_early, 'lr_scale': 0.5},
            {'params': backbone_late, 'lr_scale': 1.0},
            {'params': head_params, 'lr_scale': 1.5},
        ]
    else:
        param_groups = [{'params': model.parameters(), 'lr_scale': 1.0}]
    return model, param_groups

In [9]:
# Data transforms
def build_transforms(input_size: int, is_training: bool, use_augmentation: bool = True) -> transforms.Compose:
    if is_training and use_augmentation:
        return create_transform(input_size=input_size, is_training=True, auto_augment='rand-m9-mstd0.5-inc1', interpolation='bicubic', mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
    else:
        return transforms.Compose([
            transforms.Resize(int(input_size * 1.14)),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
        ])

In [10]:
# Dataset preparation
def prepare_datasets(data_root: str, val_split: Optional[float], variant: str, use_augmentation: bool = True):
    data_path = Path(data_root)
    train_path, val_path = data_path / 'train', data_path / 'val'
    input_size = MODEL_CONFIGS[variant]['input_size']
    train_transform = build_transforms(input_size, True, use_augmentation)
    val_transform = build_transforms(input_size, False, False)
    if val_path.exists():
        train_dataset = datasets.ImageFolder(train_path, transform=train_transform)
        val_dataset = datasets.ImageFolder(val_path, transform=val_transform)
    else:
        full_dataset = datasets.ImageFolder(train_path)
        if val_split:
            val_size = int(len(full_dataset) * val_split)
            train_size = len(full_dataset) - val_size
            train_indices, val_indices = random_split(range(len(full_dataset)), [train_size, val_size])
            train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
            val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
            train_dataset.dataset.transform = train_transform
            val_dataset.dataset.transform = val_transform
        else:
            raise ValueError('No validation folder found and val_split not specified')
    return train_dataset, val_dataset

In [11]:
# Dataloaders
def create_dataloaders(train_dataset, val_dataset, batch_size: int, num_workers: int = 4, balance_sampler: bool = False):
    train_sampler = None
    if balance_sampler:
        targets = train_dataset.targets if hasattr(train_dataset, 'targets') else [train_dataset.dataset.targets[i] for i in train_dataset.indices]
        class_counts = np.bincount(targets)
        class_weights = 1.0 / class_counts
        sample_weights = [float(class_weights[t]) for t in targets]
        train_sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=num_workers, pin_memory=False, persistent_workers=True if num_workers > 0 else False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False, num_workers=num_workers, pin_memory=False, persistent_workers=True if num_workers > 0 else False)
    return train_loader, val_loader

In [12]:
# SAM optimizer wrapper
class SAM:
    def __init__(self, base_optimizer, rho=0.05):
        self.base_optimizer = base_optimizer
        self.rho = rho
    @torch.no_grad()
    def first_step(self):
        grad_norm = self._grad_norm()
        for group in self.base_optimizer.param_groups:
            scale = self.rho / (grad_norm + 1e-12)
            for p in group['params']:
                if p.grad is None: continue
                p.requires_grad_(False)
                p.add_(p.grad * scale)
                p.requires_grad_(True)
    @torch.no_grad()
    def second_step(self):
        for group in self.base_optimizer.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                p.requires_grad_(False)
                p.sub_(p.grad * (self.rho / (self._grad_norm() + 1e-12)))
                p.requires_grad_(True)
        self.base_optimizer.step()
    def _grad_norm(self):
        return torch.norm(torch.stack([p.grad.norm(p=2) for g in self.base_optimizer.param_groups for p in g['params'] if p.grad is not None]), p=2)
    def zero_grad(self):
        self.base_optimizer.zero_grad()
    @property
    def param_groups(self):
        return self.base_optimizer.param_groups
    def state_dict(self):
        return self.base_optimizer.state_dict()
    def load_state_dict(self, state_dict):
        self.base_optimizer.load_state_dict(state_dict)

In [13]:
# Training epoch
def train_epoch(model, loader, criterion, optimizer, device, scaler, mixup_fn: Optional[Mixup], accumulation_steps: int, use_sam: bool, ema_model: Optional[ModelEmaV2] = None, epoch: int = 0, show_progress: bool = True):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]', bar_format='{l_bar}{bar:30}{r_bar}{bar:-10b}', colour='green') if show_progress else loader
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs, targets = inputs.to(device), targets.to(device)
        if mixup_fn is not None:
            inputs, targets = mixup_fn(inputs, targets)
        # Forward
        if device.type == 'cuda':
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(inputs)
                loss = criterion(outputs, targets)
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        loss = loss / accumulation_steps
        # Backward
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        if (batch_idx + 1) % accumulation_steps == 0:
            if use_sam:
                optimizer.first_step()
                optimizer.zero_grad()
                # second forward-backward
                if device.type == 'cuda':
                    with torch.autocast(device_type='cuda', dtype=torch.float16):
                        outputs = model(inputs)
                        loss2 = criterion(outputs, targets) / accumulation_steps
                else:
                    outputs = model(inputs)
                    loss2 = criterion(outputs, targets) / accumulation_steps
                if scaler is not None:
                    scaler.scale(loss2).backward()
                else:
                    loss2.backward()
                optimizer.second_step()
            else:
                if scaler is not None:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
            optimizer.zero_grad()
            if ema_model is not None:
                ema_model.update(model)
        running_loss += loss.item() * accumulation_steps
        if mixup_fn is None:
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        if show_progress and batch_idx % 10 == 0:
            current_loss = running_loss / (batch_idx + 1)
            current_acc = 100.0 * correct / total if total > 0 else 0.0
            if hasattr(pbar, 'set_postfix'):
                pbar.set_postfix({'loss': f'{current_loss:.4f}', 'acc': f'{current_acc:.2f}%'})
    clear_memory(device)
    avg_loss = running_loss / len(loader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy, 0.0

In [14]:
# Validation epoch
def validate(model, loader, criterion, device, calc_top5: bool = True, epoch: int = 0, show_progress: bool = True):
    model.eval()
    running_loss, correct_top1, correct_top5, total = 0.0, 0, 0, 0
    pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Val]  ', bar_format='{l_bar}{bar:30}{r_bar}{bar:-10b}', colour='blue') if show_progress else loader
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct_top1 += predicted.eq(targets).sum().item()
            if calc_top5:
                _, top5_pred = outputs.topk(5, 1, True, True)
                correct_top5 += top5_pred.eq(targets.view(-1,1).expand_as(top5_pred)).sum().item()
            if show_progress and batch_idx % 10 == 0 and hasattr(pbar, 'set_postfix'):
                current_loss = running_loss / (batch_idx + 1)
                current_acc = 100.0 * correct_top1 / total
                pbar.set_postfix({'loss': f'{current_loss:.4f}', 'acc': f'{current_acc:.2f}%'})
    avg_loss = running_loss / len(loader)
    top1_acc = 100.0 * correct_top1 / total
    top5_acc = 100.0 * correct_top5 / total if calc_top5 else 0.0
    return avg_loss, top1_acc, top5_acc

In [15]:
# Evaluation & visualization
def evaluate_model(model, val_loader, device, class_names: List[str], output_dir: Path):
    console.print("\n[bold cyan]Running final evaluation...[/bold cyan]")
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc='Evaluating'):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.numpy())
    cm = confusion_matrix(all_targets, all_preds)
    plt.figure(figsize=(12,10))
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / 'confusion_matrix.png', dpi=150)
    plt.close()
    report = classification_report(all_targets, all_preds, target_names=class_names[:len(np.unique(all_targets))], output_dict=True)
    metrics_df = pd.DataFrame(report).transpose()
    metrics_df.to_csv(output_dir / 'per_class_metrics.csv')
    summary = {
        'accuracy': report['accuracy'],
        'macro_avg': report['macro avg'],
        'weighted_avg': report['weighted avg'],
        'total_samples': len(all_targets)
    }
    with open(output_dir / 'model_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    console.print(f"[green]✓ Evaluation complete. Results saved to {output_dir}[/green]")
    console.print(f"[yellow]  Final Accuracy: {report['accuracy']*100:.2f}%[/yellow]")

In [16]:
# Training runner used by the experiments section
from types import SimpleNamespace

def run_training(args: SimpleNamespace):
    # Reproducibility
    torch.manual_seed(getattr(args, 'seed', 42))
    np.random.seed(getattr(args, 'seed', 42))

    # Setup device and output
    device = get_device(getattr(args, 'force_cpu', False))
    out_root = Path(getattr(args, 'outdir', './runs'))
    output_dir = out_root / args.variant
    output_dir.mkdir(parents=True, exist_ok=True)

    # Persist args
    try:
        with open(output_dir / 'args.json', 'w') as f:
            json.dump(vars(args), f, indent=2)
    except Exception:
        pass

    # Data
    config = MODEL_CONFIGS[args.variant]
    batch_size = getattr(args, 'batch_size', None) or config['batch_size']
    train_ds, val_ds = prepare_datasets(
        getattr(args, 'data_root', './data'), getattr(args, 'val_split', None), args.variant, not getattr(args, 'no_augs', False)
    )
    class_names = train_ds.classes if hasattr(train_ds, 'classes') else train_ds.dataset.classes
    num_classes = len(class_names)

    train_loader, val_loader = create_dataloaders(
        train_ds, val_ds, batch_size=batch_size, num_workers=getattr(args, 'num_workers', 4), balance_sampler=getattr(args, 'balance_sampler', False)
    )

    # Model
    model, param_groups = build_model(args.variant, num_classes, getattr(args, 'use_se', False))
    model = model.to(device)

    # Optimizer (support per-group lr scaling)
    base_lr = float(getattr(args, 'lr', 2e-4))
    wd = float(getattr(args, 'weight_decay', 0.1))
    opt_param_groups = []
    for g in param_groups:
        lr_scale = g.pop('lr_scale', 1.0)
        opt_param_groups.append({'params': g['params'], 'lr': base_lr * lr_scale})
    base_opt = AdamW(opt_param_groups, weight_decay=wd)
    use_sam = bool(getattr(args, 'sam', False))
    optimizer = SAM(base_opt, rho=0.05) if use_sam else base_opt

    # Scheduler (per-epoch)
    warmup_epochs = int(getattr(args, 'warmup_epochs', 0))
    total_epochs = int(getattr(args, 'epochs', 10))
    sched_opt = optimizer.base_optimizer if use_sam else optimizer
    scheduler = CosineAnnealingLR(sched_opt, T_max=max(1, total_epochs - warmup_epochs))

    # Criterion and mixup/cutmix
    mixup_alpha = float(getattr(args, 'mixup', 0.0))
    cutmix_alpha = float(getattr(args, 'cutmix', 0.0))
    label_smoothing = float(getattr(args, 'label_smoothing', 0.0))
    do_aug = not getattr(args, 'no_augs', False)
    if do_aug and (mixup_alpha > 0 or cutmix_alpha > 0):
        mixup_fn = Mixup(mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, label_smoothing=label_smoothing, num_classes=num_classes)
        criterion = SoftTargetCrossEntropy()
    else:
        mixup_fn = None
        try:
            criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        except TypeError:
            criterion = nn.CrossEntropyLoss()

    # EMA
    ema_model = ModelEmaV2(model, decay=0.999) if bool(getattr(args, 'ema', False)) else None

    # AMP scaler (CUDA only)
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    # Train loop
    tracker = TrainingTracker(total_epochs, len(train_loader), len(val_loader))
    patience = int(getattr(args, 'patience', 10))
    best_top1, patience_ctr = 0.0, 0
    metrics = []

    # Gradient accumulation to approximate larger batch
    eff_bs = 128
    accumulation_steps = max(1, eff_bs // batch_size)

    console.print(Panel.fit(
        f"Variant: {args.variant}\nDevice: {device}\nEpochs: {total_epochs}\nBatch size: {batch_size} (accum {accumulation_steps})",
        title='Training'
    ))

    for epoch in range(total_epochs):
        tracker.start_epoch(epoch)

        # Warmup LR
        if epoch < warmup_epochs:
            warmup_scale = (epoch + 1) / float(max(1, warmup_epochs))
            for pg in sched_opt.param_groups:
                base_group_lr = base_lr  # already scaled above per group when created
                # Keep current per-group factor by ratio of current lr to base_lr
                scale_factor = pg['lr'] / max(1e-12, base_lr)
                pg['lr'] = base_lr * scale_factor * warmup_scale

        # Train
        train_loss, train_acc, _ = train_epoch(
            model, train_loader, criterion, optimizer, device, scaler, mixup_fn,
            accumulation_steps=accumulation_steps, use_sam=use_sam, ema_model=ema_model, epoch=epoch, show_progress=not getattr(args, 'no_progress', False)
        )

        # Validate (EMA model for eval if present)
        eval_model = ema_model.module if ema_model is not None else model
        val_loss, val_top1, val_top5 = validate(
            eval_model, val_loader, criterion, device, calc_top5=True, epoch=epoch, show_progress=not getattr(args, 'no_progress', False)
        )

        if epoch >= warmup_epochs:
            scheduler.step()

        # Track
        current_lr = sched_opt.param_groups[0]['lr']
        tracker.end_epoch(val_top1)
        tracker.train_losses.append(train_loss)
        tracker.val_losses.append(val_loss)
        tracker.val_accs.append(val_top1)
        tracker.lrs.append(current_lr)
        console.print(tracker.create_dashboard(epoch, train_loss, val_loss, val_top1, current_lr))

        metrics.append({
            'epoch': epoch + 1,
            'lr': current_lr,
            'train_loss': float(train_loss),
            'val_loss': float(val_loss),
            'top1': float(val_top1),
            'top5': float(val_top5),
        })

        # Early stopping on best top1
        if val_top1 > best_top1:
            best_top1 = val_top1
            patience_ctr = 0
        else:
            patience_ctr += 1
            if patience_ctr >= patience:
                console.print(f"[yellow]Early stopping at epoch {epoch+1}[/yellow]")
                break

    # Persist metrics
    pd.DataFrame(metrics).to_csv(output_dir / 'metrics_log.tsv', sep='\t', index=False)

    # Final evaluation and artifacts
    evaluate_model(eval_model, val_loader, device, class_names, output_dir)

    return {'best_acc': float(best_top1), 'output_dir': str(output_dir)}

## Run all experiments

This section mirrors `run_all_experiments.py` to launch multiple variants sequentially, collect metrics, and write an `EXPERIMENT_SUMMARY.md`. Edit the main Config cell above if you need different data/epochs.

In [17]:
# Run-all-experiments utilities (inline version of run_all_experiments.py)
from dataclasses import dataclass
import time

# Default variants (you can edit this list)
VARIANTS = [
    "r18_base",
    "r34_base",
    "efficientnet_b0",
    "efficientnet_b1",
    "efficientnet_b2",
    "densenet121",
]

@dataclass
class ExpBaseConfig:
    data_root: str = "./data"
    epochs: int = 10
    num_workers: int = 4
    seed: int = 42

BASE_CONFIG = ExpBaseConfig()


def run_experiment_inline(variant: str) -> dict:
    """Run a single experiment using the in-notebook training function."""
    console.print("\n" + "="*60)
    console.print(f"Starting experiment: {variant}")
    start_ts = time.time()

    # Build args for the existing run_training function
    from types import SimpleNamespace
    base_bs = MODEL_CONFIGS[variant]["batch_size"]
    args_exp = SimpleNamespace(
        variant=variant,
        use_se=False,
        data_root=BASE_CONFIG.data_root,
        val_split=None,
        num_workers=BASE_CONFIG.num_workers,
        balance_sampler=False,
        epochs=BASE_CONFIG.epochs,
        batch_size=None,
        lr=2e-4,
        weight_decay=0.1,
        warmup_epochs=3,
        patience=10,
        no_augs=False,
        mixup=0.2,
        cutmix=0.2,
        label_smoothing=0.1,
        ema=("plus" in variant),
        sam=("plus" in variant),
        freeze_epochs=5,
        final_resolution=288,
        device='mps',
        force_cpu=False,
        seed=BASE_CONFIG.seed,
        outdir='./runs',
        resume=None,
        no_progress=False,
    )

    try:
        result = run_training(args_exp)
        elapsed = time.time() - start_ts

        # Read metrics
        out_dir = Path(result["output_dir"])  # runs/<variant>
        metrics_file = out_dir / 'metrics_log.tsv'
        if metrics_file.exists():
            mdf = pd.read_csv(metrics_file, sep='\t')
            best_top1 = float(mdf['top1'].max())
            best_top5 = float(mdf['top5'].max())
            epochs_trained = len(mdf)
        else:
            best_top1 = 0.0
            best_top5 = 0.0
            epochs_trained = 0

        # Read summary if exists
        summary_file = out_dir / 'model_summary.json'
        if summary_file.exists():
            with open(summary_file) as f:
                summary = json.load(f)
        else:
            summary = {"accuracy": best_top1/100.0}

        return {
            "variant": variant,
            "status": "completed",
            "accuracy": float(summary.get("accuracy", 0.0)) * 100,
            "best_top1": best_top1,
            "best_top5": best_top5,
            "epochs_trained": epochs_trained,
            "training_time": elapsed,
            "error": None,
        }
    except Exception as e:
        console.print(f"[red]Error training {variant}: {e}[/red]")
        return {
            "variant": variant,
            "status": "error",
            "accuracy": 0.0,
            "best_top1": 0.0,
            "best_top5": 0.0,
            "epochs_trained": 0,
            "training_time": time.time() - start_ts,
            "error": str(e),
        }


def generate_summary_md(results: list) -> str:
    from datetime import datetime
    summary = "# Aircraft Classification Experiments Summary\n\n"
    summary += f"**Date**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
    summary += f"**Dataset**: 100 Aircraft Classes (80/10/10 split)\n"
    summary += f"**Device**: Mac with MPS acceleration (if available)\n"
    summary += f"**Epochs**: {BASE_CONFIG.epochs}\n\n"
    summary += "## Performance Comparison\n\n"
    summary += "| Model Variant | Status | Top-1 Acc (%) | Top-5 Acc (%) | Epochs | Time (min) |\n"
    summary += "|--------------|---------|---------------|---------------|--------|------------|\n"
    for r in sorted(results, key=lambda x: x['best_top1'], reverse=True):
        status_emoji = "✅" if r['status'] == "completed" else "❌"
        time_min = r['training_time'] / 60
        summary += f"| {r['variant']} | {status_emoji} | {r['best_top1']:.2f} | {r['best_top5']:.2f} | {r['epochs_trained']} | {time_min:.1f} |\n"
    top_completed = [r for r in results if r['status'] == 'completed']
    if top_completed:
        summary += "\n## Top Performers\n\n"
        top3 = sorted(top_completed, key=lambda x: x['best_top1'], reverse=True)[:3]
        for i, r in enumerate(top3, 1):
            summary += f"{i}. **{r['variant']}**: {r['best_top1']:.2f}% Top-1 accuracy\n"
    return summary


In [18]:
# Orchestrate all experiments from the notebook
from pprint import pprint

all_results = []
for i, v in enumerate(VARIANTS, 1):
    console.print(f"\n[i={i}] Running {v} ...", style="cyan")
    r = run_experiment_inline(v)
    all_results.append(r)
    # Save intermediate results
    with open("experiment_results.json", "w") as f:
        json.dump(all_results, f, indent=2)
    console.print(f"Done {v}: {r.get('best_top1',0):.2f}%", style="green")

# Summary markdown and a small table view
summary_md = generate_summary_md(all_results)
with open("EXPERIMENT_SUMMARY.md", "w") as f:
    f.write(summary_md)

pd.DataFrame(all_results).sort_values("best_top1", ascending=False)

Epoch 1 [Train]: 100%|[32m██████████████████████████████[0m| 250/250 [01:47<00:00,  2.32it/s, loss=4.6183, acc=0.00%][32m[0m
Epoch 1 [Val]  :   0%|[34m                              [0m| 0/16 [00:06<?, ?it/s][34m[0m


Epoch 1 [Train]:  57%|[32m█████████████████▏            [0m| 191/334 [02:40<02:00,  1.19it/s, loss=4.6159, acc=0.00%][32m[0m


KeyboardInterrupt: 