# **PROGETTO: CARCINOMA CLASSIFICATION**

## *Librerie + import del dataset*

In [None]:
import copy
import gc
import os, re
from glob import glob
import json
import textwrap
import math
import random
from pathlib import Path

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from torchmetrics import ConfusionMatrix
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassConfusionMatrix,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

from densatio import CustomPooling2d
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0))

In [None]:
SEED = 42
BATCH_SIZE = 16
NUM_CLASSES = 3
NUM_EPOCHS = 250

LABEL_MAP = {"not detectable": 0, "benign": 1, "malignant": 2}

In [None]:
img_dir = Path("dati\data\dataset")
labels_path = Path("dati\data\metadata.csv")

df = pd.read_csv(labels_path, sep=",")

df["label"] = df["malignant"].map(LABEL_MAP).astype(int)

print(df[["id", "malignant", "label"]].head())
print(df["label"].value_counts())

Dopo la mappatura delle etichette: ogni riga ha l’id dell’immagine, la label testuale originale e la versione numerica. In totale ho 36 immagini ‘not detectable’ (classe 0), 14 ‘benign’ (classe 1) e 12 ‘malignant’ (classe 2), quindi il dataset è un po’ sbilanciato verso la classe 0

In [None]:
def seed_everything(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(42)

In [None]:
def id2loc(id_):
    """Restituisce il path dell'immagine a partire dall'id."""
    return str(Path(img_dir) / f"img_{id_}.png")

def id2img(id_):
    """Carica l'immagine (RGB) dato l'id."""
    return Image.open(id2loc(id_)).convert("RGB")

def img2arr(img):
    """Converte una PIL Image in array numpy."""
    return np.asarray(img)

def id2arr(id_):
    """Carica l'immagine da id e la restituisce come array numpy."""
    return img2arr(id2img(id_))

def show_images(df, n_fig=10,random=True, seed=42):

    """
    Mostra un grid di immagini per classe (3 righe = 3 classi, n_fig colonne).

    - df: DataFrame con colonne id_col, label_col
    - n_fig: numero massimo di immagini per classe
    - random: se True, shuffle per classe
    - seed: random_state per la riproducibilità
    """

    LABEL_MAP = {0: "not detectable", 1:"benign", 2:"malignant"}

    dfs = {
        0: df[df["label"] == 0].reset_index(drop=True),
        1: df[df["label"] == 1].reset_index(drop=True),
        2: df[df["label"] == 2].reset_index(drop=True),
    }

    if random:
        for k in dfs:
            dfs[k] = dfs[k].sample(frac=1, random_state=seed).reset_index(drop=True)

    # n_fig non può superare la classe più piccola:
    # limito il numero di immagini mostrate al minimo tra le classi,
    # così ogni classe è rappresentata con lo stesso numero di esempi.

    n_fig_eff = min(n_fig, len(dfs[0]), len(dfs[1]), len(dfs[2]))
    if n_fig_eff < n_fig:
        print(f"Ridotto da {n_fig} a {n_fig_eff}")
    n_fig = n_fig_eff

    fig, axes = plt.subplots(3, n_fig, figsize=(1.6*n_fig, 6), constrained_layout=True)

    for col in range(n_fig):
        for row, lab in enumerate([0, 1, 2]):
            img_id = dfs[lab].iloc[col]["id"]
            axes[row, col].imshow(id2img(img_id))
            axes[row, col].set_xticks([])
            axes[row, col].set_yticks([])
            axes[row, col].set_frame_on(True)

        axes[0, 0].set_ylabel(LABEL_MAP[0], fontsize=12, rotation=90, labelpad=20)
        axes[1, 0].set_ylabel(LABEL_MAP[1], fontsize=12, rotation=90, labelpad=20)
        axes[2, 0].set_ylabel(LABEL_MAP[2], fontsize=12, rotation=90, labelpad=20)

    plt.show()

show_images(df, n_fig=10, random=True, seed=SEED)

In [None]:
dims = pd.DataFrame(
    [id2arr(df.iloc[i]["id"]).shape for i in range(len(df))],
    columns=["height", "width", "channels"]  # np.asarray(img) → (H, W, C)
)
print(dims.describe())

I numeri riassumono le dimensioni delle 62 immagini: tutte hanno 3 canali (quindi sono a colori, RGB) e le altezze/larghezze sono piuttosto variabili, con valori che vanno da circa 270 a 900 pixel. La media è intorno ai 465 pixel per lato, ma la deviazione standard alta indica che non sono tutte della stessa dimensione. Questo conferma che ha senso ridimensionarle a una risoluzione fissa (ad esempio 224×224) prima di usarle nel modello.

## *Data augmentation*

Si parte da un DataFrame contenente ID e label. Per ciascun campione:

1. viene costruito il percorso dell’immagine a partire dall’ID;

2. l’immagine viene caricata da disco e convertita in RGB;

3. viene applicata la pipeline di trasformazioni: in training include anche data augmentation, mentre in validation/test prevede solo resize e normalizzazione;

4. l’output viene convertito in un tensore nel formato (C, H, W) e associato alla label intera, pronta per il calcolo della loss.

In questo modo, il modello riceve input uniformi per dimensione e normalizzati con mean/std calcolate sul training set; inoltre, durante l’addestramento, osserva varianti realistiche dello stesso campione grazie alle augmentations, migliorando la generalizzazione.

La funzione *_denorm* consente di invertire la normalizzazione e ottenere immagini visualizzabili, utili per verificare che la pipeline di preprocessing stia producendo gli input attesi.

Le augmentations adottate (rumore, flip, piccole rotazioni, variazioni di luminosità e contrasto) sono coerenti con pratiche consolidate nella letteratura sul deep learning per imaging medico, in particolare su lesioni cutanee.

In [None]:
def compute_mean_std(train_df, img_dir, img_size=224):
    """
    Calcola mean/std SOLO sul train, in [0,1] (per canale RGB).
    Poi queste mean/std vanno passate a tutti i dataset (train/val/test).
    """
    means, stds = [], []

    for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
        path = os.path.join(img_dir, f"img_{row['id']}.png")
        img = cv2.imread(path)

        if img is None:
            raise FileNotFoundError(f"Not found: {path}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (img_size, img_size))
        img = img.astype(np.float32) / 255.0  # [0,1]

        means.append(img.mean(axis=(0, 1)))
        stds.append(img.std(axis=(0, 1)))

    mean = np.mean(means, axis=0).tolist()
    std  = np.mean(stds, axis=0).tolist()
    return mean, std

In [None]:
class CustomDataset(Dataset):
    """
    Dataset personalizzato per immagini classificate con supporto ad augmentations Albumentations.

    Carica immagini da directory usando un DataFrame con colonne 'id' e 'label'.

    Args:
        df: DataFrame con colonne 'id' (int) e 'label' (int).
        img_dir: Path alla directory contenente img_{id}.png.
        img_size: Dimensione target (quadrata, default 224).
        is_train: Modalità training (abilita augmentations).
        augment: Applica augmentations se is_train=True.
        resize: Se True -> Resize diretto a (img_size,img_size).
                Se False -> Preserva aspect ratio con LongestMaxSize + PadIfNeeded,
                            MA output sempre (img_size,img_size).
        mean: Lista mean per normalizzazione (3 valori per RGB).
        std: Lista std per normalizzazione (3 valori per RGB).
    """

    def __init__(
        self,
        df,
        img_dir,
        img_size=224,
        is_train=True,   # train/eval mode
        augment=True,    # augmentation on/off
        resize=True,     # resize on/off (see docstring)
        mean=None,
        std=None
    ):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.img_size = img_size

        self.is_train = is_train
        self.augment = augment
        self.resize = resize

        if mean is None or std is None:
            raise ValueError("Fornire mean e std calcolati SOLO sul TRAIN set con compute_mean_std().")

        self.mean = mean
        self.std = std

        # Albumentations pipelines
        self.transform = self._build_tf(img_size)
        self.orig_transform = self._build_orig_tf(img_size)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_ = self.df.loc[idx, "id"]
        label = int(self.df.loc[idx, "label"])

        path = os.path.join(self.img_dir, f"img_{id_}.png")
        img = cv2.imread(path)
        if img is None:
            raise FileNotFoundError(f"Image not found: {path}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # numpy uint8 (H,W,3)

        # main version (aug + final resize/pad + float + norm)
        img_aug = self.transform(image=img)["image"]
        img_aug = self._to_tensor(img_aug)

        y = torch.tensor(label, dtype=torch.long)
        return img_aug, y

    def _final_resize_ops(self, img_size):
        """
        Always produce (img_size, img_size):
        - if resize=True: direct resize
        - if resize=False: keep aspect ratio + pad to square
        """
        if self.resize:
            return [A.Resize(img_size, img_size)]
        else:
            return [
                A.LongestMaxSize(img_size),
                A.PadIfNeeded(img_size, img_size, border_mode=cv2.BORDER_CONSTANT),
            ]

    def _build_orig_tf(self, img_size):
        """Pipeline per immagine originale: sempre size fissa (utile per debug)."""
        tf = []
        tf.extend(self._final_resize_ops(img_size))
        return A.Compose(tf)

    def _build_tf(self, img_size):
        """
        Pipeline principale per training.

        Sequenza:
        - Augmentations (se train e augment=True)
        - Final resize/pad (SEMPRE: output img_size x img_size)
        - ToFloat [0,1]
        - Normalize con mean/std forniti
        """
        tf = []

        # 1) Augmentations (solo train)
        if self.is_train and self.augment:
            tf.append(self._aug())

        # 2) Final resize/pad (sempre)
        tf.extend(self._final_resize_ops(img_size))

        # 3) Float + Normalize
        tf.append(A.ToFloat(max_value=255.0))  # [0,1]
        tf.append(A.Normalize(mean=self.mean, std=self.std, max_pixel_value=1.0))

        return A.Compose(tf)

    def _to_tensor(self, img_np):
        """Converte numpy (H,W,C) -> torch (C,H,W) float32."""
        if img_np.dtype != np.float32:
            img_np = img_np.astype(np.float32)
        t = torch.from_numpy(img_np).permute(2, 0, 1).contiguous()
        return t

    def _aug(self):
        """Augmentations randomizzate."""
        families = [
            A.OneOf([
                A.GaussNoise(noise_scale_factor=0.1, p=1.0),
                A.MultiplicativeNoise(multiplier=(0.85, 1.15), per_channel=True),
                A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)),
            ], p=1.0),

            A.OneOf([
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
            ], p=1.0),

            A.OneOf([
                A.RandomBrightnessContrast(0.15, 0.15),
                A.RandomGamma(gamma_limit=(80, 120)),
                A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8)),
            ], p=1.0),

            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 7)),
                A.MotionBlur(blur_limit=5),
                A.Sharpen(alpha=(0.1, 0.35), lightness=(0.9, 1.1)),
            ], p=1.0),

            A.OneOf([
                A.Rotate(limit=15, border_mode=cv2.BORDER_CONSTANT),
                A.ShiftScaleRotate(
                    shift_limit=0.04,
                    scale_limit=0.06,
                    rotate_limit=0,
                    border_mode=cv2.BORDER_CONSTANT
                ),
                A.GridDistortion(num_steps=5, distort_limit=0.2),
            ], p=1.0),

            A.OneOf([
                A.CoarseDropout(num_holes_range=(1, 8), hole_height_range=(8, 24), hole_width_range=(8, 24), fill=0, p=1.0),
            ], p=1.0),
        ]

        return A.OneOf([
            A.SomeOf(families, n=1, replace=False, p=1.0),
            A.SomeOf(families, n=2, replace=False, p=1.0),
        ], p=0.8)

