# ResNet‑18 training notebook

This notebook trains and evaluates an image classifier (ResNet‑18) for Amsterdam insect species.

**What this notebook does**
- Loads an `ImageFolder` dataset with a **train / val / test** directory structure.
- Trains **N independent runs** (different random seeds) with early stopping.
- Evaluates each run on a **temporal holdout** test split (e.g. `test2` = year 2025).
- Saves weights, per-run metrics, predictions, and confusion matrices.
- Produces an aggregated summary across runs (mean ± std).

**Repository-aligned outputs (your current structure)**

- **Models** (weights): `models/vision/`
- **Run artifacts** (figures / metrics / predictions):
  - `outputs/vision_resnet/figures/`
  - `outputs/vision_resnet/metrics/`
  - `outputs/vision_resnet/preds/`

**Recommended dataset folder structure**

```
data/amsterdam/images_no_vespula/
  train/
    Species_A/
      img_001.jpg
      ...
  val/
    Species_A/
      ...
  test2/           # temporal holdout (e.g. 2025)
    Species_A/
      ...
```

> This notebook is robust to being launched from `notebooks/` by auto-detecting the repo root using `pyproject.toml`.


In [None]:
# Environment sanity check (helps reproducibility)
import sys, platform
import torch, torchvision

print("Python:", sys.version.split()[0])
print("Platform:", platform.platform())
print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("MPS available:", getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available())
print("CUDA available:", torch.cuda.is_available())


## 1. Configuration

Edit the paths and hyperparameters below.

For portability, the defaults assume you run this notebook from the repository root, with data under `./data/...`.


In [None]:
from __future__ import annotations

from dataclasses import dataclass, asdict
from pathlib import Path
import json
import numpy as np
import torch
import random

from digital_naturalist.paths import load_paths


# -----------------------------
# Load repo paths (single source of truth)
# -----------------------------
P = load_paths("configs/paths.yaml")

# Main dirs
REPO_ROOT = P["REPO_ROOT"]
IMAGE_ROOT = P["IMAGE_ROOT"]              # data/.../images_no_vespula
IMAGE_TRAIN = P["IMAGE_TRAIN_DIR"]
IMAGE_VAL   = P["IMAGE_VAL_DIR"]
IMAGE_TEST2 = P["IMAGE_TEST2_DIR"]

MODELS_DIR = P["VISION_MODEL_DIR"]        # models/vision
TEMPS_DIR  = P["VISION_TEMPS_DIR"]        # models/vision/temperatures

OUT_DIR = P["OUT_VISION_RESNET"]          # outputs/vision_resnet
FIG_DIR = OUT_DIR / "figures"
MET_DIR = OUT_DIR / "metrics"
PRD_DIR = OUT_DIR / "preds"

for d in (FIG_DIR, MET_DIR, PRD_DIR, MODELS_DIR, TEMPS_DIR):
    d.mkdir(parents=True, exist_ok=True)


# -----------------------------
# Experiment config (non-path hyperparams only)
# -----------------------------
@dataclass(frozen=True)
class ExperimentConfig:
    test_split_name: str = "test2"   # temporal holdout split
    image_size: int = 224

    num_classes: int = 8
    batch_size: int = 32
    lr: float = 1e-3
    epochs: int = 20
    patience: int = 4
    num_runs: int = 10
    base_seed: int = 42
    num_workers: int = 0   # macOS/MPS safest
    use_pretrained: bool = True


def get_device() -> torch.device:
    if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


CFG = ExperimentConfig()
DEVICE = get_device()

print("Repo root        :", REPO_ROOT)
print("Image root       :", IMAGE_ROOT, "| exists:", IMAGE_ROOT.exists())
print("Train images     :", IMAGE_TRAIN, "| exists:", IMAGE_TRAIN.exists())
print("Val images       :", IMAGE_VAL, "| exists:", IMAGE_VAL.exists())
print("Test images      :", IMAGE_TEST2, "| exists:", IMAGE_TEST2.exists())
print("Outputs root     :", OUT_DIR)
print("  figures        :", FIG_DIR)
print("  metrics        :", MET_DIR)
print("  preds          :", PRD_DIR)
print("Vision models    :", MODELS_DIR)
print("Temperatures     :", TEMPS_DIR)
print("Device           :", DEVICE)

# Save config snapshot (hyperparams + resolved paths)
config_snapshot = {
    "paths": {k: str(v) for k, v in P.items()},
    "cfg": asdict(CFG),
    "device": str(DEVICE),
}
with open(MET_DIR / "config.json", "w", encoding="utf-8") as f:
    json.dump(config_snapshot, f, indent=2)


