In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

!pip install --quiet torch torchvision webdataset tqdm pillow scikit-learn joblib


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
from tqdm import tqdm
import csv, logging, joblib, sys, importlib, yaml, torch, numpy as np
from pathlib import Path
from collections import defaultdict, Counter
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score

# 📁 Configurazione percorso progetto
config_path = Path('/content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/config/training.yaml')
with config_path.open('r') as f:
    cfg = yaml.safe_load(f)

colab_root = Path(cfg['env_paths']['colab'])
local_root = Path(cfg['env_paths']['local'])
PROJECT_ROOT = colab_root if colab_root.exists() else local_root

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src"))


trainer_modules = [
    "trainers.simclr",
    "trainers.moco_v2",
    "trainers.rotation",
    "trainers.jigsaw",
    "trainers.supervised",
    "trainers.transfer",
]
for m in trainer_modules:
    if m in sys.modules:
        importlib.reload(sys.modules[m])
    else:
        importlib.import_module(m)
from utils.training_utils import TRAINER_REGISTRY, load_checkpoint
# Normalizza i path dei dati
for split in ['train','val','test']:
    rel = cfg['data'].get(split)
    if rel:
        cfg['data'][split] = str(PROJECT_ROOT / rel)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("EVAL")

# 📁 Setup esperimento specifico
experiment_dir = PROJECT_ROOT / "data/processed2/dataset_9f30917e/experiments/prova"


In [3]:
# ─── Helpers ───────────────────────────────────────────────────────────────
def extract_patient_id(key: str) -> str:
    for p in key.split("_"):
        if p.startswith("HP") or p.startswith("H"):
            return p
    return "UNKNOWN"

def extract_label_from_key(key: str) -> str:
    return "not_tumor" if key.startswith("not_tumor") else key.split("_")[0]

def compute_metrics(keys, y_pred, le):
    # ground-truth per paziente
    all_labels = defaultdict(list)
    for k in keys:
        all_labels[extract_patient_id(k)].append(extract_label_from_key(k))
    true_labels = {}
    for pid, labs in all_labels.items():
        tumor = [l for l in labs if l!="not_tumor"]
        if len(set(tumor))==1:
            true_labels[pid]=tumor[0]
    # votazioni
    preds = defaultdict(list)
    for k,p in zip(keys, y_pred):
        pid=extract_patient_id(k)
        if le.classes_[p]!="not_tumor":
            preds[pid].append(p)
    y_true, y_maj, valid = [], [], []
    for pid,votes in preds.items():
        if pid in true_labels and votes:
            gt = true_labels[pid]
            maj = Counter(votes).most_common(1)[0][0]
            y_true.append(le.transform([gt])[0])
            y_maj.append(maj)
            valid.append(pid)
    if not y_true:
        raise RuntimeError("Nessun paziente valutabile")
    report = classification_report(y_true, y_maj, target_names=[c for c in le.classes_ if c!="not_tumor"])
    cm     = confusion_matrix(y_true, y_maj)
    acc    = np.mean(np.array(y_true)==np.array(y_maj))
    f1     = f1_score(y_true, y_maj, average="macro")
    prec   = precision_score(y_true, y_maj, average="macro")
    rec    = recall_score(y_true, y_maj, average="macro")
    return dict(
        y_true=y_true, y_maj=y_maj, valid=valid,
        report=report, cm=cm,
        metrics=(acc,f1,prec,rec),
    )

def write_md_log(save_dir: Path, model_name: str, cm, report: str, valid: list[str], metrics: tuple[float,float,float,float], y_true: list[int], y_maj: list[int], le):
    acc,f1,prec,rec = metrics
    md = save_dir/"evals_log.md"
    with open(md, "w") as f:
        f.write(f"# 🧠 Modello: {model_name}\n\n")
        f.write("## 📊 Risultati Majority Voting (paziente-level)\n```text\n")
        f.write(report)
        f.write("\n```\n\n")
        f.write("## 📉 Confusion Matrix\n```text\n")
        f.write(str(cm))
        f.write("\n```\n\n")
        f.write(f"✅ Totale pazienti classificati: {len(valid)}\n\n")
        # sezione per-paziente
        f.write("## 🧾 Predizione per paziente\n```text\n")
        for pid, t_enc, p_enc in zip(valid, y_true, y_maj):
            true_lbl = le.inverse_transform([t_enc])[0]
            pred_lbl = le.inverse_transform([p_enc])[0]
            f.write(f"Paziente {pid}: predetto = {pred_lbl} | reale = {true_lbl}\n")
        f.write("```\n\n")
        f.write("## 📈 Metriche sintetiche\n")
        f.write(f"- Accuracy        : {acc:.4f}\n")
        f.write(f"- Macro F1        : {f1:.4f}\n")
        f.write(f"- Macro Precision : {prec:.4f}\n")
        f.write(f"- Macro Recall    : {rec:.4f}\n")

def append_summary_csv(common_csv:Path, model_name:str, metrics):
    if not common_csv.exists():
        with open(common_csv, "w", newline="") as f:
            writer=csv.writer(f)
            writer.writerow(["Model","Accuracy","Macro F1","Macro Precision","Macro Recall","N_Patients"])
    acc,f1,prec,rec = metrics
    # N_Patients è già nel report: la lunghezza di valid
    writer=csv.writer(open(common_csv,"a",newline=""))
    writer.writerow([model_name,acc,f1,prec,rec])