In [None]:
# DENORMALIZZAZIONE
def denorm(t, mean, std):
    # t: torch tensor (C,H,W)
    mean = torch.tensor(mean, device=t.device, dtype=t.dtype).view(-1, 1, 1)
    std  = torch.tensor(std,  device=t.device, dtype=t.dtype).view(-1, 1, 1)
    return t * std + mean

In [None]:
train_ratio = 0.70
test_ratio = 0.30

# Primo split: train vs (val+test)
df_train, df_temp = train_test_split(
    df,
    test_size=test_ratio,
    random_state=SEED,
    stratify=df["label"] # Mantiene distribuzione classi
)

# Secondo split: temp (30%) -> val (15%) + test (15%)
val_ratio_inside_temp = 0.5  # 50% di df_temp

df_val, df_test = train_test_split(
    df_temp,
    test_size=val_ratio_inside_temp,
    random_state=SEED,
    stratify=df_temp["label"]
)

print("Train:", len(df_train))
print("Val:", len(df_val))
print("Test:", len(df_test))

Il dataset totale contiene 62 immagini, suddivise così:

- Train: 43 immagini.
- Validation: 9 immagini.
- Test: 10 immagini.

In [None]:
# mean/std
mean, std = compute_mean_std(df_train, img_dir, img_size=224)

# dataset
train_ds = CustomDataset(df_train, img_dir, img_size=224, is_train=True, augment=True,  resize=True, mean=mean, std=std)
val_ds = CustomDataset(df_val, img_dir, img_size=224,is_train=False, augment=False, resize=True, mean=mean, std=std)
test_ds = CustomDataset(df_test, img_dir, img_size=224,is_train=False, augment=False, resize=True, mean=mean, std=std)

I valori negativi osservati sono successivi alla normalizzazione.
Il controllo del range è stato effettuato prima della Normalize, dove le immagini risultano correttamente comprese tra 0 e 1.

In [None]:
id_ = train_ds.df.loc[0, "id"]
path = os.path.join(train_ds.img_dir, f"img_{id_}.png")

img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

tf_pre = []

# pipeline pre-Normalize
if train_ds.is_train and train_ds.augment:
    tf_pre.append(train_ds._aug()) # Augmentations
if train_ds.resize:
    tf_pre.append(A.Resize(train_ds.img_size, train_ds.img_size)) # Resize
tf_pre.append(A.ToFloat(max_value=255.0)) # [0,1]

# Applico la pipeline e ottengo l'immagine preprocessata
img_pre = A.Compose(tf_pre)(image=img)["image"]
print("Prima di applicare Normalize:")
print(img_pre.dtype, img_pre.min(), img_pre.max())

In [None]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def plot_aug_images_by_class_multi(
    dataset,
    n_per_class=3,
    k=4,
    seed=42,
    show_id=True,
    pick="random"
):
    """
    Mosaico per categoria con più immagini per classe.

    Righe: 3 classi * n_per_class
    Colonne: 1 originale + k augmentazioni
    """

    random.seed(seed)

    label_names = {
        0: "not detectable",
        1: "benign",
        2: "malignant"
    }

    # resize per visualizzazione (NO normalize)
    resize_tf = A.Compose([A.Resize(dataset.img_size, dataset.img_size)])

    # blocco augmentation
    aug_block = dataset._aug_viz() if hasattr(dataset, "_aug_viz") else dataset._aug()
    aug_tf = A.Compose([aug_block, A.Resize(dataset.img_size, dataset.img_size)])

    # wrapper per titoli più leggibili
    def wrap(txt, width=18):
        return textwrap.fill(txt, width=width)

    # scegli immagini per classe
    chosen = []  # (label, global_idx)
    for lab in [0, 1, 2]:
        sub = dataset.df[dataset.df["label"] == lab]
        if len(sub) == 0:
            raise ValueError(f"Nessuna immagine trovata per la classe {lab}")

        n_eff = min(n_per_class, len(sub))
        if pick == "first":
            sub_sel = sub.head(n_eff)
        else:
            sub_sel = sub.sample(n=n_eff, random_state=seed)

        for gi in sub_sel.index.tolist():
            chosen.append((lab, gi))

    n_rows = len(chosen)
    n_cols = k + 1

    fig, axes = plt.subplots(
        n_rows, n_cols,
        figsize=(3.1 * n_cols, 2.9 * n_rows),
        constrained_layout=False
    )

    fig.subplots_adjust(top=0.88, hspace=0.25, wspace=0.05)

    fig.suptitle(
        f"Data Augmentation per categoria\n",
        fontsize=15
    )

    if n_rows == 1:
        axes = np.expand_dims(axes, axis=0)

    for r, (lab, global_idx) in enumerate(chosen):
        id_ = dataset.df.loc[global_idx, "id"]
        path = os.path.join(dataset.img_dir, f"img_{id_}.png")

        img = cv2.imread(path)
        if img is None:
            raise FileNotFoundError(f"Not found: {path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        orig = resize_tf(image=img)["image"]

        # ---- ORIGINALE ----
        title_orig = f"{label_names[lab]} Originale"

        axes[r, 0].imshow(orig)
        axes[r, 0].axis("off")
        axes[r, 0].set_title(wrap(title_orig), fontsize=10)

        # ---- AUGMENTAZIONI ----
        for c in range(1, n_cols):
            augm = aug_tf(image=img)["image"]
            axes[r, c].imshow(augm)
            axes[r, c].axis("off")
            axes[r, c].set_title(wrap(f"Augmentazione {c}"), fontsize=9)

    plt.show()

plot_aug_images_by_class_multi(
    train_ds,
    n_per_class=1,
    k=4,
    seed=42,
    show_id=True,
    pick="random"
)

In [None]:
# Catalogo augmentation
def build_aug_catalog():
    """
    Ritorna lista di (nome, transform) da applicare SINGOLARMENTE.
    Niente Normalize/ToTensor: serve per VISUALIZZARE.
    """
    catalog = []

    # --- NOISE ---
    catalog += [
        ("GaussNoise", A.GaussNoise(var_limit=(5.0, 35.0), p=1.0)),
        ("MultiplicativeNoise", A.MultiplicativeNoise(multiplier=(0.85, 1.15), per_channel=True, p=1.0)),
        ("ISONoise", A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=1.0)),
    ]

    # --- FLIP / ROT ---
    catalog += [
        ("HorizontalFlip", A.HorizontalFlip(p=1.0)),
        ("VerticalFlip", A.VerticalFlip(p=1.0)),
        ("RandomRotate90", A.RandomRotate90(p=1.0)),
    ]

    # --- COLOR/CONTRAST ---
    catalog += [
        ("BrightnessContrast", A.RandomBrightnessContrast(0.15, 0.15, p=1.0)),
        ("RandomGamma", A.RandomGamma(gamma_limit=(80, 120), p=1.0)),
        ("CLAHE", A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0)),
    ]

    # --- BLUR / SHARPEN ---
    catalog += [
        ("GaussianBlur", A.GaussianBlur(blur_limit=(3, 7), p=1.0)),
        ("MotionBlur", A.MotionBlur(blur_limit=5, p=1.0)),
        ("Sharpen", A.Sharpen(alpha=(0.1, 0.35), lightness=(0.9, 1.1), p=1.0)),
    ]

    # --- GEOMETRIC ---
    catalog += [
        ("Rotate15", A.Rotate(limit=15, border_mode=cv2.BORDER_CONSTANT, p=1.0)),
        ("ShiftScaleRotate", A.ShiftScaleRotate(
            shift_limit=0.04, scale_limit=0.06, rotate_limit=0,
            border_mode=cv2.BORDER_CONSTANT, p=1.0
        )),
        ("GridDistortion", A.GridDistortion(num_steps=5, distort_limit=0.2, p=1.0)),
    ]

    # --- DROPOUT ---
    catalog += [
        ("CoarseDropout", A.CoarseDropout(max_holes=8, max_height=24, max_width=24, fill_value=0, p=1.0)),
    ]

    return catalog