In [None]:
# Aliases 
TRAIN_DIR = IMAGE_TRAIN
VAL_DIR   = IMAGE_VAL
TEST_DIR  = IMAGE_TEST2  # your temporal holdout split

# Sanity checks
assert TRAIN_DIR.exists(), f"Missing TRAIN_DIR: {TRAIN_DIR}"
assert VAL_DIR.exists(),   f"Missing VAL_DIR: {VAL_DIR}"
assert TEST_DIR.exists(),  f"Missing TEST_DIR: {TEST_DIR}"


### Where files will be saved (quick reference)

- **Weights** → `models/vision/`
- **Per-run predictions** → `outputs/vision_resnet/preds/<run_tag>/`
- **Per-run metrics** → `outputs/vision_resnet/metrics/<run_tag>/`
- **Per-run figures** → `outputs/vision_resnet/figures/<run_tag>/`
- **Aggregates** → `outputs/vision_resnet/metrics/` and `outputs/vision_resnet/figures/`

You can override folders with environment variables:

- `DATA_DIR`
- `OUTPUTS_DIR`
- `MODELS_DIR`


## 2. Data loading

We use `torchvision.datasets.ImageFolder`. This requires each split folder to contain one subfolder per class.


In [None]:
from pathlib import Path
from typing import Dict, Tuple, List

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def _assert_split_exists(split_dir: Path) -> None:
    if not split_dir.exists():
        raise FileNotFoundError(f"Missing split folder: {split_dir}")
    class_dirs = [p for p in split_dir.iterdir() if p.is_dir()]
    if not class_dirs:
        raise RuntimeError(f"No class subfolders found in: {split_dir}")


def build_transforms(image_size: int) -> Dict[str, transforms.Compose]:
    '''Standard ImageNet normalization with mild augmentation for training.'''
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    return {
        "train": transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(imagenet_mean, imagenet_std),
        ]),
        "val": transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(imagenet_mean, imagenet_std),
        ]),
        "test": transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(imagenet_mean, imagenet_std),
        ]),
    }


def build_datasets(cfg: ExperimentConfig) -> Tuple[Dict[str, datasets.ImageFolder], List[str]]:
    split_train = IMAGE_ROOT / "train"
    split_val = IMAGE_ROOT / "val"
    split_test = IMAGE_ROOT / cfg.test_split_name

    _assert_split_exists(split_train)
    _assert_split_exists(split_val)
    _assert_split_exists(split_test)

    tfms = build_transforms(cfg.image_size)

    ds = {
        "train": datasets.ImageFolder(split_train, transform=tfms["train"]),
        "val": datasets.ImageFolder(split_val, transform=tfms["val"]),
        "test": datasets.ImageFolder(split_test, transform=tfms["test"]),
    }

    classes = ds["train"].classes
    for k in ("val", "test"):
        if ds[k].classes != classes:
            raise RuntimeError(
                f"Class mismatch between train and {k}. "
                f"train={classes}, {k}={ds[k].classes}"
            )

    return ds, classes


image_datasets, class_names = build_datasets(CFG)

print(f"Train: {len(image_datasets['train'])} images")
print(f"Val:   {len(image_datasets['val'])} images")
print(f"Test:  {len(image_datasets['test'])} images (split='{CFG.test_split_name}')")
print("Classes:", class_names)


## 3. Model

ResNet‑18 with a replaced final fully‑connected layer.


In [None]:
from typing import Optional
import torch
import torch.nn as nn
from torchvision import models


def create_resnet18(num_classes: int, pretrained: bool = True) -> nn.Module:
    '''
    Create ResNet-18 with a custom classification head.

    Uses the newer torchvision `Weights` API when available, with a fallback.
    '''
    if pretrained:
        try:
            weights = models.ResNet18_Weights.IMAGENET1K_V1  # torchvision >= 0.13
            model = models.resnet18(weights=weights)
        except Exception:
            model = models.resnet18(pretrained=True)  # older torchvision
    else:
        try:
            model = models.resnet18(weights=None)
        except Exception:
            model = models.resnet18(pretrained=False)

    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


## 4. Training & evaluation utilities

- Early stopping on **validation loss**
- Accuracy + macro precision/recall/F1 on the test set
- Top‑K accuracy (K ∈ {1,3,5})
- Confusion matrix saved per run


In [None]:
from __future__ import annotations

