In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 1. 🧠 Self-Supervised Models (SimCLR, MoCo, Rotation, JEPA)

### 1. Pre-training

* Allena l’**encoder una sola volta**, su tutti i pazienti di `train_fold0`.
* Salva i checkpoint (encoder + projector).

### 2. Estrazione delle feature

* Carica il **checkpoint fisso** del pretraining.
* Estrai le feature su tutti i fold (`train_fold{i}`, `val_fold{i}`, `test_holdout`).
* Salva i file come `{model}_features_fold{i}.pt`.

### 3. Linear-Probe per ogni fold `i ∈ [0, ..., N-1]`

* **Train**: allena la testa lineare (LogReg, MLP) sulle feature di `train_fold{i}`.
* **Validation**: usa `val_fold{i}` per:

  * early stopping,
  * salvare il miglior probe,
  * applicare **Temperature Scaling**.
* **Test**: valuta il probe calibrato su `test_holdout`, usando anche **MC-Dropout**.

### 4. Aggregazione finale

* Raccogli i risultati su `test_holdout` per ogni fold.
* Calcola **media ± deviazione standard** per ogni metrica (Accuracy, F1, AUC, ECE, incertezza).
* Questo è il risultato finale del tuo modello SSL + probe.

> 🔒 **Non riaddestrare l’encoder nei fold 1–N.**
> Lo scopo è testarne la **capacità di generalizzazione task-agnostica**, non adattarlo a ogni fold.

---

## 2. 🧪 Supervised & Transfer Learning

Qui **non hai un encoder fisso**. Ogni fold ha il proprio training da zero.

### Per ciascun fold `i ∈ [0, ..., N-1]`

* **Train**: allena l’intero modello (es. ResNet-50) su `train_fold{i}`.
* **Validation**: usa `val_fold{i}` per early-stopping e calibrazione.
* **Test**: valuta sempre su `test_holdout` (o sul val-fold se non esiste un holdout).
* Salva un checkpoint per ogni fold (`supervised_fold{i}.pt`, `transfer_fold{i}.pt`, ecc.).

### Aggregazione finale

* Come per gli SSL, calcola la **media ± deviazione** delle metriche su `test_holdout` per i vari fold.
* Non scegli il modello col miglior punteggio, ma riporti la **media aggregata**.

> ℹ️ **Facoltativo**: puoi evidenziare il fold più vicino alla media, **solo a scopo illustrativo**.

---

## 3. 🔁 Confronto tra pipeline SSL e SL

| Step                   | SSL models (SimCLR, MoCo, ...)  | SL/Transfer models              |
| ---------------------- | ------------------------------- | ------------------------------- |
| **Encoder training**   | Solo su fold0                   | Uno per ogni fold               |
| **Feature extraction** | Uno per ogni fold               | –                               |
| **Probe/classifier**   | Uno per ogni fold               | Uno per ogni fold               |
| **Inference**          | Su `test_holdout` per ogni fold | Su `test_holdout` per ogni fold |
| **Checkpoint**         | Encoder 1×                      | 1× per fold                     |
| **Output finale**      | Media ± std su tutti i fold     | Media ± std su tutti i fold     |

---

## 🎯 Perché questa architettura?

* Nei **modelli SSL**, vogliamo dimostrare che un encoder **generalizza** a nuovi pazienti come **feature extractor**, senza mai essere fine-tuned.
* Nei **modelli supervisionati**, alleniamo da capo su ogni fold per valutare una baseline comparabile (ma meno task-agnostica).

In entrambi i casi:

✅ **Non scegli il miglior fold**,
✅ **Riporti solo le medie cross-fold**,
✅ **Dimostri affidabilità, generalizzazione e riproducibilità**.

---

## 📊 Output finale

Le metriche aggregate sono presentate in una tabella riassuntiva:

| Model             | Accuracy (μ±σ) | Macro-F1 (μ±σ) | ECE (μ±σ)   | Uncertainty (μ±σ) |
| ----------------- | -------------- | -------------- | ----------- | ----------------- |
| SimCLR + probe    | 0.83 ± 0.04    | 0.79 ± 0.06    | 0.03 ± 0.01 | 0.12 ± 0.03       |
| MoCo-v2 + probe   | 0.78 ± 0.05    | …              | …           | …                 |
| Rotation + probe  | 0.74 ± 0.08    | …              | …           | …                 |
| Supervised        | 0.80 ± 0.03    | …              | …           | …                 |
| Transfer learning | 0.82 ± 0.04    | …              | …           | …                 |

> Ogni riga rappresenta una valutazione **completa e aggregata** del modello, utile per confronti clinici e accademici.


In [2]:
# ## Cell 1 – Environment Setup & Dependencies
# Compatibile con Google Colab (GPU/CPU) e ambiente locale (VS Code).

# %%
import os, sys, subprocess, importlib.util
from pathlib import Path

print("📦 [DEBUG] Avvio configurazione ambiente…")

# ────────────────────────────────────────────────────────────────────────
# 1) Rileva ambiente (Colab vs locale)
# ────────────────────────────────────────────────────────────────────────
IN_COLAB = Path("/content").exists()
if IN_COLAB:
    print("📍 [DEBUG] Google Colab rilevato.")
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive", force_remount=False)
else:
    print("💻 [DEBUG] Ambiente locale (VS Code / CLI) rilevato.")