# Prende un immagine da CustomDataset
def get_rgb_images_from_dataset(dataset, n_images=2, seed=42, pick="random"):
    """
    ritorna lista di immagini RGB uint8 lette da cv2.
    """
    random.seed(seed)

    n_images = min(n_images, len(dataset))
    if n_images <= 0:
        raise ValueError("Dataset vuoto.")

    if pick == "first":
        idxs = list(range(n_images))
    else:
        idxs = random.sample(range(len(dataset)), k=n_images)

    imgs = []
    ids = []
    for idx in idxs:
        id_ = dataset.df.loc[idx, "id"]
        path = os.path.join(dataset.img_dir, f"img_{id_}.png")
        img = cv2.imread(path)
        if img is None:
            raise FileNotFoundError(f"Not found: {path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        imgs.append(img)
        ids.append(id_)
    return imgs, ids


def _wrap(s, width=16):
    return "\n".join(textwrap.wrap(str(s), width=width))


# Mosaico: originale + tutte le augmentation (1 per tipo)
def mosaic_all_augs_for_images(
    images_rgb,
    ids=None,
    img_size=224,
    seed=42,
    n_cols=5,
    title_prefix="",
    save_dir=None,
    dpi=300,
    highlight_original=True
):
    """
    Per ogni immagine crea una figura con mosaico:
    - 1 tile Originale
    - 1 tile per ogni augmentation nel catalogo (una per tipo)
    """
    catalog = build_aug_catalog()
    resize_tf = A.Compose([A.Resize(img_size, img_size)])

    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    for i, img in enumerate(images_rgb):
        if img.dtype != np.uint8:
            img = img.astype(np.uint8)

        orig = resize_tf(image=img)["image"]

        variants = [("Original (baseline)", orig)]
        for j, (name, tform) in enumerate(catalog):
            # seed riproducibile per ogni tile
            np.random.seed(seed + j)

            aug_tf = A.Compose([tform, A.Resize(img_size, img_size)])
            aug = aug_tf(image=img)["image"]
            variants.append((name, aug))

        n = len(variants)
        n_rows = math.ceil(n / n_cols)

        fig = plt.figure(figsize=(3.0 * n_cols, 3.0 * n_rows))
        gs = fig.add_gridspec(n_rows, n_cols, wspace=0.03, hspace=0.18)

        # Titolo + sottotitolo
        id_txt = f" | id={ids[i]}" if ids is not None and i < len(ids) else ""
        fig.suptitle(
            f"{title_prefix}Augmentation catalog — image #{i+1}{id_txt}",
            fontsize=16, y=0.98
        )
        fig.text(
            0.5, 0.95,
            f"Original + {len(catalog)} transforms | img_size={img_size} | seed={seed}",
            ha="center", va="center", fontsize=11
        )

        for idx in range(n_rows * n_cols):
            r, c = divmod(idx, n_cols)
            ax = fig.add_subplot(gs[r, c])
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_facecolor("white")

            # bordo sottile uniforme
            for sp in ax.spines.values():
                sp.set_visible(True)
                sp.set_linewidth(0.6)

            if idx < n:
                name, im = variants[idx]
                ax.imshow(im)
                ax.set_title(_wrap(name, width=16), fontsize=10, pad=6)

                # evidenzia originale
                if highlight_original and idx == 0:
                    for sp in ax.spines.values():
                        sp.set_linewidth(1.4)
            else:
                ax.axis("off")

        plt.tight_layout(rect=[0, 0, 1, 0.92])

        if save_dir is not None:
            out_path = os.path.join(save_dir, f"aug_catalog_img{i+1}.png")
            plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
            print("Salvato:", out_path)

        plt.show()

imgs, ids = get_rgb_images_from_dataset(train_ds, n_images=1, seed=42, pick="random")
mosaic_all_augs_for_images(imgs, ids=ids, img_size=224, seed=42, n_cols=5)

## *Funzioni di pooling*

### **Max pooling**

In [None]:
def pool_max(window: torch.Tensor, **kwargs) -> torch.Tensor:
    return torch.amax(window, dim=(-1, -2), keepdim=True)

### **Learnable Max - Average Pooling**

In [None]:
def pooling_max_avg_2d(patches: torch.Tensor, w: torch.Tensor, g: torch.Tensor, epsilon=1e-8, **kwargs) -> torch.Tensor:

    device = patches.device
    dtype  = patches.dtype
    w   = w.to(device=device, dtype=dtype)
    g   = g.to(device=device, dtype=dtype)

    # (FIX) accetta w/g sia scalari/1D che 2D (C_eff,K)
    if w.dim() == 0:
        w = w.view(1, 1)
    elif w.dim() == 1:
        w = w.view(1, -1)
    if g.dim() == 0:
        g = g.view(1, 1)
    elif g.dim() == 1:
        g = g.view(1, -1)
    assert w.dim() == 2 and w.size(1) == 2, f"w must be (C_eff,2), got {tuple(w.shape)}"
    assert g.dim() == 2 and g.size(1) == 1, f"g must be (C_eff,1) or (1,1), got {tuple(g.shape)}"

    # patches: (B,Ho,Wo,C,kH,kW)
    x_max = torch.amax(patches + epsilon, dim=(-1, -2), keepdim=True)   # (B,Ho,Wo,C,1,1)
    x_avg = torch.mean(patches + epsilon, dim=(-1, -2), keepdim=True)

    w = torch.softmax(w, dim=-1)  # (C_eff,2) con densatio: C_eff=1 se shared_window=True

    # canale è in dim=3 -> broadcast su (B,Ho,Wo,C,1,1)
    w_max = w[:, 0].view(1, 1, 1, -1, 1, 1)
    w_avg = w[:, 1].view(1, 1, 1, -1, 1, 1)

    g = torch.sigmoid(g).view(1, 1, 1, -1, 1, 1)

    return (w_max * x_max + w_avg * x_avg) * g

### **Stochastic pooling**

In [None]:
def stochastic_pooling2d(patches: torch.Tensor, eps=1e-8, **kwargs) -> torch.Tensor:
    B, Ho, Wo, C, kH, kW = patches.shape
    K = kH * kW

    x_flat = patches.reshape(B * Ho * Wo * C, K)
    probs = F.softmax(x_flat, dim=-1) + eps

    idx = torch.multinomial(probs, 1)               # (N,1)
    sampled = x_flat.gather(1, idx).squeeze(1)      # (N,)

    return sampled.view(B, Ho, Wo, C, 1, 1)

### **Weighted Max Pooling**

In [None]:
def weighted_max_pooling(patches, e_exponent, epsilon=1e-4, **kwargs):
    # x: (B, Ho, Wo, C, kH, kW)  -> out: (B, Ho, Wo, C, 1, 1)
    device = patches.device
    dtype  = patches.dtype
    e_exponent= e_exponent.to(device=device, dtype=dtype)

    e_exp = e_exponent.view(1, 1, 1, -1, 1, 1)  # broadcast su C (densatio: canali in dim=3)
    x_safe = torch.clamp(patches + epsilon, min=epsilon)

    out = torch.max(
        torch.max(torch.pow(x_safe, e_exp), dim=-1, keepdim=True)[0],
        dim=-2, keepdim=True
    )[0]  # (B,Ho,Wo,C,1,1)

    return out

### **Log-Sum-Exp pooling**

In [None]:
def lse_pooling2d(patches: torch.Tensor, beta: torch.Tensor, g: torch.Tensor, eps=1e-12, **kwargs) -> torch.Tensor:
    device = patches.device
    dtype  = patches.dtype
    beta= beta.to(device=device, dtype=dtype)
    g= g.to(device=device, dtype=dtype)

    B, Ho, Wo, C, kH, kW = patches.shape
    K = kH * kW
    x = patches.reshape(B, Ho, Wo, C, K)

    beta_eff = (F.softplus(beta) + eps).view(1, 1, 1, -1, 1)
    g_eff = torch.sigmoid(g).view(1, 1, 1, -1, 1)

    m = torch.max(x, dim=-1, keepdim=True)[0]
    z = beta_eff * (x - m)
    out = m + (1.0 / beta_eff) * torch.log(torch.mean(torch.exp(z), dim=-1, keepdim=True) + eps)

    out = out * (1.0 + g_eff)    # (B,Ho,Wo,C,1)
    return out.unsqueeze(-1)     # (B,Ho,Wo,C,1,1)

### **GeM**

In [None]:
def gem_pooling2d(patches, p, eps=1e-6, **kwargs):
    """
    GeM pooling (Generalized Mean) per finestra.
    patches: (B, Ho, Wo, C, kH, kW)
    p: tensor broadcastabile su C (es. shape (1,1) oppure (C,1))
    return: (B, Ho, Wo, C, 1, 1)
    """
    device = patches.device
    dtype  = patches.dtype
    p= p.to(device=device, dtype=dtype)

    # p >= 1 (stabile): p_eff = 1 + softplus(p)
    p_eff = 1.0 + F.softplus(p)

    # broadcast su (B,Ho,Wo,C,1,1)
    p_eff = p_eff.view(1, 1, 1, -1, 1, 1)

    x = torch.clamp(patches, min=0.0) + eps
    # mean(x^p)^(1/p) su kH,kW
    out = torch.mean(x ** p_eff, dim=(-1, -2), keepdim=True) ** (1.0 / p_eff)
    return out

### **Adaptive Statical Pooling (LSE - GeM - Std)**

In [None]:
def adp_stat_pooling2d(patches, w, tau, p, g, eps=1e-6, **kwargs):
    """
    patches: (B, Ho, Wo, C, kH, kW)
    ritorna: (B, Ho, Wo, C, 1, 1)

    Supporta:
      - w shape (3,)      -> shared-channel
      - w shape (1,3)     -> shared-channel
      - w shape (C,3)     -> per-canale
    """
    device = patches.device
    dtype  = patches.dtype

    w   = w.to(device=device, dtype=dtype)
    tau = tau.to(device=device, dtype=dtype)
    p   = p.to(device=device, dtype=dtype)
    g   = g.to(device=device, dtype=dtype)

    B, Ho, Wo, C, kH, kW = patches.shape
    K = kH * kW

    x = patches.reshape(B, Ho, Wo, C, K)
    x = torch.clamp(x, min=0.0) + eps

    #  parametri scalari / per-canale (broadcast safe)
    tau_eff = (F.softplus(tau) + eps).view(1, 1, 1, -1, 1)
    p_eff   = (1.0 + F.softplus(p)).view(1, 1, 1, -1, 1)
    g_eff   = (1.0 + torch.sigmoid(g)).view(1, 1, 1, -1, 1)

    # ---- LSE
    z = x / tau_eff
    z = z - torch.max(z, dim=-1, keepdim=True)[0]
    lse = tau_eff * torch.log(torch.mean(torch.exp(z), dim=-1, keepdim=True) + eps)

    # ---- GeM
    gem = torch.mean(x ** p_eff, dim=-1, keepdim=True) ** (1.0 / p_eff)

    # ---- Std
    std = torch.sqrt(torch.var(x, dim=-1, keepdim=True, unbiased=False) + eps)

    # w: (3,) | (1,3) | (C,3)
    if w.dim() == 1:
        # (3,) -> (1,3) -> (C,3)
        w = w.view(1, 3).expand(C, 3)
    elif w.dim() == 2 and w.shape[0] == 1:
        # (1,3) -> (C,3)
        w = w.expand(C, 3)
    elif w.dim() == 2 and w.shape[0] == C:
        pass
    else:
        raise ValueError(f"w ha shape non supportata: {w.shape}")

    w_mix = torch.softmax(w, dim=-1)  # (C,3)

    w_lse = w_mix[:, 0].view(1, 1, 1, C, 1)
    w_gem = w_mix[:, 1].view(1, 1, 1, C, 1)
    w_std = w_mix[:, 2].view(1, 1, 1, C, 1)

    out = (w_lse * lse + w_gem * gem + w_std * std) * g_eff
    return out.unsqueeze(-1)  # (B,Ho,Wo,C,1,1)


## *ResNet*

In [None]:
POOLING_MAP = {
    "max": (pool_max, {}),

    "stochastic": (stochastic_pooling2d, {}),

    "maxavg": (
        pooling_max_avg_2d,
        {
            "w": {"init_value": torch.ones(1, 2) * 0.5, "requires_grad": True},
            "g": {"init_value": torch.ones(1) * 0.5, "requires_grad": True},
        }
    ),

    "weighted_max_pooling": (weighted_max_pooling,
    {
        "e_exponent": {
            "init_value": torch.ones(1) * 1.5,   # es. 1.0=quasi max, >1 più aggressivo
            "requires_grad": True
        }
    }
),

    "lse": (
        lse_pooling2d,
        {
            "beta": {"init_value": torch.ones(1) * 0.5, "requires_grad": True},
            "g":    {"init_value": torch.ones(1) * 0.5, "requires_grad": True},
        }
    ),

    "gem": (
        gem_pooling2d,
        {
            "p": {"init_value": torch.ones(1) * 0.5, "requires_grad": True},  # p_eff ≈ 1.97
        }
    ),

    "adpstat": (
        adp_stat_pooling2d,
        {
            "w":   {"init_value": torch.ones(3) * 0.5, "requires_grad": True},
            "tau": {"init_value": torch.ones(1) * 0.5, "requires_grad": True},
            "p":   {"init_value": torch.ones(1) * 0.5, "requires_grad": True},
            "g":   {"init_value": torch.ones(1) * 0.5, "requires_grad": True},
        }
    ),
}

In [None]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        return F.relu(out + identity, inplace=True)


def _make_layer(in_planes, planes, blocks, stride):
    downsample = None
    if stride != 1 or in_planes != planes:
        downsample = nn.Sequential(
            nn.Conv2d(in_planes, planes, 1, stride=stride, bias=False),
            nn.BatchNorm2d(planes),
        )
    layers = [BasicBlock(in_planes, planes, stride=stride, downsample=downsample)]
    for _ in range(1, blocks):
        layers.append(BasicBlock(planes, planes))
    return nn.Sequential(*layers)

class Bottleneck(nn.Module):
    """ResNet50/101/152"""
    expansion = 4

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        width = out_channels
        out_expanded = out_channels * self.expansion

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(width)

        # stride NON esposto: sempre 1
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width)

        self.conv3 = nn.Conv2d(width, out_expanded, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_expanded)


        if in_channels == out_expanded:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_expanded, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_expanded),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = out + self.skip(x)
        out = F.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(
        self,
        num_classes=3,
        pooling_choice="max",      # chiave in POOLING_MAP
        stem_pool_size=(3, 3),
        stem_stride=(2, 2),
        stem_padding="same",
    ):
        super().__init__()

        # pooling
        if isinstance(pooling_choice, str):
            if pooling_choice not in POOLING_MAP:
                raise ValueError(f"pooling_choice sconosciuta: {pooling_choice}. Valide: {list(POOLING_MAP.keys())}")

            entry = POOLING_MAP[pooling_choice]

            if isinstance(entry, tuple):
                self.pooling_method, self.pooling_params_base = entry
                self.pooling_params_base = copy.deepcopy(self.pooling_params_base)
            else:
                self.pooling_method = entry
                self.pooling_params_base = {}
        else:
            self.pooling_method = pooling_choice
            self.pooling_params_base = {}

        # stem ResNet
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

        # layers ResNet18
        self.layer1 = _make_layer(64,  64, 2, 1)
        self.layer2 = _make_layer(64,  128, 2, 2)
        self.layer3 = _make_layer(128, 256, 2, 2)
        self.layer4 = _make_layer(256, 512, 2, 2)

        self.fc = nn.Linear(512, num_classes)

        # densatio pool
        self._dense_pool = None
        self._pool_sig = None

        # config pooling
        self._stem_pool_size = stem_pool_size
        self._stem_stride = stem_stride
        self._stem_padding = stem_padding

    def _build_pool(self):
        return CustomPooling2d(
            pool_size=self._stem_pool_size,
            stride=self._stem_stride,
            padding=self._stem_padding,
            pooling_method=self.pooling_method,
            pooling_params=self.pooling_params_base,
        )

    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)  # (B,512,H,W)

        # GLOBAL POOLING: finestra = (H,W) per ottenere (B,512,1,1)
        H, W = x.shape[-2], x.shape[-1]

        sig = (x.device.type, x.dtype, H, W, self.pooling_method.__name__)
        if self._dense_pool is None or getattr(self, "_pool_sig", None) != sig:
            self._stem_pool_size = (H, W)
            self._stem_stride = (1, 1)
            self._stem_padding = "valid"
            self._dense_pool = self._build_pool()
            self._pool_sig = sig

        x = self._dense_pool(x)           # (B,512,1,1)
        x = torch.flatten(x, 1)           # (B,512)
        return self.fc(x)