import copy
from typing import Dict, Tuple, Any

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def make_dataloaders(cfg: ExperimentConfig, datasets_dict) -> Dict[str, DataLoader]:
    pin = DEVICE.type == "cuda"  # pin_memory only helps CUDA
    return {
        split: DataLoader(
            datasets_dict[split],
            batch_size=cfg.batch_size,
            shuffle=(split == "train"),
            num_workers=cfg.num_workers,
            pin_memory=pin,
        )
        for split in ("train", "val", "test")
    }


def topk_accuracy(y_true: np.ndarray, probs: np.ndarray, k: int) -> float:
    topk = np.argsort(probs, axis=1)[:, -k:]
    hits = [yt in row for yt, row in zip(y_true, topk)]
    return float(np.mean(hits))


def train_one_run(
    model: nn.Module,
    dataloaders: Dict[str, DataLoader],
    cfg: ExperimentConfig,
) -> Tuple[nn.Module, float, float]:
    '''
    Trains a model with early stopping on validation loss.
    Returns (best_model, best_val_loss, best_val_acc).
    '''
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    best_wts = copy.deepcopy(model.state_dict())
    best_val_loss = float("inf")
    best_val_acc = 0.0
    epochs_no_improve = 0

    for epoch in range(cfg.epochs):
        for phase in ("train", "val"):
            is_train = phase == "train"
            model.train(is_train)

            running_loss = 0.0
            running_corrects = 0
            n = 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad(set_to_none=True)

                with torch.set_grad_enabled(is_train):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    preds = outputs.argmax(dim=1)

                    if is_train:
                        loss.backward()
                        optimizer.step()

                bs = inputs.size(0)
                running_loss += loss.item() * bs
                running_corrects += (preds == labels).sum().item()
                n += bs

            epoch_loss = running_loss / max(n, 1)
            epoch_acc = running_corrects / max(n, 1)

            if is_train:
                scheduler.step()

            if phase == "val":
                if epoch_loss < best_val_loss - 1e-6:
                    best_val_loss = epoch_loss
                    best_val_acc = epoch_acc
                    best_wts = copy.deepcopy(model.state_dict())
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1

        if epochs_no_improve >= cfg.patience:
            break

    model.load_state_dict(best_wts)
    return model, float(best_val_loss), float(best_val_acc)


@torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader, num_classes: int) -> Dict[str, Any]:
    model.eval()
    all_preds: list[int] = []
    all_labels: list[int] = []
    all_probs: list[np.ndarray] = []

    for inputs, labels in dataloader:
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()
        preds = outputs.argmax(dim=1).cpu().numpy()

        all_probs.append(probs)
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.numpy().tolist())

    y_true = np.array(all_labels)
    y_pred = np.array(all_preds)
    probs = np.vstack(all_probs) if len(all_probs) else np.zeros((0, num_classes), dtype=float)

    acc = accuracy_score(y_true, y_pred) if len(y_true) else float("nan")
    prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
    rec = recall_score(y_true, y_pred, average="macro", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)

    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))

    return {
        "y_true": y_true,
        "y_pred": y_pred,
        "probs": probs,
        "accuracy": float(acc),
        "precision_macro": float(prec),
        "recall_macro": float(rec),
        "f1_macro": float(f1),
        "top1": topk_accuracy(y_true, probs, 1) if len(y_true) else float("nan"),
        "top3": topk_accuracy(y_true, probs, 3) if len(y_true) else float("nan"),
        "top5": topk_accuracy(y_true, probs, 5) if len(y_true) else float("nan"),
        "cm": cm,
        "report": classification_report(y_true, y_pred, output_dict=True, zero_division=0),
    }


## 5. Plotting & saving artifacts

The helpers below save run outputs to `OUTPUT_DIR/run_XX_seed_YY/...`.


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, Any, List

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def _pretty_class_name(name: str) -> str:
    return name.replace("_", " ")