# ────────────────────────────────────────────────────────────────────────
# 2) Definisci PROJECT_ROOT (ENV > default mapping)
# ────────────────────────────────────────────────────────────────────────
DEFAULT_ENV_PATHS = {
    "colab": "/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project",
    "local": "/Users/stefanoroybisignano/Desktop/MLA/project/wsi-ssrl-rcc_project",
}
PROJECT_ROOT = Path(os.getenv(
    "PROJECT_ROOT",
    DEFAULT_ENV_PATHS["colab" if IN_COLAB else "local"])
).resolve()

sys.path.append(str(PROJECT_ROOT / "src"))
print(f"📁 [DEBUG] PROJECT_ROOT → {PROJECT_ROOT}")

# ────────────────────────────────────────────────────────────────────────
# 3) Utility per installare pacchetti mancanti
# ────────────────────────────────────────────────────────────────────────
def _missing(pkgs):
    return [p for p in pkgs if importlib.util.find_spec(p) is None]

def _install(pkgs, idx_url=None):
    if not pkgs:
        return
    cmd = [sys.executable, "-m", "pip", "install", "--quiet"]
    if idx_url:
        cmd += ["--index-url", idx_url]
    subprocess.check_call(cmd + pkgs)

# ────────────────────────────────────────────────────────────────────────
# 4) Verifica / installa PyTorch (se non presente)
#    • In Colab non sovrascrive la versione pre-installata
# ────────────────────────────────────────────────────────────────────────
TORCH_PKGS = ["torch", "torchvision", "torchaudio"]

if _missing(["torch"]):
    print("🔧 [DEBUG] PyTorch non trovato → installazione in corso…")
    if IN_COLAB:
        GPU = Path("/usr/local/cuda").exists()
        INDEX = "https://download.pytorch.org/whl/cu121" if GPU else "https://download.pytorch.org/whl/cpu"
        _install(TORCH_PKGS, INDEX)
    else:  # locale: lascia scegliere all'utente il build corretto
        _install(TORCH_PKGS)
else:
    import torch
    print(f"✅ [DEBUG] PyTorch già presente ({torch.__version__})")

# ────────────────────────────────────────────────────────────────────────
# 5) Installazione pacchetti ausiliari (sempre sicura)
# ────────────────────────────────────────────────────────────────────────
AUX_PKGS = ["webdataset", "tqdm", "pillow", "pyyaml", "joblib"]
missing_aux = _missing(AUX_PKGS)
if missing_aux:
    print(f"🔧 [DEBUG] Installazione pacchetti ausiliari mancanti: {missing_aux}")
    _install(missing_aux)
else:
    print("✅ [DEBUG] Pacchetti ausiliari già presenti.")

# ────────────────────────────────────────────────────────────────────────
# 6) Info dispositivo
# ────────────────────────────────────────────────────────────────────────
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ [DEBUG] Torch device disponibile → {DEVICE}")

# ────────────────────────────────────────────────────────────────────────
# 7) Costante path per Data Tarball (utilizzata negli step successivi)
# ────────────────────────────────────────────────────────────────────────
DATA_TARBALL = PROJECT_ROOT / "data" / "processed"
print(f"📦 [DEBUG] DATA_TARBALL → {DATA_TARBALL}")


📦 [DEBUG] Avvio configurazione ambiente…
📍 [DEBUG] Google Colab rilevato.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
📁 [DEBUG] PROJECT_ROOT → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project
✅ [DEBUG] PyTorch già presente (2.6.0+cu124)
🔧 [DEBUG] Installazione pacchetti ausiliari mancanti: ['pillow', 'pyyaml']
🖥️ [DEBUG] Torch device disponibile → cuda
📦 [DEBUG] DATA_TARBALL → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed


In [3]:
# %% -------------------------------------------------------------------- #
# Cell 2 – Configurazione & Setup Esperimento                            #
# ----------------------------------------------------------------------- #
import yaml
import datetime
import os
import pprint
import random
import numpy as np
import torch
from pathlib import Path

# ─── 1) Carica il file di configurazione ───────────────────────────────
cfg_path = PROJECT_ROOT / "config" / "training.yaml"
assert cfg_path.exists(), f"❌ File mancante: {cfg_path}"
cfg = yaml.safe_load(cfg_path.read_text())
print(f"📄 [DEBUG] Config caricata da: {cfg_path}")

# ─── 2) Genera EXP_CODE ────────────────────────────────────────────────
yaml_exp = cfg.get("exp_code", "")
env_exp  = os.getenv("EXP_CODE", "")
if yaml_exp:
    EXP_CODE, src = yaml_exp, "YAML"
elif env_exp:
    EXP_CODE, src = env_exp, "ENV"
else:
    EXP_CODE, src = datetime.datetime.now().strftime("%Y%m%d%H%M%S"), "TIMESTAMP"
os.environ["EXP_CODE"] = EXP_CODE
cfg["exp_code"] = EXP_CODE
print(f"🔑 [DEBUG] EXP_CODE → {EXP_CODE} (fonte: {src})")

# ─── 3) Parametri dataset & folds ──────────────────────────────────────
DATASET_ID = cfg["data"]["dataset_id"]
FOLDS = cfg.get("folds", [0])
TRAIN_ENCODER_ONCE = cfg.get("train_encoder_once", False)