In [None]:
model = ResNet18(num_classes=NUM_CLASSES, pooling_choice="max").to(device)

for key in POOLING_MAP.keys():
    model = ResNet18(num_classes=NUM_CLASSES, pooling_choice=key).to(device)
    print("Training with pooling:", key)

## *Ciclo di addestramento*

In [None]:
def clean():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def save_cm_csv(cm, out_path: str):
    pd.DataFrame(cm.numpy()).to_csv(out_path, index=False)

In [None]:
RESULTS_DIR = "results"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [None]:
NUM_CLASSES = 3
NUM_EPOCHS = 250
BATCH_SIZE = 8

RESIZES = [False, True]
CHANNELS = [False, True]
AUGS = [False, True]
POOLINGS = list(POOLING_MAP.keys())

In [None]:
def cm_metrics(model, loader, num_classes=3):
    model.eval()

    cm_metric = MulticlassConfusionMatrix(num_classes=num_classes).to(device)
    acc_metric = MulticlassAccuracy(num_classes=num_classes, average="micro").to(device)
    bal_metric = MulticlassAccuracy(num_classes=num_classes, average="macro").to(device)
    wacc_metric = MulticlassAccuracy(num_classes=num_classes, average="weighted").to(device)

    with torch.no_grad():
        for x, y in tqdm(loader, desc="test", leave=False):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)

            cm_metric.update(preds, y)
            acc_metric.update(preds, y)
            bal_metric.update(preds, y)
            wacc_metric.update(preds, y)

    cm = cm_metric.compute().cpu()
    acc = acc_metric.compute().item()
    bal = bal_metric.compute().item()
    wacc = wacc_metric.compute().item()

    return cm, acc, bal, wacc