def save_confusion_matrix_png(
    cm: np.ndarray,
    class_names: List[str],
    out_path: Path,
    normalize: bool = True,
) -> None:
    """
    Save a confusion matrix PNG (no seaborn).
    - normalize=False: shows counts (int if possible; otherwise 1 decimal for mean CMs)
    - normalize=True: shows row-normalized percentages
    """
    cm = np.asarray(cm)

    if normalize:
        cm_plot = cm.astype(float)
        row_sums = cm_plot.sum(axis=1, keepdims=True)
        cm_plot = np.divide(cm_plot, np.maximum(row_sums, 1e-12))
        fmt = ".1%"  # show as percentage
    else:
        # If counts are integer-like (typical per-run CM), display as ints.
        # If not (e.g. mean across runs), display with 1 decimal.
        is_intlike = np.all(np.isclose(cm, np.round(cm)))
        if is_intlike:
            cm_plot = np.round(cm).astype(int)
            fmt = "d"
        else:
            cm_plot = cm.astype(float)
            fmt = ".1f"

    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(cm_plot, interpolation="nearest")
    fig.colorbar(im, ax=ax)

    tick_labels = [_pretty_class_name(c) for c in class_names]
    ax.set_xticks(np.arange(len(class_names)), labels=tick_labels, rotation=45, ha="right")
    ax.set_yticks(np.arange(len(class_names)), labels=tick_labels)

    ax.set_ylabel("True label")
    ax.set_xlabel("Predicted label")
    ax.set_title("Confusion Matrix (normalized)" if normalize else "Confusion Matrix (counts)")

    thresh = (cm_plot.max() / 2.0) if cm_plot.size else 0.0
    for i in range(cm_plot.shape[0]):
        for j in range(cm_plot.shape[1]):
            val = cm_plot[i, j]
            text = f"{val:{fmt}}"
            ax.text(
                j, i, text,
                ha="center", va="center",
                color="white" if val > thresh else "black",
                fontsize=9,
            )

    fig.tight_layout()
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def save_run_artifacts(
    run_tag: str,
    class_names: List[str],
    eval_dict: Dict[str, Any],
    *,
    fig_root: Path,
    met_root: Path,
    pred_root: Path,
) -> None:
    """Save per-run artifacts aligned to outputs/vision_resnet/{figures,metrics,preds}."""
    run_fig = fig_root / run_tag
    run_met = met_root / run_tag
    run_prd = pred_root / run_tag
    for d in (run_fig, run_met, run_prd):
        d.mkdir(parents=True, exist_ok=True)

    # ---- Metrics (JSON + small CSV) ----
    metrics = {
        "accuracy": float(eval_dict["accuracy"]),
        "precision_macro": float(eval_dict["precision_macro"]),
        "recall_macro": float(eval_dict["recall_macro"]),
        "f1_macro": float(eval_dict["f1_macro"]),
        "top1": float(eval_dict["top1"]),
        "top3": float(eval_dict["top3"]),
        "top5": float(eval_dict["top5"]),
    }
    with open(run_met / f"metrics_{run_tag}.json", "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    pd.DataFrame([{"run": run_tag, **metrics}]).to_csv(run_met / f"metrics_{run_tag}.csv", index=False)

    # Full classification report as JSON (handy for thesis tables)
    with open(run_met / f"classification_report_{run_tag}.json", "w", encoding="utf-8") as f:
        json.dump(eval_dict["report"], f, indent=2)

    # Per-class metrics CSV (precision/recall/f1/support)
    report = eval_dict["report"]
    per_class = []
    for cls in class_names:
        if cls in report:
            per_class.append({
                "class": cls,
                "precision": float(report[cls]["precision"]),
                "recall": float(report[cls]["recall"]),
                "f1": float(report[cls]["f1-score"]),
                "support": int(report[cls]["support"]),
            })
    pd.DataFrame(per_class).to_csv(run_met / f"per_class_metrics_{run_tag}.csv", index=False)

    # ---- Predictions ----
    y_true = np.asarray(eval_dict["y_true"], dtype=int)
    y_pred = np.asarray(eval_dict["y_pred"], dtype=int)
    probs = np.asarray(eval_dict["probs"], dtype=float)

    pred_df = pd.DataFrame({
        "true_label": y_true,
        "true_class": [class_names[i] for i in y_true],
        "pred_label": y_pred,
        "pred_class": [class_names[i] for i in y_pred],
        "correct": (y_true == y_pred),
    })
    for i, cls in enumerate(class_names):
        pred_df[f"prob_{cls}"] = probs[:, i]
    pred_df.to_csv(run_prd / f"predictions_{run_tag}.csv", index=False)

    # ---- Confusion matrix (array + figures) ----
    cm = np.asarray(eval_dict["cm"])
    np.save(run_met / f"confusion_matrix_{run_tag}.npy", cm)

    save_confusion_matrix_png(cm, class_names, run_fig / f"confusion_matrix_{run_tag}_counts.png", normalize=False)
    save_confusion_matrix_png(cm, class_names, run_fig / f"confusion_matrix_{run_tag}_normalized.png", normalize=True)


## 6. Main experiment: train **10 independent** ResNet‑18 models

This is the *primary* training block (ported from the last block of your original notebook, but cleaned and modularised).


## 6.1 Post‑hoc temperature scaling (per run)

For each trained ResNet‑18 model, we learn a single scalar temperature **T\*** on the **validation** split by minimising
negative log-likelihood (cross‑entropy) on fixed logits (Guo et al., 2017).

**Where outputs go**
- Model weights: `models/vision/<run_tag>.pth`
- Temperatures: `models/vision/temperatures/temperature_<run_tag>.npy`

The learned T\* is used later in fusion to calibrate the CNN probabilities without changing the predicted class (argmax is invariant under positive temperature scaling).


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Tuple, Dict, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim



def multiclass_brier_score(y_true: np.ndarray, y_proba: np.ndarray) -> float:
    '''Vectorised multi-class Brier score (lower is better).'''
    y_true = np.asarray(y_true, dtype=int)
    y_proba = np.asarray(y_proba, dtype=float)
    n, k = y_proba.shape
    y_onehot = np.zeros((n, k), dtype=float)
    y_onehot[np.arange(n), y_true] = 1.0
    return float(np.mean((y_proba - y_onehot) ** 2))


def expected_calibration_error(y_true: np.ndarray, y_proba: np.ndarray, n_bins: int = 15) -> float:
    '''Standard ECE (confidence vs. accuracy in bins).'''
    y_true = np.asarray(y_true, dtype=int)
    y_proba = np.asarray(y_proba, dtype=float)

    confidences = y_proba.max(axis=1)
    preds = y_proba.argmax(axis=1)
    accuracies = (preds == y_true).astype(float)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        lo, hi = bins[i], bins[i + 1]
        in_bin = (confidences > lo) & (confidences <= hi)
        prop = in_bin.mean()
        if prop > 0:
            bin_acc = accuracies[in_bin].mean()
            bin_conf = confidences[in_bin].mean()
            ece += prop * abs(bin_acc - bin_conf)
    return float(ece)


class TemperatureScaler(nn.Module):
    '''Single-parameter temperature scaling: logits / T, with T = exp(log_T) > 0.'''
    def __init__(self) -> None:
        super().__init__()
        self.log_T = nn.Parameter(torch.zeros(1))

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        T = torch.exp(self.log_T)
        return logits / T


@torch.no_grad()
def collect_logits_and_labels_cpu(
    model: nn.Module, loader: torch.utils.data.DataLoader, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
    '''
    Collect logits and labels for a loader, returning CPU tensors.

    We optimise temperature on CPU for maximum compatibility (LBFGS can be finicky on MPS).
    '''
    model.eval()
    logits_list = []
    labels_list = []
    for inputs, labels in loader:
        inputs = inputs.to(device)
        logits = model(inputs).detach().cpu()
        logits_list.append(logits)
        labels_list.append(labels.detach().cpu())
    return torch.cat(logits_list, dim=0), torch.cat(labels_list, dim=0)


def calibrate_temperature_from_val_logits(
    logits_cpu: torch.Tensor, labels_cpu: torch.Tensor, max_iter: int = 50
) -> float:
    '''
    Learn T* by minimising cross-entropy on validation logits (CPU).
    Returns scalar T* (>0).
    '''
    temperature = TemperatureScaler().cpu()
    nll = nn.CrossEntropyLoss()

    optimizer = optim.LBFGS(
        temperature.parameters(),
        lr=0.01,
        max_iter=max_iter,
        line_search_fn="strong_wolfe",
    )

    def closure():
        optimizer.zero_grad()
        loss = nll(temperature(logits_cpu), labels_cpu)
        loss.backward()
        return loss

    optimizer.step(closure)
    return float(torch.exp(temperature.log_T).item())


@torch.no_grad()
def evaluate_calibration(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    device: torch.device,
    T_star: float | None = None,
) -> Dict[str, float]:
    '''Compute Brier + ECE, optionally applying temperature scaling.'''
    model.eval()
    probs_list = []
    labels_list = []

    for inputs, labels in loader:
        inputs = inputs.to(device)
        logits = model(inputs)
        if T_star is not None:
            logits = logits / float(T_star)
        probs = F.softmax(logits, dim=1).detach().cpu().numpy()
        probs_list.append(probs)
        labels_list.append(labels.numpy())

    probs = np.vstack(probs_list)
    y_true = np.concatenate(labels_list)

    return {
        "brier": multiclass_brier_score(y_true, probs),
        "ece": expected_calibration_error(y_true, probs, n_bins=15),
    }


In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import torch

results_rows = []
cms = []

for run in range(1, CFG.num_runs + 1):
    seed = CFG.base_seed + (run - 1)
    set_seed(seed)

    run_tag = f"resnet18_run_{run:02d}_seed_{seed}"

    print("\n" + "="*80)
    print(f"RUN {run}/{CFG.num_runs}  |  seed={seed}  |  tag={run_tag}")
    print("="*80)

    dataloaders = make_dataloaders(CFG, image_datasets)

    model = create_resnet18(num_classes=CFG.num_classes, pretrained=CFG.use_pretrained).to(DEVICE)
    model, best_val_loss, best_val_acc = train_one_run(model, dataloaders, CFG)

    # Save weights to models/vision/
    weights_path = MODELS_DIR / f"{run_tag}.pth"
    torch.save(model.state_dict(), weights_path)

    # Evaluate on temporal holdout

    # --- Temperature scaling (learn T* on val; save to models/vision/temperatures) ---
    val_logits_cpu, val_labels_cpu = collect_logits_and_labels_cpu(model, dataloaders["val"], DEVICE)
    T_star = calibrate_temperature_from_val_logits(val_logits_cpu, val_labels_cpu, max_iter=50)
    temperature_path = TEMPS_DIR/ f"temperature_{run_tag}.npy"
    np.save(temperature_path, np.array([T_star], dtype=np.float32))

    # Optional calibration metrics (does not change argmax / accuracy)
    calib_before = evaluate_calibration(model, dataloaders["test"], DEVICE, T_star=None)
    calib_after = evaluate_calibration(model, dataloaders["test"], DEVICE, T_star=T_star)

    eval_dict = evaluate(model, dataloaders["test"], num_classes=CFG.num_classes)

    # Save artifacts to outputs/vision_resnet/{figures,metrics,preds}/
    save_run_artifacts(
        run_tag,
        class_names,
        eval_dict,
        fig_root=FIG_DIR,
        met_root=MET_DIR,
        pred_root=PRD_DIR,
    )

    row = {
        "run": run,
        "seed": seed,
        "run_tag": run_tag,
        "best_val_loss": float(best_val_loss),
        "best_val_acc": float(best_val_acc),
        "test_accuracy": float(eval_dict["accuracy"]),
        "test_precision_macro": float(eval_dict["precision_macro"]),
        "test_recall_macro": float(eval_dict["recall_macro"]),
        "test_f1_macro": float(eval_dict["f1_macro"]),
        "test_top1": float(eval_dict["top1"]),
        "test_top3": float(eval_dict["top3"]),
        "test_top5": float(eval_dict["top5"]),
        "weights_path": str(weights_path),
    }
    row.update({
        "temperature_T": float(T_star),
        "temperature_path": str(temperature_path),
        "calib_brier_before": float(calib_before["brier"]),
        "calib_ece_before": float(calib_before["ece"]),
        "calib_brier_after": float(calib_after["brier"]),
        "calib_ece_after": float(calib_after["ece"]),
    })

    print(f"Temp scaling: T*={T_star:.4f} | ECE {calib_before['ece']:.3f} -> {calib_after['ece']:.3f}")

    results_rows.append(row)
    cms.append(eval_dict["cm"])

    print(
        f"Val acc={best_val_acc:.3f} | "
        f"Test acc={eval_dict['accuracy']:.3f} | "
        f"Top-3={eval_dict['top3']:.3f} | "
        f"F1={eval_dict['f1_macro']:.3f}"
    )

# -------------------------
# Aggregate summary (metrics)
# -------------------------
summary_df = pd.DataFrame(results_rows)

# Match your existing naming convention (keep both for convenience)
summary_df.to_csv(MET_DIR / "summary_runs.csv", index=False)
summary_df.to_csv(MET_DIR / "resnet18_10models_summary.csv", index=False)

print("\n" + "-"*80)
print("AGGREGATE RESULTS")
print("-"*80)

def _mean_std(x):
    x = np.asarray(x, dtype=float)
    return float(np.nanmean(x)), float(np.nanstd(x))

for col in ["best_val_acc", "test_accuracy", "test_top3", "test_top5", "test_f1_macro"]:
    m, s = _mean_std(summary_df[col].values)
    print(f"{col:>15}: {m:.3f} ± {s:.3f}")

# Mean confusion matrix across runs
mean_cm = np.mean(np.stack(cms, axis=0), axis=0)
np.save(MET_DIR / "confusion_matrix_resnet18_MEAN_10runs.npy", mean_cm)

save_confusion_matrix_png(mean_cm, class_names, FIG_DIR / "confusion_matrix_resnet18_MEAN_10runs_counts.png", normalize=False)
save_confusion_matrix_png(mean_cm, class_names, FIG_DIR / "confusion_matrix_resnet18_MEAN_10runs_normalized.png", normalize=True)

print("\nSaved:")
print("-", MET_DIR / "summary_runs.csv")
print("-", MET_DIR / "resnet18_10models_summary.csv")
print("-", FIG_DIR / "confusion_matrix_resnet18_MEAN_10runs_counts.png")
print("-", FIG_DIR / "confusion_matrix_resnet18_MEAN_10runs_normalized.png")
print("-", MET_DIR / "confusion_matrix_resnet18_MEAN_10runs.npy")
print("-", MODELS_DIR, "(weights)")



If you want, the next step would be to move the utilities into a small `src/` package (e.g. `src/data.py`, `src/train.py`) and keep the notebook as a thin “experiment driver”.


## Appendix A: dataset utilities (optional)

The original notebook contained several one‑off scripts (splitting folders, resizing images, aligning parquet splits).
These are useful, but they should not run by default during training.

Set `RUN_UTILS = True` to execute any of them.


In [None]:
from __future__ import annotations

import shutil
from pathlib import Path
from typing import Tuple

RUN_UTILS = False  # <-- change to True if you want to execute these helpers


def split_imagefolder_train_val(
    train_dir: Path,
    out_train_dir: Path,
    out_val_dir: Path,
    val_split: float = 0.2,
    seed: int = 42,
    copy_files: bool = True,
    exts: Tuple[str, ...] = (".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"),
) -> None:
    '''
    Create a deterministic train/val split from an ImageFolder-style `train_dir`.

    Notes
    - We write into new directories (no in-place overwrite).
    - `copy_files=False` will MOVE files (faster, but destructive).
    '''
    rng = random.Random(seed)

    train_dir = Path(train_dir)
    out_train_dir = Path(out_train_dir)
    out_val_dir = Path(out_val_dir)

    if not train_dir.exists():
        raise FileNotFoundError(train_dir)

    out_train_dir.mkdir(parents=True, exist_ok=True)
    out_val_dir.mkdir(parents=True, exist_ok=True)

    for cls_dir in sorted([p for p in train_dir.iterdir() if p.is_dir()]):
        imgs = [p for p in cls_dir.iterdir() if p.suffix.lower() in exts]
        rng.shuffle(imgs)

        n_val = int(round(len(imgs) * val_split))
        val_imgs = imgs[:n_val]
        train_imgs = imgs[n_val:]

        (out_train_dir / cls_dir.name).mkdir(exist_ok=True)
        (out_val_dir / cls_dir.name).mkdir(exist_ok=True)

        for p in train_imgs:
            dst = out_train_dir / cls_dir.name / p.name
            (shutil.copy2 if copy_files else shutil.move)(p, dst)

        for p in val_imgs:
            dst = out_val_dir / cls_dir.name / p.name
            (shutil.copy2 if copy_files else shutil.move)(p, dst)

    print("Done.")
    print("Train split:", out_train_dir)
    print("Val split  :", out_val_dir)


if RUN_UTILS:
    # Example (adjust paths):
    # split_imagefolder_train_val(
    #     train_dir=CFG.data_dir / "train",
    #     out_train_dir=CFG.data_dir / "train_split",
    #     out_val_dir=CFG.data_dir / "val",
    #     val_split=0.2,
    #     seed=42,
    #     copy_files=True,
    # )
    pass


## Appendix B: align parquet splits with image splits (optional)

If you maintain a metadata table (e.g. `train_with_image_paths.parquet`) and want it split **exactly** like your image folders,
this helper will create `train.parquet` and `val.parquet` based on which image filenames appear in each split folder.

This refactors the original “split parquet file to match image split” block into a reusable function.


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Set, Tuple

import pandas as pd


def collect_filenames(split_dir: Path) -> Set[str]:
    '''Collect all filenames in an ImageFolder-style split (split/class_name/*.jpg).'''
    split_dir = Path(split_dir)
    names: Set[str] = set()
    for cls_dir in split_dir.iterdir():
        if not cls_dir.is_dir():
            continue
        for p in cls_dir.iterdir():
            if p.is_file():
                names.add(p.name)
    return names


def split_parquet_by_image_splits(
    parquet_path: Path,
    image_path_col: str,
    train_split_dir: Path,
    val_split_dir: Path,
    out_train_path: Path,
    out_val_path: Path,
) -> Tuple[Path, Path]:
    '''
    Split a parquet table into train/val tables by matching image filenames.

    Parameters
    ----------
    parquet_path : Path
        Input parquet file (contains a column with image paths).
    image_path_col : str
        Column name holding image paths.
    train_split_dir / val_split_dir : Path
        Folders containing the split images.
    out_train_path / out_val_path : Path
        Output parquet paths.

    Returns
    -------
    (out_train_path, out_val_path)
    '''
    parquet_path = Path(parquet_path)
    df = pd.read_parquet(parquet_path)

    if image_path_col not in df.columns:
        raise KeyError(f"'{image_path_col}' not found in parquet columns: {list(df.columns)[:10]} ...")

    train_files = collect_filenames(train_split_dir)
    val_files = collect_filenames(val_split_dir)

    df = df.copy()
    df["_image_filename"] = df[image_path_col].apply(lambda x: Path(x).name)

    train_df = df[df["_image_filename"].isin(train_files)].drop(columns=["_image_filename"])
    val_df = df[df["_image_filename"].isin(val_files)].drop(columns=["_image_filename"])

    out_train_path = Path(out_train_path)
    out_val_path = Path(out_val_path)
    out_train_path.parent.mkdir(parents=True, exist_ok=True)
    out_val_path.parent.mkdir(parents=True, exist_ok=True)

    train_df.to_parquet(out_train_path)
    val_df.to_parquet(out_val_path)

    print("Saved:")
    print("-", out_train_path, f"({len(train_df)} rows)")
    print("-", out_val_path, f"({len(val_df)} rows)")

    return out_train_path, out_val_path


# Example (disabled by default):
# RUN_UTILS = True
# if RUN_UTILS:
#     split_parquet_by_image_splits(
#         parquet_path=CFG.project_root / "data/amsterdam/train_with_image_paths.parquet",
#         image_path_col="image_path",
#         train_split_dir=CFG.data_dir / "train",
#         val_split_dir=CFG.data_dir / "val",
#         out_train_path=MET_DIR / "metadata_train.parquet",
#         out_val_path=MET_DIR / "metadata_val.parquet",
#     )


## Appendix C: blur + resize a dataset copy (optional)

This refactors the original “create blurred + resized dataset” block into a function.

- Applies a Gaussian blur (optional).
- Resizes so the **shorter edge** equals `min_edge` (keeps aspect ratio).
- Mirrors the ImageFolder structure (`split/class_name/image.jpg`).

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Iterable, Tuple

from PIL import Image, ImageFilter
from torchvision import transforms


IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}


def iter_images(root: Path) -> Iterable[Path]:
    root = Path(root)
    for p in root.rglob("*"):
        if p.is_file() and p.suffix.lower() in IMG_EXTS:
            yield p


def blur_resize_imagefolder_copy(
    src_dir: Path,
    dst_dir: Path,
    min_edge: int = 50,
    blur_radius: float = 2.0,
    jpg_quality: int = 95,
) -> None:
    '''
    Create a processed copy of an ImageFolder dataset.

    Parameters
    ----------
    src_dir : Path
        Source dataset root (contains train/val/test...).
    dst_dir : Path
        Destination root.
    min_edge : int
        Shorter side after resizing.
    blur_radius : float
        Gaussian blur radius. Use 0.0 to disable.
    '''
    src_dir = Path(src_dir)
    dst_dir = Path(dst_dir)
    resize = transforms.Resize(min_edge)

    for img_path in iter_images(src_dir):
        rel = img_path.relative_to(src_dir)
        out_path = dst_dir / rel
        out_path.parent.mkdir(parents=True, exist_ok=True)

        try:
            with Image.open(img_path) as img:
                img = img.convert("RGB")
                if blur_radius and blur_radius > 0:
                    img = img.filter(ImageFilter.GaussianBlur(radius=float(blur_radius)))
                img = resize(img)
                img.save(out_path, quality=jpg_quality)
        except Exception as e:
            print(f"Skipping {img_path}: {e}")

    print("Done.")
    print("Source      :", src_dir)
    print("Destination :", dst_dir)


# Example (disabled by default):
# RUN_UTILS = True
# if RUN_UTILS:
#     blur_resize_imagefolder_copy(
#         src_dir=CFG.data_dir,
#         dst_dir=CFG.data_dir.parent / (CFG.data_dir.name + "_processed"),
#         min_edge=50,
#         blur_radius=2.0,
#     )