print(f"🧬 [DEBUG] DATASET_ID         = {DATASET_ID}")
print(f"🔁 [DEBUG] Folds              = {FOLDS}")
print(f"🔒 [DEBUG] train_encoder_once = {TRAIN_ENCODER_ONCE}")

# ─── 4) Crea la directory dell’esperimento ─────────────────────────────
exp_dir_rel = cfg["output"]["exp_dir"].format(
    dataset_id=DATASET_ID,
    exp_code=EXP_CODE
)
# Use PROJECT_ROOT without early resolve to ensure Drive path
EXP_BASE = PROJECT_ROOT / exp_dir_rel
# Safety check: ensure EXP_BASE is under PROJECT_ROOT
if not str(EXP_BASE.resolve()).startswith(str(PROJECT_ROOT.resolve())):
    raise RuntimeError(f"🚨 EXP_BASE NON sotto PROJECT_ROOT! → {EXP_BASE}")
# Create experiment directory
EXP_BASE.mkdir(parents=True, exist_ok=True)
print(f"📂 [DEBUG] EXP_BASE → {EXP_BASE.resolve()}")

# ─── 5) Salva snapshot del file YAML ───────────────────────────────────
snap = EXP_BASE / f"training_{EXP_CODE}.yaml"
if not snap.exists():
    snap.write_text(yaml.dump(cfg, sort_keys=False))
    print(f"💾 [DEBUG] Salvato snapshot → {snap.resolve()}")
else:
    print(f"ℹ️  [DEBUG] Snapshot già presente → {snap.resolve()}")

# ─── 6) Imposta seed per la riproducibilità ────────────────────────────
SEED = cfg.get("seed", 42)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
print(f"🎲 [DEBUG] Seed globale impostato a: {SEED}")


📄 [DEBUG] Config caricata da: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/config/training.yaml
🔑 [DEBUG] EXP_CODE → 20250723180604 (fonte: TIMESTAMP)
🧬 [DEBUG] DATASET_ID         = dataset_9f30917e
🔁 [DEBUG] Folds              = [0, 1, 2, 3]
🔒 [DEBUG] train_encoder_once = True
📂 [DEBUG] EXP_BASE → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604
💾 [DEBUG] Salvato snapshot → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/training_20250723180604.yaml
🎲 [DEBUG] Seed globale impostato a: 42


In [4]:
# ## Cell 3 – Import & Registrazione dei Trainer
# Import dinamico dei modelli definiti in YAML e verifica del registry + tipo (SSL/SL).
import sys
import importlib
from utils.training_utils.registry import TRAINER_REGISTRY

# 1) Leggi i nomi dei modelli e i loro tipi da YAML
model_cfgs = cfg["models"]
model_names = list(model_cfgs.keys())
print(f"🔄 [DEBUG] Modelli configurati in YAML: {model_names}")

# 2) Import dinamico di ciascun modulo trainers.{model}
for name in model_names:
    module_name = f"trainers.{name}"
    try:
        if module_name in sys.modules:
            importlib.reload(sys.modules[module_name])
            print(f"✅ [DEBUG] Ricaricato {module_name}")
        else:
            importlib.import_module(module_name)
            print(f"✅ [DEBUG] Importato {module_name}")
    except ImportError as e:
        print(f"❌ [DEBUG] Errore importazione {module_name}: {e}")

# 3) Verifica che tutti i modelli siano registrati e mostra il loro tipo
missing = [n for n in model_names if n not in TRAINER_REGISTRY]
if missing:
    print(f"❌ [DEBUG] Trainer mancanti nel registry: {missing}")
else:
    print("📚 [DEBUG] Trainer registry contiene e corrispondenti tipi:")
    for name in model_names:
        # Usa solo il type dal file YAML per ogni modello
        ttype = model_cfgs[name].get("type", "unknown").upper()
        print(f"  • {name:<12} → {ttype}")

# 3) Verifica funzioni in sottomoduli
import importlib, pkgutil, inspect

print("🔎 [DEBUG] Sotto-moduli caricati da utils.training_utils:")
for mod in ("registry", "device_io", "data_utils", "model_utils", "metrics"):
    m = importlib.import_module(f"utils.training_utils.{mod}")
    print(f"  • {mod:<12} → functions:", [n for n, o in inspect.getmembers(m) if inspect.isfunction(o)])

🔄 [DEBUG] Modelli configurati in YAML: ['simclr', 'moco_v2', 'rotation', 'jepa', 'supervised', 'transfer']
✅ [DEBUG] Importato trainers.simclr
✅ [DEBUG] Importato trainers.moco_v2
✅ [DEBUG] Importato trainers.rotation
✅ [DEBUG] Importato trainers.jepa
✅ [DEBUG] Importato trainers.supervised
✅ [DEBUG] Importato trainers.transfer
📚 [DEBUG] Trainer registry contiene e corrispondenti tipi:
  • simclr       → SSL
  • moco_v2      → SSL
  • rotation     → SSL
  • jepa         → SSL
  • supervised   → SL
  • transfer     → SL