In [None]:
def train_val_bestloss(model, train_loader, val_loader, NUM_EPOCHS: int, NUM_CLASSES: int, out_dir=None):
    """
    Allena NUM_EPOCHS epoche e seleziona best model su MIN val loss.
    Salva history.csv + best_model.pt se out_dir è fornita.
    Ritorna:
      - history: list di dict
      - best_val: dict con best_epoch e best_val_loss
    """
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    crit = nn.CrossEntropyLoss()

    best_w = copy.deepcopy(model.state_dict())
    best_val_loss = float("inf")
    best_epoch = 0
    history = []

    for epoch in range(NUM_EPOCHS):
        for phase, loader in [("train", train_loader), ("val", val_loader)]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            tot = 0
            loss_sum = 0.0

            for x, y in tqdm(loader, desc=f"{phase} e{epoch+1}", leave=False):
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                if phase == "train":
                    opt.zero_grad(set_to_none=True)

                with torch.set_grad_enabled(phase == "train"):
                    out = model(x)
                    loss = crit(out, y)

                    if phase == "train":
                        loss.backward()
                        opt.step()

                bs = y.size(0)
                tot += bs
                loss_sum += loss.item() * bs

            avg_loss = loss_sum / max(tot, 1)
            history.append({"epoch": epoch + 1, "phase": phase, "loss": float(avg_loss)})

            print(f"epoch {epoch+1}/{NUM_EPOCHS} | {phase} | loss {avg_loss:.6f}")

            if phase == "val" and avg_loss < best_val_loss:
                best_val_loss = float(avg_loss)
                best_epoch = epoch + 1
                best_w = copy.deepcopy(model.state_dict())

        # se hai una tua funzione clean() nel notebook, chiamala pure fuori da qui

    best_val = {"best_epoch": int(best_epoch), "best_val_loss": float(best_val_loss)}

    if out_dir is not None:
        os.makedirs(out_dir, exist_ok=True)
        pd.DataFrame(history).to_csv(os.path.join(out_dir, "history.csv"), index=False)
        torch.save(best_w, os.path.join(out_dir, "best_model.pt"))

    return history, best_val

In [None]:
RESULTS_GRID = "results_bestvalloss"

def run_grid_resize_channels_aug_poolings(
    RESIZES, CHANNELS, AUGS, POOLINGS,
    POOLING_MAP,
    CustomDataset,
    df_train, df_val, df_test,
    img_dir, mean, std,
    BATCH_SIZE: int,
    NUM_EPOCHS: int,
    NUM_CLASSES: int,
):
    """
    Per ogni combinazione:
      resize ∈ {T,F}
      shared_window_across_channels ∈ {T,F}
      augment_train ∈ {T,F}
      pooling ∈ POOLINGS

    best model su MIN val loss, poi test.
    """
    ensure_dir(RESULTS_GRID)
    rows = []

    POOLINGS = list(POOLING_MAP.keys())

    for resize in RESIZES:
        rdir = os.path.join(RESULTS_GRID, f"resize_{int(resize)}")
        ensure_dir(rdir)

        for ch in CHANNELS:
            cdir = os.path.join(rdir, f"channels_{int(ch)}")
            ensure_dir(cdir)

            for pooling in tqdm(POOLINGS, desc=f"poolings | resize={int(resize)} ch={int(ch)}", leave=False):
                if pooling not in POOLING_MAP:
                    print(f"[SKIP] pooling='{pooling}' non trovato in POOLING_MAP.")
                    continue

                pool_params = POOLING_MAP[pooling][1]

                pdir = os.path.join(cdir, f"pooling_{pooling}")
                ensure_dir(pdir)

                for aug in AUGS:
                    adir = os.path.join(pdir, f"aug_{int(aug)}")
                    ensure_dir(adir)

                    print(f"\n>>> START resize={resize} shared_channels={ch} pooling={pooling} aug={aug} | path={adir}")

                    # DATASETS / LOADERS
                    train_ds = CustomDataset(df_train, img_dir, img_size=224, is_train=True,
                                             augment=bool(aug), resize=bool(resize), mean=mean, std=std)
                    val_ds   = CustomDataset(df_val, img_dir, img_size=224, is_train=False,
                                             augment=False, resize=bool(resize), mean=mean, std=std)
                    test_ds  = CustomDataset(df_test, img_dir, img_size=224, is_train=False,
                                             augment=False, resize=bool(resize), mean=mean, std=std)

                    tr_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
                    va_loader = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
                    te_loader = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

                    # MODEL (TRAIN)
                    pool_choice = pooling

                    if pool_choice == "max":
                        stem_pool_size = (3, 3)   # oppure (5,5)
                    else:
                        stem_pool_size = (2, 2)

                    model = ResNet18(
                        num_classes=NUM_CLASSES,
                        pooling_choice=pool_choice,
                        stem_pool_size=stem_pool_size,
                        stem_stride=(2, 2),
                        stem_padding="same",
                    ).to(device)

                    # TRAIN/VAL
                    history, best_val = train_val_bestloss(
                        model, tr_loader, va_loader,
                        NUM_EPOCHS=NUM_EPOCHS,
                        NUM_CLASSES=NUM_CLASSES,
                        out_dir=adir
                    )

                    # libera RAM
                    del model
                    clean()

                    # MODEL (TEST) ricarico best
                    model = ResNet18(
                        num_classes=NUM_CLASSES,
                        pooling_choice=pool_choice,
                        stem_pool_size=stem_pool_size,
                        stem_stride=(2, 2),
                        stem_padding="same",
                    ).to(device)

                    best_path = os.path.join(adir, "best_model.pt")
                    state = torch.load(best_path, map_location=device)
                    model.load_state_dict(state,strict=False)
                    model.eval()

                    # TEST
                    cm, acc, bal, wacc = cm_metrics(model, te_loader, num_classes=NUM_CLASSES)
                    save_cm_csv(cm, os.path.join(adir, "cm_test.csv"))

                    rows.append({
                        "resize": bool(resize),
                        "channels": bool(ch),
                        "pooling": str(pooling),
                        "augment_train": bool(aug),
                        "best_epoch": int(best_val["best_epoch"]),
                        "best_val_loss": float(best_val["best_val_loss"]),
                        "test_acc": float(acc),
                        "test_bal_acc": float(bal),
                        "test_weighted_acc": float(wacc),
                        "path": adir,
                    })

                    del model
                    clean()

    out_csv = os.path.join(RESULTS_GRID, "results_all.csv")
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print("Done:", RESULTS_GRID, "| saved:", out_csv)
    return out_csv

In [None]:
def sanity_check_poolings(poolings_to_test, num_classes=2, img_size=224, device=None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)

    # batch finto
    x = torch.randn(2, 3, img_size, img_size, device=device)
    y = torch.randint(0, num_classes, (2,), device=device)

    ok, bad = [], []

    for p in poolings_to_test:
        try:
            model = ResNet18(
                num_classes=num_classes,
                pooling_choice=p,
                stem_pool_size=(2, 2) if p != "texture" else (3, 3),
                stem_stride=(2, 2),
                stem_padding="same",
            ).to(device)

            model.train()
            out = model(x)
            assert out.dim() == 2 and out.shape[0] == x.shape[0] and out.shape[1] == num_classes, \
                f"Output shape strana: {tuple(out.shape)}"

            loss = F.cross_entropy(out, y)
            assert torch.isfinite(loss).item(), "Loss non finita (NaN/Inf)"

            model.zero_grad(set_to_none=True)
            loss.backward()

            # check gradienti finiti
            for name, param in model.named_parameters():
                if param.grad is not None and not torch.isfinite(param.grad).all().item():
                    raise RuntimeError(f"Grad non finito in {name}")

            ok.append(p)
        except Exception as e:
            bad.append((p, repr(e)))
        finally:
            del model
            if device == "cuda":
                torch.cuda.empty_cache()

    print("\nOK:", ok)
    print("\nFAIL:")
    for p, e in bad:
        print(" -", p, "->", e)

# poolings da testare: tutti tranne max e stochastic
poolings_to_test = [k for k in POOLING_MAP.keys() if k not in ("max", "stochastic")]
sanity_check_poolings(poolings_to_test, num_classes=NUM_CLASSES if "NUM_CLASSES" in globals() else 2)

In [None]:
run_grid_resize_channels_aug_poolings(
   RESIZES=RESIZES,
    CHANNELS=CHANNELS,
    AUGS=AUGS,
    POOLINGS=POOLINGS,
    POOLING_MAP=POOLING_MAP,
    CustomDataset=CustomDataset,
    df_train=df_train,
    df_val=df_val,
    df_test=df_test,
    img_dir=img_dir,
    mean=mean,
    std=std,
    BATCH_SIZE=16,
    NUM_EPOCHS=250,
    NUM_CLASSES=3
)

## *Risultati*

### **LOSS**

In [None]:
base_dir = "results_bestvalloss"
res = "results_all.csv"
path = os.path.join(base_dir, res)
results_df = pd.read_csv(path)

In [None]:
EPOCH_COL = "epoch"
PHASE_COL = "phase"   # "train"/"val"
LOSS_COL  = "loss"

POOLINGS = ["max", "stochastic", "maxavg", "weighted_max_pooling", "lse", "gem", "adpstat"]