In [4]:
# ─── Evaluation ─────────────────────────────────────────────────────────────
from utils.training_utils import TRAINER_REGISTRY, load_checkpoint
from trainers.extract_features import extract_features
import webdataset as wds
import torchvision.transforms as T
from PIL import Image

def make_loader(test_path):
    ds = (
        wds.WebDataset(
            test_path,
            shardshuffle=False,
            handler=wds.warn_and_continue,
            empty_check=False,            # evita generator vuoti
        )
        .decode("pil")
        .map(lambda s: {
            # se non trova immagini, next torna None → fallirà convert, ma almeno non StopIteration
            "img": T.ToTensor()(
                next((v for v in s.values() if isinstance(v, Image.Image)), None)
                .convert("RGB")
            ),
            # se non trova key *.jpg, torna stringa vuota
            "key": s["__key__"]
                   + "."
                   + next((k for k in s.keys() if k.endswith(".jpg")), "")
        })
    )
    return torch.utils.data.DataLoader(
        ds,
        batch_size=64,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )


def evaluate_model(trainer, test_path: str, save_dir: Path, model_name: str, ssl: bool):
    # 1) carica il checkpoint migliore
    ckpts = sorted(save_dir.glob(f"*Trainer_best_epoch*.pt"))
    ckpt  = ckpts[-1]

    # 2) ricarica i pesi
    if ssl:
        import torch.nn as nn
        # --- SimCLR / MoCo: encoder + projector
        if hasattr(trainer, "encoder") and hasattr(trainer, "projector"):
            seq = nn.Sequential(trainer.encoder, trainer.projector)
            load_checkpoint(ckpt, model=seq)
            trainer.encoder, trainer.projector = seq[0], seq[1]
        # --- Jigsaw: encoder + head
        elif hasattr(trainer, "encoder") and hasattr(trainer, "head"):
            seq = nn.Sequential(trainer.encoder, trainer.head)
            load_checkpoint(ckpt, model=seq)
            trainer.encoder, trainer.head = seq[0], seq[1]
        # --- Rotation: usa direttamente trainer.model
        else:
            load_checkpoint(ckpt, model=trainer.model)
    else:
        # supervised/transfer
        load_checkpoint(ckpt, model=trainer.model)

    # 3) carica il classificatore per ssl
    if ssl:
        clf_data = joblib.load(save_dir / f"{model_name}_classifier.joblib")
        clf, le = clf_data["model"], clf_data["label_encoder"]
        # sceglie feature extractor: encoder se presente, altrimenti model
        if hasattr(trainer, "encoder"):
            feat_mod = trainer.encoder
        else:
            feat_mod = trainer.model
    else:
        clf, le = None, trainer.label_encoder
        feat_mod  = trainer.model

    feat_mod = feat_mod.to(trainer.device)

    # 4) estrai le predizioni
    loader = make_loader(test_path)
    if ssl:
        feats = extract_features(feat_mod, loader, trainer.device)
        X, keys = feats["features"].numpy(), feats["keys"]
        y_pred = clf.predict(X)
    else:
        y_pred, keys = [], []
        for batch in loader:
            imgs = batch["img"].to(trainer.device)
            logits = feat_mod(imgs)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            y_pred += preds.tolist()
            keys   += batch["key"]

    # 5) metriche, log e CSV
    out = compute_metrics(keys, y_pred, le)
    write_md_log(save_dir, model_name, out["cm"], out["report"], out["valid"], out["metrics"], out["y_true"], out["y_maj"], le)
    append_summary_csv(save_dir.parent / "evaluation_summary_all_models.csv", model_name, out["metrics"])
    print(f"✔️ Finished eval for {model_name}")


In [5]:
# ─── Main loop ─────────────────────────────────────────────────────────────
run_model = cfg.get("run_model","all").lower()
models = cfg["models"].items() if run_model=="all" else [(run_model,cfg["models"][run_model])]
for name, m_cfg in models:
    if name not in TRAINER_REGISTRY:
        logger.warning(f"Trainer '{name}' non trovato, skip.")
        continue
    trainer = TRAINER_REGISTRY[name](m_cfg, cfg["data"])
    test_p = str(cfg["data"]["test"])
    model_dir = experiment_dir/name
    if not model_dir.exists():
        logger.warning(f"Dir {model_dir} mancante, skip.")
        continue
    is_ssl = (name not in ["supervised","transfer"])
    evaluate_model(trainer, test_p, model_dir, name, ssl=is_ssl)

Extracting features: 8it [01:06,  8.27s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


✔️ Finished eval for simclr


Extracting features: 8it [00:54,  6.79s/it]


✔️ Finished eval for moco_v2


Extracting features: 8it [00:53,  6.65s/it]


✔️ Finished eval for rotation


Extracting features: 8it [00:52,  6.51s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


✔️ Finished eval for jigsaw
✔️ Finished eval for supervised
✔️ Finished eval for transfer