🔎 [DEBUG] Sotto-moduli caricati da utils.training_utils:
  • registry     → functions: ['register_trainer']
  • device_io    → functions: ['choose_device', 'get_latest_checkpoint', 'load_checkpoint', 'load_json', 'save_checkpoint', 'save_joblib', 'save_json']
  • data_utils   → functions: ['build_loader', 'count_samples', 'default_transforms', 'discover_classes', 'extract_labels_from_keys', 'load_classifier', 'load_features', 'parse_label_from_filename', 'save_classifier

In [5]:
# %% -------------------------------------------------------------------- #
# Cell 4 – Helper utilities (Tee, paths, selezione, …)                    #
# ----------------------------------------------------------------------- #
import contextlib
import sys
import time
import math
import inspect
from pathlib import Path
from typing import Any

from utils.training_utils.registry import TRAINER_REGISTRY
from utils.training_utils.device_io import (
    get_latest_checkpoint,
    load_checkpoint,
    save_json,
    save_joblib,
)
from utils.training_utils.data_utils import count_samples
from trainers.train_classifier import train_classifier
from utils.training_utils.metrics import apply_temperature_scaling


class _Tee:
    """
    Classe per duplicare lo stdout/stderr su più target (es. console + file).
    """
    def __init__(self, *tgts):
        self.tgts = tgts

    def write(self, data: str):
        for t in self.tgts:
            t.write(data)
            t.flush()

    def flush(self):
        for t in self.tgts:
            t.flush()


def _global_experiments_append(line: str):
    """
    Appende una riga al file esperimenti globale (esperiments.md).
    """
    global_file = EXP_BASE.parent.parent / "experiments.md"
    with open(global_file, "a") as f:
        f.write(line.rstrip() + "\n")


# %% ----------------------------------------------------------------------- #
# Cell 4 – Helper utilities (Tee, paths, selezione, …)                    #
# ------------------------------------------------------------------------ #
import os
from pathlib import Path
from typing import Any

from utils.training_utils.registry import TRAINER_REGISTRY
from utils.training_utils.device_io import (
    get_latest_checkpoint,
    load_checkpoint,
    save_json,
    save_joblib,
)
from utils.training_utils.data_utils import count_samples
from trainers.train_classifier import train_classifier
from utils.training_utils.metrics import apply_temperature_scaling


def _paths(cfg: dict, model: str, fold: int) -> dict[str, Path]:
    """
    Costruisce tutti i path di output (training, inference, explain, aggregate, ecc.)
    per uno specifico modello e fold, sempre ancorati a PROJECT_ROOT.
    """
    # Parametri di formattazione iniziali
    ph = {
        'dataset_id': cfg['data']['dataset_id'],
        'exp_code': cfg['exp_code'],
        'model_name': model,
        'fold_idx': fold,
    }
    # 1) exp_dir relativo
    rel_exp_dir = cfg['output']['exp_dir'].format(**ph)
    ph['exp_dir'] = rel_exp_dir
    # 2) exp_model_dir relativo (usa ph['exp_dir'])
    rel_exp_model_dir = cfg['output']['exp_model_dir'].format(**ph)

    # Costruisci path assoluti sotto PROJECT_ROOT
    exp_dir       = (PROJECT_ROOT / rel_exp_dir)
    exp_model_dir = (PROJECT_ROOT / rel_exp_model_dir)

    # Directory per training, inference, explain, aggregate, experiment level
    tr = exp_model_dir / f"fold{fold}" / "training"
    inf = exp_model_dir / f"fold{fold}" / "inference"
    ex = exp_model_dir / f"fold{fold}" / "explain"
    ag = exp_model_dir / "_aggregate"
    el = exp_dir / "_experiment_level"

    # Creazione cartelle
    for d in (tr, inf, ex, ag, el):
        d.mkdir(parents=True, exist_ok=True)

    # Ritorna i path assoluti per tutti gli artefatti
    return {
        'ckpt_dir':       tr.resolve(),
        'ckpt_tpl':       (tr / f"{model}_bestepoch{{epoch:03d}}_fold{fold}.pt").resolve(),
        'features':       (tr / f"{model}_features_fold{fold}.pt").resolve(),
        'features_train': (tr / f"{model}_features_train_fold{fold}.pt").resolve(),
        'features_val':   (tr / f"{model}_features_val_fold{fold}.pt").resolve(),
        'clf':            (tr / f"{model}_classifier_fold{fold}.joblib").resolve(),
        'scaler':         (tr / f"{model}_ts_scaler_fold{fold}.joblib").resolve(),
        'log':            (tr / f"{model}_train_log_fold{fold}.md").resolve(),
        'loss_json':      (tr / f"{model}_train_valid_loss_fold{fold}.json").resolve(),

        'patch_preds':    (inf / f"{model}_patch_preds_fold{fold}.pt").resolve(),
        'patient_preds':  (inf / f"{model}_patient_preds_fold{fold}.csv").resolve(),
        'mc_logits':      (inf / f"{model}_mc_logits_fold{fold}.npy").resolve(),
        'metrics':        (inf / f"{model}_metrics_fold{fold}.json").resolve(),

        'gradcam_dir':    ex.resolve(),
        'metadata_csv':   (ex / f"{model}_metadata_gradcam_fold{fold}.csv").resolve(),

        'aggregate_metrics': (ag / f"{model}_metrics.json").resolve(),
        'aggregate_summary': (ag / f"{model}_summary_agg.jpg").resolve(),

        'comparison_json':    (el / "models_comparison.json").resolve(),
        'comparison_img':     (el / "models_comparison.jpg").resolve(),

        'readme':            (exp_dir / "README_EXPERIMENT.md").resolve(),
    }




def _completed(paths: dict[str, Path], is_ssl: bool) -> bool:
    """
    Verifica se il training + artefatti SSL sono già stati completati.
    """
    if not get_latest_checkpoint(paths["ckpt_dir"]):
        return False
    if is_ssl:
        return all(paths[k].exists() for k in (
            "features_train", "features_val", "clf", "scaler", "loss_json"
        ))
    return True


def _select_models(cfg: dict) -> dict[str, dict[str, Any]]:
    """
    Seleziona i modelli da eseguire in base a `run_models` o tutti i modelli.
    """
    wanted = cfg.get("run_models") or list(cfg["models"].keys())
    return {name: cfg["models"][name] for name in wanted}


def _init_trainer(name: str, m_cfg: dict, data_cfg: dict, ckpt_dir: Path):
    """
    Inizializza il trainer registrato per nome.
    """
    tr = TRAINER_REGISTRY[name](m_cfg, data_cfg)
    tr.ckpt_dir = ckpt_dir
    tr.m_cfg = m_cfg
    tr.data_cfg = data_cfg
    tr.is_ssl = m_cfg.get("type", "").lower() == "ssl"
    return tr

In [6]:
# %% ----------------------------------------------------------------------- #
# Cell 5 – Training loop                                                     #
# ----------------------------------------------------------------------- #
import math
import time
import inspect
from pathlib import Path

import torch
import joblib
import numpy as np

from utils.training_utils.device_io import get_latest_checkpoint, load_checkpoint, save_json
from utils.training_utils.data_utils import count_samples
from utils.training_utils.metrics import TemperatureScaler
from utils.training_utils.data_utils import extract_labels_from_keys
from trainers.train_classifier import train_classifier


import time
import inspect
from pathlib import Path
from typing import Optional

import torch
import joblib
import numpy as np

from utils.training_utils.device_io import get_latest_checkpoint, load_checkpoint, save_json
from utils.training_utils.data_utils import count_samples
from utils.training_utils.metrics import TemperatureScaler
from utils.training_utils.data_utils import extract_labels_from_keys
from trainers.train_classifier import train_classifier


def _get_total_batches(loader) -> Optional[int]:
    """
    Try to infer the number of batches from the loader.
    Returns ``None`` if the loader/dataset has no valid length (e.g. WebDataset).
    """
    try:
        return len(loader)  # may raise TypeError if undefined
    except TypeError:
        return None


def _run_full_training(trainer, paths: dict[str, Path], epochs: int) -> None:
    """
    Full training loop with per-batch logging.

    * If ``len(trainer.train_loader)`` is available, a percentage and ETA are shown.
    * Otherwise, only the current batch index is logged, avoiding wrong totals.
    """
    is_ssl = getattr(trainer, "is_ssl", False)
    history: list[dict] = []

    for epoch in range(1, epochs + 1):
        start_time = time.time()
        loss_sum = 0.0
        corr_sum = 0
        seen = 0

        # ── try to compute total batches on-the-fly ───────────────────────────
        total_batches = _get_total_batches(trainer.train_loader)
        print(f"[fold{trainer.cfg_fold}] ── Epoch {epoch}/{epochs} ──")

        for i, batch in enumerate(trainer.train_loader, start=1):
            # if i == 3:
            #   break
            sig = inspect.signature(trainer.train_step)
            result = (
                trainer.train_step(*batch)
                if len(sig.parameters) > 1
                else trainer.train_step(batch)
            )

            if len(result) == 4:
                _, loss, correct, bs = result
            else:
                loss, bs = result
                correct = 0

            loss_sum += loss * bs
            corr_sum += correct
            seen += bs

            # ── logging ───────────────────────────────────────────────────────
            msg = f"[fold{trainer.cfg_fold}] Batch {i}"
            if total_batches:
                pct = (i / total_batches) * 100
                eta = (time.time() - start_time) / i * (total_batches - i)
                msg += f"/{total_batches} ({pct:5.1f}%) | ETA {eta:6.1f}s"
            msg += f" | Loss {loss_sum/seen:.4f}"
            if not is_ssl:
                msg += f" | Acc {corr_sum/seen:.3f}"
            print(msg, flush=True)

        # ── end-of-epoch bookkeeping ─────────────────────────────────────────
        train_loss = loss_sum / seen
        if not is_ssl:
            val_loss, val_acc = trainer.validate_epoch()
            trainer.post_epoch(epoch, val_acc)
            history.append(
                {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_acc": val_acc}
            )
            print(f"[fold{trainer.cfg_fold}] Val → Loss {val_loss:.4f} | Acc {val_acc:.3f}")
        else:
            trainer.post_epoch(epoch, train_loss)
            history.append({"epoch": epoch, "train_loss": train_loss})
            print(f"[fold{trainer.cfg_fold}] Train → Loss {train_loss:.4f}")

        print(f"[fold{trainer.cfg_fold}] ⏱  {time.time() - start_time:.1f}s\n")

    save_json(history, paths["loss_json"])


def _resume_or_train(trainer, paths: dict[str, Path], epochs: int):
    """
    Resume from last checkpoint if available, otherwise run full training.
    """
    ckpt = get_latest_checkpoint(paths["ckpt_dir"])
    if ckpt:
        print(f"[fold{trainer.cfg_fold}] ⏩  Resuming from {ckpt.name}")
        model, optimizer = trainer.get_resume_model_and_optimizer()
        load_checkpoint(ckpt, model, optimizer)
    _run_full_training(trainer, paths, epochs)


def _ensure_ssl_artifacts(trainer, paths: dict[str, Path]):
    """
    For SSL models:
      1) Extract train and val features
      2) Train a linear probe on train features
      3) Calibrate the probe via temperature scaling on val logits
    """
    # 1) Extract train features
    if not paths["features_train"].exists():
        trainer.extract_features_to(paths["features_train"], split="train")

    # 2) Extract val features
    if not paths["features_val"].exists():
        trainer.extract_features_to(paths["features_val"], split="val")

    # 3) Train the linear classifier
    if not paths["clf"].exists():
        train_classifier(str(paths["features_train"]), str(paths["clf"]))

    # 4) Temperature scaling on validation logits
    if not paths["scaler"].exists():
        # Load validation features
        val_data = torch.load(paths["features_val"], map_location="cpu")
        X_val = val_data["features"].numpy()
        keys_val = val_data["keys"]

        # Load classifier and label encoder
        bundle = joblib.load(paths["clf"])
        clf = bundle["model"]
        le = bundle["label_encoder"]

        # Obtain logits or convert probs to logits
        if hasattr(clf, "decision_function"):
            logits = clf.decision_function(X_val)
        else:
            probs = clf.predict_proba(X_val)
            logits = np.log(probs + 1e-12)

        # Extract labels from keys
        labels = extract_labels_from_keys(keys_val, le)

        # Fit the TemperatureScaler
        scaler = TemperatureScaler().fit(logits, labels)
        joblib.dump(scaler, str(paths["scaler"]))


In [None]:
# %% -------------------------------------------------------------------- #
# Cell 6 – Modular Launch & Auto-Recover (cross-fold)                     #
# ------------------------------------------------------------------------#
from pathlib import Path
import contextlib, sys, torch, os
from utils.training_utils.device_io import get_latest_checkpoint, load_checkpoint

def launch_training(cfg: dict) -> None:
    """Esegue training + generazione artefatti per tutti i modelli e fold."""
    for name, m_cfg in _select_models(cfg).items():
        is_ssl = (m_cfg.get("type") == "ssl")
        epochs = int(m_cfg["training"]["epochs"])
        print(f"\n🚀  Modello '{name}'  ({'SSL' if is_ssl else 'SL'}) – epochs={epochs}")

        for fold in cfg["folds"]:
            # 1) Configurazione dati
            data_cfg = {
                "train": str((PROJECT_ROOT / cfg["data"]["train"]
                              .format(dataset_id=cfg["data"]["dataset_id"], fold_idx=fold))
                             .resolve()),
                "val":   str((PROJECT_ROOT / cfg["data"]["val"]
                              .format(dataset_id=cfg["data"]["dataset_id"], fold_idx=fold))
                             .resolve()),
                "test":  str((PROJECT_ROOT / cfg["data"]["test"]
                              .format(dataset_id=cfg["data"]["dataset_id"]))
                             .resolve())
            }
            if not Path(data_cfg["train"]).exists():
                print(f"[fold{fold}] ⚠️  shard train mancante → skip")
                continue

            # 2) Path output + trainer
            paths   = _paths(cfg, name, fold)
            trainer = _init_trainer(name, m_cfg, data_cfg, paths["ckpt_dir"])
            trainer.cfg_fold     = fold
            trainer.train_loader = trainer.build_loader("train")
            if hasattr(trainer, "validate_epoch"):
                trainer.val_loader = trainer.build_loader("val")

            # 3) Logging su stdout + file
            with open(paths["log"], "a") as logf, \
                 contextlib.redirect_stdout(_Tee(sys.stdout, logf)), \
                 contextlib.redirect_stderr(_Tee(sys.stderr, logf)):

                print(f"[fold{fold}] 📂  ckpt dir      → {paths['ckpt_dir'].resolve()}")
                print(f"[fold{fold}] 🏷   train shards  → {Path(data_cfg['train']).resolve()}")
                print(f"[fold{fold}] 🚀  avvio trainer  → '{name}'")

                # 3.1) Se abbiamo già tutti gli artefatti, salta interamente
                if _completed(paths, is_ssl):
                    print(f"[fold{fold}] ⚡  Artefatti già presenti → skip training + SSL pipeline")
                else:
                    # 4) SSL: encoder once su fold0 oppure training completo
                    if is_ssl and cfg.get("train_encoder_once", False) and fold > 0:
                        fold0_ckpt = get_latest_checkpoint(_paths(cfg, name, 0)["ckpt_dir"])
                        if fold0_ckpt:
                            mdl, _ = trainer.get_resume_model_and_optimizer()
                            load_checkpoint(fold0_ckpt, mdl, None)
                            print(f"[fold{fold}] ✅ encoder da fold0 → {fold0_ckpt.resolve()}")
                        else:
                            print(f"[fold{fold}] ⚠️  encoder fold0 mancante → train completo")
                            _resume_or_train(trainer, paths, epochs)
                    else:
                        _resume_or_train(trainer, paths, epochs)

                    # 5) SSL – Feature extraction, probe, T-scaling
                    if is_ssl:
                        _ensure_ssl_artifacts(trainer, paths)

                # 6) Append a esperimenti globali (anche in caso di skip)
                latest_ckpt = get_latest_checkpoint(paths["ckpt_dir"])
                if latest_ckpt:
                    base_dir = EXP_BASE.parent.parent.resolve()
                    try:
                        rel_path = os.path.relpath(str(latest_ckpt.resolve()), str(base_dir))
                    except Exception:
                        rel_path = str(latest_ckpt.resolve())
                else:
                    rel_path = "-"
                _global_experiments_append(
                    f"| {cfg['exp_code']} | {name} | fold{fold} | {epochs} | {rel_path} |"
                )

                print(f"[fold{fold}] ✅  completato\n")

# ─── Avvio automatico in Colab ───────────────────────────────────────────
if IN_COLAB:
    launch_training(cfg)
else:
    print("⏩  Ambiente locale: esegui manualmente  launch_training(cfg) per partire.")



🚀  Modello 'rotation'  (SSL) – epochs=25
[fold0] 📂  ckpt dir      → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold0/training
[fold0] 🏷   train shards  → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/webdataset/fold0/train/patches-0000.tar
[fold0] 🚀  avvio trainer  → 'rotation'
[fold0] ── Epoch 1/25 ──
[fold0] Batch 1 | Loss 1.3691
[fold0] Batch 2 | Loss 1.3961
[fold0] Batch 3 | Loss 1.4439
[fold0] Batch 4 | Loss 1.4594
[fold0] Batch 5 | Loss 1.4723
[fold0] Batch 6 | Loss 1.4619
[fold0] Batch 7 | Loss 1.4791
[fold0] Batch 8 | Loss 1.4834
[fold0] Batch 9 | Loss 1.4735
[fold0] Batch 10 | Loss 1.4779
[fold0] Batch 11 | Loss 1.4835
[fold0] Batch 12 | Loss 1.4779
[fold0] Batch 13 | Loss 1.4768
[fold0] Batch 14 | Loss 1.4763
[fold0] Batch 15 | Loss 1.4715
[fold0] Batch 16 | Loss 1.4715
[fold0] Batch 17 | Loss 1.4695
[fold0] Batch 18 | Loss 1.4693
[fold0] Batch 19 

Extracting features: 24it [00:12,  1.92it/s]


✅ Rotation features (train) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold0/training/rotation_features_train_fold0.pt


Extracting features: 8it [00:04,  1.98it/s]


✅ Rotation features (val) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold0/training/rotation_features_val_fold0.pt
✅ Loaded 1473 keys and (1473, 4) features
📊 Class distribution:
Counter({np.str_('CHROMO'): 300, np.str_('ONCO'): 298, np.str_('ccRCC'): 297, np.str_('pRCC'): 294, np.str_('not_tumor'): 284})
✅ Filtered dataset: 1473 samples
              precision    recall  f1-score   support

      CHROMO       0.53      0.90      0.67        60
        ONCO       0.74      0.33      0.46        60
       ccRCC       0.56      0.25      0.35        59
   not_tumor       0.45      0.42      0.44        57
        pRCC       0.54      0.80      0.64        59

    accuracy                           0.54       295
   macro avg       0.56      0.54      0.51       295
weighted avg       0.57      0.54      0.51       295

Confusion Matrix:
 [[54  4  0  0  2]
 [29 20  3  1  7]
 [13  2 15 25  4]
 [ 2  

Extracting features: 24it [00:12,  1.86it/s]


✅ Rotation features (train) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold1/training/rotation_features_train_fold1.pt


Extracting features: 8it [00:05,  1.59it/s]


✅ Rotation features (val) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold1/training/rotation_features_val_fold1.pt
✅ Loaded 1506 keys and (1506, 4) features
📊 Class distribution:
Counter({np.str_('pRCC'): 309, np.str_('ONCO'): 308, np.str_('CHROMO'): 303, np.str_('ccRCC'): 293, np.str_('not_tumor'): 293})
✅ Filtered dataset: 1506 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

      CHROMO       0.68      0.90      0.77        61
        ONCO       0.50      0.44      0.47        62
       ccRCC       0.00      0.00      0.00        59
   not_tumor       0.40      0.90      0.55        58
        pRCC       0.50      0.29      0.37        62

    accuracy                           0.50       302
   macro avg       0.42      0.50      0.43       302
weighted avg       0.42      0.50      0.43       302

Confusion Matrix:
 [[55  3  0  2  1]
 [19 27  0 14  2]
 [ 5 12  0 31 11]
 [ 2  0  0 52  4]
 [ 0 12  0 32 18]]
💾 Classifier saved to /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold1/training/rotation_classifier_fold1.joblib
[fold1] ✅  completato

[fold2] 📂  ckpt dir      → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold2/training
[fold2] 🏷   

Extracting features: 23it [00:12,  1.91it/s]


✅ Rotation features (train) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold2/training/rotation_features_train_fold2.pt


Extracting features: 9it [00:04,  2.05it/s]


✅ Rotation features (val) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold2/training/rotation_features_val_fold2.pt
✅ Loaded 1452 keys and (1452, 4) features
📊 Class distribution:
Counter({np.str_('CHROMO'): 298, np.str_('ccRCC'): 297, np.str_('ONCO'): 291, np.str_('pRCC'): 283, np.str_('not_tumor'): 283})
✅ Filtered dataset: 1452 samples
              precision    recall  f1-score   support

      CHROMO       0.58      0.88      0.70        60
        ONCO       0.47      0.14      0.21        58
       ccRCC       0.00      0.00      0.00        59
   not_tumor       0.34      0.68      0.45        57
        pRCC       0.52      0.60      0.55        57

    accuracy                           0.46       291
   macro avg       0.38      0.46      0.38       291
weighted avg       0.38      0.46      0.38       291

Confusion Matrix:
 [[53  0  1  3  3]
 [26  8  0 20  4]
 [ 8  5  0 34 12]
 [ 5  

Extracting features: 23it [00:12,  1.83it/s]


✅ Rotation features (train) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold3/training/rotation_features_train_fold3.pt


Extracting features: 8it [00:04,  1.88it/s]


✅ Rotation features (val) saved → /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/dataset_9f30917e/experiments/20250723180604/rotation/fold3/training/rotation_features_val_fold3.pt
✅ Loaded 1466 keys and (1466, 4) features
📊 Class distribution:
Counter({np.str_('CHROMO'): 300, np.str_('ONCO'): 298, np.str_('pRCC'): 294, np.str_('ccRCC'): 290, np.str_('not_tumor'): 284})
✅ Filtered dataset: 1466 samples
              precision    recall  f1-score   support

      CHROMO       0.54      0.95      0.69        60
        ONCO       0.12      0.02      0.03        60
       ccRCC       0.51      0.36      0.42        58
   not_tumor       0.39      0.53      0.45        57
        pRCC       0.52      0.56      0.54        59

    accuracy                           0.48       294
   macro avg       0.42      0.48      0.43       294
weighted avg       0.42      0.48      0.43       294

Confusion Matrix:
 [[57  0  1  2  0]
 [32  1  6 12  9]
 [ 9  5 21 15  8]
 [ 7  



[fold0] Batch 1 | Loss 4.8367
[fold0] Batch 2 | Loss 4.8761
[fold0] Batch 3 | Loss 4.8965
[fold0] Batch 4 | Loss 4.8769
[fold0] Batch 5 | Loss 4.8648
[fold0] Batch 6 | Loss 4.8552
[fold0] Batch 7 | Loss 4.8368
[fold0] Batch 8 | Loss 4.8172
[fold0] Batch 9 | Loss 4.7884
[fold0] Batch 10 | Loss 4.7861
[fold0] Batch 11 | Loss 4.7845
[fold0] Batch 12 | Loss 4.7922
[fold0] Batch 13 | Loss 4.7979
[fold0] Batch 14 | Loss 4.8059
[fold0] Batch 15 | Loss 4.8073
[fold0] Batch 16 | Loss 4.8070
[fold0] Batch 17 | Loss 4.8037
[fold0] Batch 18 | Loss 4.8101
[fold0] Batch 19 | Loss 4.8181
[fold0] Batch 20 | Loss 4.8185
[fold0] Batch 21 | Loss 4.8181
[fold0] Batch 22 | Loss 4.8177
[fold0] Batch 23 | Loss 4.8164
[fold0] Train → Loss 4.8164
[fold0] ⏱  173.8s

[fold0] ── Epoch 2/25 ──
[fold0] Batch 1 | Loss 4.7789
[fold0] Batch 2 | Loss 4.8016
[fold0] Batch 3 | Loss 4.8104
[fold0] Batch 4 | Loss 4.8027
[fold0] Batch 5 | Loss 4.7956
[fold0] Batch 6 | Loss 4.7885
[fold0] Batch 7 | Loss 4.7795
[fold0] Batch 

In [None]:
from pathlib import Path

# Configura esperimento
EXP_CODE = ""
ROOT = Path("/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project")
EXP_DIR = ROOT / "data/processed/dataset_9f30917e/experiments" / EXP_CODE

# Modelli
SSL_MODELS = {"simclr", "rotation", "moco_v2", "jepa"}
SL_MODELS  = {"supervised", "transfer"}
ALL_MODELS = SSL_MODELS | SL_MODELS
N_FOLDS = 2

# File attesi per tipo di modello
FILES_PER_MODEL = {
    "SSL": [
        "{model}_bestepoch*_fold{i}.pt",
        "{model}_train_log_fold{i}.md",
        "{model}_train_valid_loss_fold{i}.json",
    ],
    "SL": [
        "{model}_bestepoch*_fold{i}.pt",
        "{model}_train_log_fold{i}.md",
        "{model}_train_valid_loss_fold{i}.json",
    ],
}

missing = []

for model in sorted(ALL_MODELS):
    model_type = "SSL" if model in SSL_MODELS else "SL"
    folds = [0] if model_type == "SSL" else list(range(N_FOLDS))

    for i in folds:
        train_dir = EXP_DIR / model / f"fold{i}" / "training"
        for pattern in FILES_PER_MODEL[model_type]:
            pattern_path = pattern.format(model=model, i=i)
            matched = list(train_dir.glob(pattern_path))  # usa direttamente il pattern
            if not matched:
                missing.append(str(train_dir / pattern_path))

# Stampa risultato
print("📂 Verifica artefatti TRAINING\n")
if not missing:
    print("✅ Tutti i file richiesti sono presenti.")
else:
    print(f"❌ Mancano {len(missing)} artefatti TRAINING:\n")
    for m in missing:
        print(" •", m)