# TRAIN: colore fisso per pooling
POOL_COLOR = {
    "max": "#1f77b4",
    "stochastic": "#ff7f0e",
    "maxavg": "#2ca02c",
    "weighted_max_pooling": "#efff09",
    "lse": "#f01405",
    "gem": "#35f8ff",
    "adpstat": "#ED80C9",
}

# VAL: colore fisso uguale per tutti
VAL_COLOR = "#71067B"

# 4 combinazioni (ordine fisso)
COMBOS = [
    (True,  True),
    (False, False),
    (True,  False),
    (False, True),
]

for col in ["resize", "channels", "augment_train"]:
    if col in results_df.columns:
        results_df[col] = (
            results_df[col]
            .map({"True": True, "False": False, "1": True, "0": False})
            .fillna(results_df[col])
            .astype(bool)
        )

required_cols = {"pooling", "path", "resize", "channels"}
missing_cols = sorted(list(required_cols - set(results_df.columns)))
if missing_cols:
    raise ValueError(f"Mancano colonne in results_all.csv: {missing_cols}")

present_poolings = set(results_df["pooling"].astype(str).unique())
missing_poolings = [p for p in POOLINGS if p not in present_poolings]
extra_poolings = sorted(list(present_poolings - set(POOLINGS)))
print("Pooling mancanti rispetto a POOLINGS:", missing_poolings)
print("Pooling extra presenti nel CSV:", extra_poolings)

def resolve_run_dir(p: str) -> str:
    p = str(p).strip().replace("\\", os.sep).replace("/", os.sep)
    if os.path.isabs(p) and os.path.exists(p):
        return p
    if os.path.exists(p):
        return p
    return os.path.join(base_dir, p)

def load_history(run_dir: str) -> pd.DataFrame:
    return pd.read_csv(os.path.join(run_dir, "history.csv"))

def aggregate_histories(df_subset: pd.DataFrame, phase: str):
    """Media±std su tutti i run in df_subset per una fase."""
    series_list = []

    for _, row in df_subset.iterrows():
        run_dir = resolve_run_dir(row["path"])
        if not os.path.exists(run_dir):
            continue

        try:
            h = load_history(run_dir)
        except Exception:
            continue

        if not {EPOCH_COL, PHASE_COL, LOSS_COL}.issubset(h.columns):
            continue

        hh = h[h[PHASE_COL].astype(str).str.lower() == phase.lower()].copy()
        if hh.empty:
            continue

        hh = hh.sort_values(EPOCH_COL)
        s = pd.Series(hh[LOSS_COL].values, index=hh[EPOCH_COL].values)
        series_list.append(s)

    if not series_list:
        return None

    common_epochs = set(series_list[0].index)
    for s in series_list[1:]:
        common_epochs &= set(s.index)
    if not common_epochs:
        return None

    common_epochs = np.array(sorted(common_epochs))
    mat = np.vstack([s.loc[common_epochs].values for s in series_list])
    return common_epochs, mat.mean(axis=0), mat.std(axis=0, ddof=0), mat.shape[0]

In [None]:
EPOCH_COL = "epoch"
PHASE_COL = "phase"   # "train"/"val"
LOSS_COL  = "loss"

POOLINGS = ["max", "stochastic", "maxavg", "weighted_max_pooling", "lse", "gem", "adpstat"]

POOL_COLOR = {
    "max": "#1f77b4",
    "stochastic": "#ff7f0e",
    "maxavg": "#2ca02c",
    "weighted_max_pooling": "#efff09",
    "lse": "#f01405",
    "gem": "#35f8ff",
    "adpstat": "#ED80C9",
}
VAL_COLOR = "#71067B"

COMBOS = [
    (True,  True),
    (False, False),
    (True,  False),
    (False, True),
]

# --- bool coercion robusto
for col in ["resize", "channels", "augment_train"]:
    if col in results_df.columns:
        results_df[col] = (
            results_df[col]
            .map({"True": True, "False": False, "1": True, "0": False, 1: True, 0: False})
            .fillna(results_df[col])
            .astype(bool)
        )

required_cols = {"pooling", "path", "resize", "channels"}
missing_cols = sorted(list(required_cols - set(results_df.columns)))
if missing_cols:
    raise ValueError(f"Mancano colonne in results_all.csv: {missing_cols}")

def resolve_run_dir(p: str) -> str:
    p = str(p).strip().replace("\\", os.sep).replace("/", os.sep)
    if os.path.isabs(p) and os.path.exists(p):
        return p
    if os.path.exists(p):
        return p
    return os.path.join(base_dir, p)

def load_history(run_dir: str) -> pd.DataFrame:
    return pd.read_csv(os.path.join(run_dir, "history.csv"))

def aggregate_histories(df_subset: pd.DataFrame, phase: str):
    """Media±std su tutti i run in df_subset per una fase."""
    series_list = []

    for _, row in df_subset.iterrows():
        run_dir = resolve_run_dir(row["path"])
        hist_path = os.path.join(run_dir, "history.csv")
        if not os.path.exists(hist_path):
            continue

        try:
            h = pd.read_csv(hist_path)
        except Exception:
            continue

        if not {EPOCH_COL, PHASE_COL, LOSS_COL}.issubset(h.columns):
            continue

        hh = h[h[PHASE_COL].astype(str).str.lower() == phase.lower()].copy()
        if hh.empty:
            continue

        hh = hh.sort_values(EPOCH_COL)
        s = pd.Series(hh[LOSS_COL].values, index=hh[EPOCH_COL].values)
        series_list.append(s)

    if not series_list:
        return None

    common_epochs = set(series_list[0].index)
    for s in series_list[1:]:
        common_epochs &= set(s.index)
    if not common_epochs:
        return None

    common_epochs = np.array(sorted(common_epochs))
    mat = np.vstack([s.loc[common_epochs].values for s in series_list])
    return common_epochs, mat.mean(axis=0), mat.std(axis=0, ddof=0), mat.shape[0]

def plot_pooling_grid_train_val_2x2(pooling: str, out_dir: str, show_std=True, min_runs=1):
    train_color = POOL_COLOR.get(pooling, "#000000")
    val_color = VAL_COLOR

    fig, axes = plt.subplots(2, 2, figsize=(10.5, 8.0), sharey=True)
    fig.suptitle(f"{pooling} — Loss vs Epoch", fontsize=14, fontweight="bold")
    axes = axes.ravel()

    any_plotted = False

    for ax, (r, c) in zip(axes, COMBOS):
        df_sub = results_df[
            (results_df["pooling"].astype(str) == pooling) &
            (results_df["resize"] == r) &
            (results_df["channels"] == c)
        ].copy()

        n_total = len(df_sub)
        agg_tr = aggregate_histories(df_sub, phase="train")
        agg_va = aggregate_histories(df_sub, phase="val")

        n_tr = agg_tr[3] if agg_tr is not None else 0
        n_va = agg_va[3] if agg_va is not None else 0

        ax.set_title(f"resize={r}, channels={c}\nruns: train={n_tr}, val={n_va} (tot={n_total})", fontsize=10)

        if agg_tr is not None and n_tr >= min_runs:
            e, m, s, n = agg_tr
            ax.plot(e, m, color=train_color, linewidth=2.5,
                    label="train" if (r, c) == COMBOS[0] else None)
            if show_std and n > 1:
                ax.fill_between(e, m - s, m + s, color=train_color, alpha=0.12)
            any_plotted = True

        if agg_va is not None and n_va >= min_runs:
            e, m, s, n = agg_va
            ax.plot(e, m, color=val_color, linewidth=2.5,
                    label="val" if (r, c) == COMBOS[0] else None)
            if show_std and n > 1:
                ax.fill_between(e, m - s, m + s, color=val_color, alpha=0.10)
            any_plotted = True

        ax.set_xlabel("epoch")
        ax.grid(True, alpha=0.25)

    # y-label solo sulla colonna sinistra
    for ax in axes[::2]:
        ax.set_ylabel("loss")

    # legenda unica
    handles, labels = axes[0].get_legend_handles_labels()
    if handles:
        fig.legend(handles, labels, loc="lower center", ncol=2, frameon=False)

    # spazio sotto per legenda
    plt.tight_layout(rect=[0, 0.08, 1, 0.92])

    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{pooling}.png")
    fig.savefig(out_path, dpi=240)
    plt.close(fig)

    if any_plotted:
        print("Salvato:", out_path)
        return out_path
    else:
        print(f"[WARN] Nessun dato plottabile per pooling={pooling}")
        return None

out_dir = os.path.join(base_dir, "loss_plots")
saved = []

for p in POOLINGS:
    outp = plot_pooling_grid_train_val_2x2(p, out_dir=out_dir, show_std=True, min_runs=1)
    if outp is not None:
        saved.append(outp)

print("\nout_dir =", out_dir)
print("Creati:", len(saved), "file")
for p in saved:
    img = Image.open(p)
    plt.figure(figsize=(10, 7))
    plt.imshow(img)
    plt.axis("off")
    plt.title(os.path.basename(p), fontsize=10)
    plt.show()

### **Tabelle**

In [None]:
# bool
results_df["resize"] = results_df["resize"].map({"True": True, "False": False, "1": True, "0": False}).fillna(results_df["resize"]).astype(bool)
results_df["channels"] = results_df["channels"].map({"True": True, "False": False, "1": True, "0": False}).fillna(results_df["channels"]).astype(bool)

ACC_COL = "test_bal_acc"

def make_block(df, resize_val, channels_val):
    d = df[(df["resize"] == resize_val) & (df["channels"] == channels_val)].copy()
    if ACC_COL in d.columns:
        d = d.sort_values(ACC_COL, ascending=True, kind="mergesort")
    return d

t00 = make_block(results_df, False, False)
t01 = make_block(results_df, False, True)
t10 = make_block(results_df, True,  False)
t11 = make_block(results_df, True,  True)

print("Righe:", {"r0c0": len(t00), "r0c1": len(t01), "r1c0": len(t10), "r1c1": len(t11)})

In [None]:
ACC_COL = "test_bal_acc"          # metrica principale
LOSS_COL = ["best_val_loss"]

DROP_COLS = {"path", "test_acc", "test_weighted_acc"}

def table_rc(df, title, out_path, top_n=25, max_cols=14, dpi=220):
    d = df.copy()

    #  drop colonne
    drop_here = [c for c in d.columns if c in DROP_COLS]
    if drop_here:
        d = d.drop(columns=drop_here)

    loss_col = next((c for c in LOSS_COL if c in d.columns), None)

    if ACC_COL in d.columns:
        if loss_col is not None:
            d = d.sort_values([ACC_COL, loss_col], ascending=[False, True], kind="mergesort")
        else:
            d = d.sort_values(ACC_COL, ascending=False, kind="mergesort")

    d = d.head(top_n)

    if max_cols is not None and d.shape[1] > max_cols:
        d = d.iloc[:, :max_cols]

    # best_epoch intero
    if "best_epoch" in d.columns:
        d["best_epoch"] = pd.to_numeric(d["best_epoch"], errors="coerce").astype("Int64")

    # format per rendering (stringhe)
    r = d.copy()
    for col in r.columns:
        if pd.api.types.is_float_dtype(r[col]):
            r[col] = r[col].map(lambda x: f"{x:.2f}" if pd.notnull(x) else "")
        elif col == "best_epoch":
            r[col] = r[col].map(lambda x: f"{int(x)}" if pd.notnull(x) else "")
        else:
            r[col] = r[col].map(lambda x: "" if pd.isna(x) else str(x))

    nrows, ncols = r.shape

    # figura proporzionata
    fig_w = max(10, 1.05 * ncols + 2)
    fig_h = max(3.2, 0.42 * (nrows + 2))

    fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi)
    ax.axis("off")

    # titolo con best test_bal_acc
    subtitle = ""
    if ACC_COL in d.columns and len(d) > 0 and pd.notnull(d[ACC_COL].iloc[0]):
        subtitle = f" — best {ACC_COL}: {d[ACC_COL].iloc[0]:.4f}"
    ax.set_title(f"{title}{subtitle}", pad=14, fontsize=14, fontweight="bold")

    tbl = ax.table(
        cellText=r.values,
        colLabels=r.columns,
        loc="center",
        cellLoc="center",
        colLoc="center",
    )

    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)

    try:
        tbl.auto_set_column_width(col=list(range(ncols)))
    except Exception:
        pass

    header_h = 0.06
    body_h = 0.048

    # la migliore è la prima riga del body
    best_row = 1  # 0 è header

    for (row, col), cell in tbl.get_celld().items():
        cell.set_linewidth(0.4)

        if row == 0:
            cell.set_height(header_h)
            cell.set_text_props(fontweight="bold")
            cell.set_facecolor("#f0f0f0")
        else:
            cell.set_height(body_h)
            cell.set_facecolor("#fbfbfb" if row % 2 == 0 else "white")

            if row == best_row:
                cell.set_facecolor("#a4a8f3")

    tbl.scale(1.05, 1.15)

    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", pad_inches=0.2)
    plt.close(fig)

out_dir = os.path.join(base_dir, "tables_png")
os.makedirs(out_dir, exist_ok=True)

table_rc(t00, "resize=False, channels=False", os.path.join(out_dir, "t00.png"), top_n=25)
table_rc(t01, "resize=False, channels=True",  os.path.join(out_dir, "t01.png"), top_n=25)
table_rc(t10, "resize=True,  channels=False", os.path.join(out_dir, "t10.png"), top_n=25)
table_rc(t11, "resize=True,  channels=True",  os.path.join(out_dir, "t11.png"), top_n=25)

print("PNG salvati in:", out_dir)

### **CONFUSION MATRIX**

In [None]:
base_dir = "results_bestvalloss"
res = "results_all.csv"
path = os.path.join(base_dir, res)
results_df = pd.read_csv(path)

def pick_best_row(results_df):
    d = results_df.dropna(subset=["test_bal_acc"]).copy()

    d["test_bal_acc"] = d["test_bal_acc"].round(2)

    d = d.sort_values(["test_bal_acc", "best_val_loss"],
                      ascending=[False,True],
                      kind="mergesort")
    return d.iloc[0]

def best_run_dir_from_row(best):
    # path
    if "path" in best.index and pd.notna(best["path"]):
        return str(best["path"])

best = pick_best_row(results_df)
best_run_dir = best_run_dir_from_row(best)

cm_path = os.path.join(best_run_dir, "cm_test.csv")
print("Best run:", best_run_dir)
print("Uso CM da:", cm_path)
print("Criterio sort:", "test_bal_acc desc, best_val_loss")

In [None]:
ID2LABEL = {v: k for k, v in LABEL_MAP.items()}
n_classes = 3
class_names = [ID2LABEL[i] for i in range(n_classes)]

cm_df = pd.read_csv(cm_path)

# prova a isolare 3x3 anche se il CSV ha colonne extra
cm_df = cm_df.select_dtypes(include=[np.number])
cm_df = cm_df.iloc[:n_classes, :n_classes]
cm = cm_df.to_numpy(dtype=int)

plt.figure(figsize=(3.3, 3.0), dpi=260)
im = plt.imshow(cm, cmap="Blues", vmin=0, vmax=max(1, cm.max()))
plt.colorbar(im, fraction=0.05, pad=0.03)

plt.xticks(range(n_classes), class_names, rotation=20, ha="right", fontsize=7)
plt.yticks(range(n_classes), class_names, fontsize=7)
plt.xlabel("Predicted Label", fontsize=8)
plt.ylabel("True Label", fontsize=8)
plt.suptitle("CM — Stochastic\nResize: True / Channel True / No Aug", fontsize=9, fontweight="bold", y=1.02)

ax = plt.gca()
ax.set_aspect("equal")

ax.set_xticks(np.arange(-.5, n_classes, 1), minor=True)
ax.set_yticks(np.arange(-.5, n_classes, 1), minor=True)
ax.grid(which="minor", color="black", linestyle="-", linewidth=0.8)
ax.tick_params(which="minor", bottom=False, left=False)

thr = cm.max() * 0.55
for i in range(n_classes):
    for j in range(n_classes):
        val = cm[i, j]
        color = "white" if (cm.max() > 0 and val >= thr) else "black"
        ax.text(j, i, str(val), ha="center", va="center",
                fontsize=8, fontweight="bold", color=color)

for s in ax.spines.values():
    s.set_visible(True)
    s.set_color("black")
    s.set_linewidth(1.2)

plt.tight_layout(pad=0.3)
plt.show()

In [None]:
ACC_COL = "test_bal_acc"

# filtro ESATTO: solo "max"
dmax = results_df[
    results_df["pooling"].astype(str).str.strip().str.lower().eq("max")
].copy()

if dmax.empty:
    raise ValueError("Nessun run con pooling == 'max' trovato (controlla i valori della colonna 'pooling').")

sort_cols = [ACC_COL] + (["best_val_loss"] if "best_val_loss" in dmax.columns else [])
ascending = [False] + ([True] if "best_val_loss" in dmax.columns else [])

best_max = (
    dmax.dropna(subset=[ACC_COL])
       .sort_values(sort_cols, ascending=ascending, kind="mergesort")
       .iloc[0]
)

best_run_dir = str(best_max["path"])
cm_path = os.path.join(best_run_dir, "cm_test.csv")

print("Best MAX run:", best_run_dir)
print("Best MAX test_bal_acc:", float(best_max[ACC_COL]))
print("CM file:", cm_path)

In [None]:
ID2LABEL = {v: k for k, v in LABEL_MAP.items()}
n_classes = 3
class_names = [ID2LABEL[i] for i in range(n_classes)]

cm_df = pd.read_csv(cm_path)

# isola 3x3 anche se il CSV ha colonne extra
cm_df = cm_df.select_dtypes(include=[np.number])
cm_df = cm_df.iloc[:n_classes, :n_classes]
cm = cm_df.to_numpy(dtype=int)

# --- plot ---
plt.figure(figsize=(3.3, 3.0), dpi=260)
im = plt.imshow(cm, cmap="Blues", vmin=0, vmax=max(1, cm.max()))
plt.colorbar(im, fraction=0.05, pad=0.03)

plt.xticks(range(n_classes), class_names, rotation=20, ha="right", fontsize=7)
plt.yticks(range(n_classes), class_names, fontsize=7)
plt.xlabel("Predicted Label", fontsize=8)
plt.ylabel("True Label", fontsize=8)

# titolo coerente col best_max selezionato
rz = int(best_max["resize"]) if "resize" in best_max else None
ch = int(best_max["channels"]) if "channels" in best_max else None

# a volte la colonna si chiama augment_train, a volte aug
if "augment_train" in best_max:
    aug = int(best_max["augment_train"])
elif "aug" in best_max:
    aug = int(best_max["aug"])
else:
    aug = None

aug_txt = ("Aug: True" if aug == 1 else "No Aug") if aug is not None else "Aug: ?"
rz_txt  = ("Resize: True" if rz == 1 else "Resize: False") if rz is not None else "Resize: ?"
ch_txt  = ("Channel True" if ch == 1 else "Channel False") if ch is not None else "Channel ?"

plt.suptitle(
    f"CM — MAX\n{rz_txt} / {ch_txt} / {aug_txt}",
    fontsize=9, fontweight="bold", y=1.02
)

ax = plt.gca()
ax.set_aspect("equal")

ax.set_xticks(np.arange(-.5, n_classes, 1), minor=True)
ax.set_yticks(np.arange(-.5, n_classes, 1), minor=True)
ax.grid(which="minor", color="black", linestyle="-", linewidth=0.8)
ax.tick_params(which="minor", bottom=False, left=False)

thr = cm.max() * 0.55
for i in range(n_classes):
    for j in range(n_classes):
        val = cm[i, j]
        color = "white" if (cm.max() > 0 and val >= thr) else "black"
        ax.text(j, i, str(val), ha="center", va="center",
                fontsize=8, fontweight="bold", color=color)

for s in ax.spines.values():
    s.set_visible(True)
    s.set_color("black")
    s.set_linewidth(1.2)

plt.tight_layout(pad=0.3)
plt.show()

### **GRAD-CAM**

In [None]:
ACC_COL = "test_bal_acc"

target = model.layer4[-1]

dmax = results_df[
    results_df["pooling"].astype(str).str.strip().str.lower().eq("max")
].copy()

if dmax.empty:
    raise ValueError("Nessun run con pooling == 'max' trovato.")

# tie-breaker su best_val_loss
sort_cols = [ACC_COL] + (["best_val_loss"] if "best_val_loss" in dmax.columns else [])
ascending = [False] + ([True] if "best_val_loss" in dmax.columns else [])

best_max = (
    dmax.dropna(subset=[ACC_COL])
        .sort_values(sort_cols, ascending=ascending, kind="mergesort")
        .iloc[0]
)

best_run_dir = str(best_max["path"])
ckpt_path = os.path.join(best_run_dir, "best_model.pt")
print("Grad-CAM: BEST MAX run =", best_run_dir)
print("Checkpoint:", ckpt_path)
print("Best MAX test_bal_acc:", float(best_max[ACC_COL]))

# modello + checkpoint
model = ResNet18(num_classes=NUM_CLASSES).to(device)
state = torch.load(ckpt_path, map_location=device)

if isinstance(state, dict) and "state_dict" in state:
    state = state["state_dict"]
elif isinstance(state, dict) and "model_state_dict" in state:
    state = state["model_state_dict"]

# rimuovi eventuale prefix "module."
state = {k.replace("module.", ""): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
model.eval()

In [None]:
def pick_test_image_for_label(df_test, label_id, seed=None):
    label_col = "label_id" if "label_id" in df_test.columns else "label"

    rows = df_test[df_test[label_col] == label_id]
    if rows.empty:
        raise ValueError(f"Nessuna immagine con label {label_id}")

    if seed is not None:
        rows = rows.sample(1, random_state=seed)
    else:
        rows = rows.sample(1)   # RANDOM ogni volta

    row = rows.iloc[0]

    if "path" in row:
        img_path = row["path"]
    elif "img_path" in row:
        img_path = row["img_path"]
    elif "id" in row:
        img_path = os.path.join(img_dir, f"img_{int(row['id'])}.png")
    else:
        raise ValueError("Non trovo il path immagine.")

    return img_path, row


target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers)

target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers)

def gradcam_for_class(label_id):
    # seleziona immagine dal test (una per quella label)
    img_path, row = pick_test_image_for_label(df_test, label_id)

    dataset_cam = CustomDataset(
        df_test.loc[[row.name]],
        img_dir=img_dir,
        img_size=224,
        is_train=False,
        augment=False,
        resize=bool(best_max["resize"]),
        mean=mean,
        std=std,
    )

    input_tensor, y_true = dataset_cam[0]
    input_tensor = input_tensor.unsqueeze(0).to(device)

    # immagine per overlay (non normalizzata)
    img_bgr = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_rgb = cv2.resize(img_rgb, (224, 224))
    rgb_float = img_rgb.astype(np.float32) / 255.0

    #  PRED
    with torch.no_grad():
        logits = model(input_tensor)
        pred_class = int(torch.argmax(logits, dim=1).item())

    targets = [ClassifierOutputTarget(pred_class)]

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
    overlay = show_cam_on_image(rgb_float, grayscale_cam, use_rgb=True, image_weight=0.35)

    return img_rgb, overlay, int(y_true), pred_class

ID2LABEL = {v: k for k, v in LABEL_MAP.items()}

for label_id in range(3):
    orig, cam_img, y_true, pred = gradcam_for_class(label_id)

    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.title(f"Original | true={ID2LABEL[y_true]}")
    plt.imshow(orig); plt.axis("off")

    plt.subplot(1,2,2)
    plt.title(f"Grad-CAM MAXPOOL | pred={ID2LABEL[pred]}")
    plt.imshow(cam_img); plt.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
df0 = results_df.copy()

# filtra stochastic
dstoch = df0[df0["pooling"].astype(str).str.strip().str.lower().eq("stochastic")].copy()
if dstoch.empty:
    raise ValueError("Nessun run con pooling == 'stochastic' trovato.")

dstoch = dstoch.dropna(subset=[ACC_COL, "path"]).copy()
if "best_val_loss" in dstoch.columns:
    dstoch = dstoch.dropna(subset=["best_val_loss"]).copy()

dstoch["_acc2"] = dstoch[ACC_COL].round(2)

# max su acc arrotondata
best_acc2 = dstoch["_acc2"].max()
cand = dstoch[dstoch["_acc2"] == best_acc2].copy()

if "best_val_loss" in cand.columns:
    cand = cand.sort_values(["best_val_loss", ACC_COL], ascending=[True, False], kind="mergesort")
else:
    cand = cand.sort_values([ACC_COL], ascending=[False], kind="mergesort")

best_stoch = cand.iloc[0].drop(labels=["_acc2"])

best_run_dir = str(best_stoch["path"])
ckpt_path = os.path.join(best_run_dir, "best_model.pt")

print("BEST STOCHASTIC (acc2=%.2f):" % best_acc2)
print("run:", best_run_dir)
print("test_bal_acc:", float(best_stoch[ACC_COL]))
if "best_val_loss" in best_stoch:
    print("best_val_loss:", float(best_stoch["best_val_loss"]))
print("ckpt:", ckpt_path)

model = ResNet18(num_classes=NUM_CLASSES).to(device)
state = torch.load(ckpt_path, map_location=device)

if isinstance(state, dict) and "state_dict" in state:
    state = state["state_dict"]
elif isinstance(state, dict) and "model_state_dict" in state:
    state = state["model_state_dict"]

state = {k.replace("module.", ""): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
model.eval()

In [None]:
def tensor_to_rgb(x,mean,std):
    if x.dim() == 4:
        x = x[0]
    x = x.detach().cpu().float()

    mean_t = torch.tensor(mean).view(-1,1,1)
    std_t = torch.tensor(std).view(-1,1,1)

    x = x * std_t + mean_t
    x = x.clamp(0,1)
    rgb = x.permute(1,2,0).numpy()
    return rgb


def pick_test_image_for_label(df_test, label_id, seed=None):
    label_col = "label_id" if "label_id" in df_test.columns else "label"

    rows = df_test[df_test[label_col] == label_id]
    if rows.empty:
        raise ValueError(f"Nessuna immagine con label {label_id}")

    row = rows.sample(1, random_state=seed).iloc[0] if seed is not None else rows.sample(1).iloc[0]

    if "path" in row:
        img_path = row["path"]
    elif "img_path" in row:
        img_path = row["img_path"]
    elif "id" in row:
        img_path = os.path.join(img_dir, f"img_{int(row['id'])}.png")
    else:
        raise ValueError("Non trovo il path immagine.")

    return img_path, row

target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers)

def gradcam_for_class_stochastic(label_id):
    img_path, row = pick_test_image_for_label(df_test, label_id)

    dataset_cam = CustomDataset(
        df_test.loc[[row.name]],
        img_dir=img_dir,
        img_size=224,
        is_train=False,
        augment=False,
        resize=bool(best_stoch["resize"]),
        mean=mean,
        std=std,
    )

    input_tensor, y_true = dataset_cam[0]
    input_tensor = input_tensor.unsqueeze(0).to(device)

    # overlay image (NON normalizzata)
    rgb_float = tensor_to_rgb(input_tensor,mean,std)

    # PRED reale
    with torch.no_grad():
        logits = model(input_tensor)
        pred_class = int(torch.argmax(logits, dim=1).item())

    targets = [ClassifierOutputTarget(pred_class)]

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
    overlay = show_cam_on_image(
        rgb_float,
        grayscale_cam,
        use_rgb=True,
        image_weight=0.35
    )

    return rgb_float, overlay, int(y_true), pred_class

ID2LABEL = {v: k for k, v in LABEL_MAP.items()}

for label_id in range(3):
    orig, cam_img, y_true, pred = gradcam_for_class_stochastic(label_id)

    plt.figure(figsize=(8,4))

    plt.subplot(1,2,1)
    plt.title(f"Original | true={ID2LABEL[y_true]}")
    plt.imshow(orig); plt.axis("off")

    plt.subplot(1,2,2)
    plt.title(f"Grad-CAM STOCHASTIC | pred={ID2LABEL[pred]}")
    plt.imshow(cam_img); plt.axis("off")

    plt.tight_layout()
    plt.show()
