# FPN-MIL training (self-contained)

All code in this notebook. No clone, no imports from the GitHub repo.

**Input:** Features from the **pi-cai-feature-extraction** notebook (cropped pipeline):
- **Patches are cropped** to the prostate ROI (batch crop or on-the-fly with `MASKS_DIR`).
- **One slice** (from the 3-channel cropped volume) = **one patch**; **one case** = **one bag**.
- Layout: `multi_scale/<case_id>/<case_id>/` with `C4_patch_features.pt`, `C5_patch_features.pt`, `info_patches.h5`.
- Labels CSV: `image_id` (= case/bag id), `fold`, and a 0/1 label column (e.g. `cs_pca`).

In [164]:
# Config â€” expects features from pi-cai-feature-extraction (cropped ROI, one patch per slice, one bag per case)
from pathlib import Path

INPUT_ROOT = Path("/kaggle/input/notebooks/sananiroomand/pi-cai-feature-extraction")
FEAT_FOLDER = "picai_extracted_features"
LABELS_CSV = INPUT_ROOT / "picai_labels.csv"
LABEL_COL = "cs_pca"

WORK_DIR = Path("/kaggle/working") if Path("/kaggle/working").exists() else Path(".")
# Features from pi-cai-feature-extraction output: .../picai_extracted_features/multi_scale/<case_id>/<case_id>/...
FEAT_DIR = INPUT_ROOT / FEAT_FOLDER
MS_DIR = FEAT_DIR / "multi_scale"  # e.g. .../multi_scale/10000_1000000/10000_1000000/info_patches.h5

BATCH_SIZE = 8
EPOCHS = 30
LR = 5e-5
SEED = 42
FPN_DIM = 256
ENCODER_DIM = 256
SCALES = [16, 32, 128]  # used as scale ids; we have 2 levels C4, C5
NUM_WORKERS = 2

In [165]:
# Load labels and prepare train/val split (features read from input path)
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

if not MS_DIR.exists():
    raise FileNotFoundError(f"Feature dir not found: {MS_DIR}. Add pi-cai-feature-extraction output as notebook input.")

df = pd.read_csv(LABELS_CSV)
if "patient_id" not in df.columns:
    c = "image_id" if "image_id" in df.columns else ("case_id" if "case_id" in df.columns else df.columns[0])
    df["patient_id"] = df[c].astype(str)
if "image_id" not in df.columns:
    df["image_id"] = df["patient_id"].astype(str)
df["patient_id"] = df["patient_id"].astype(str)
df["image_id"] = df["image_id"].astype(str)

df[LABEL_COL] = (pd.to_numeric(df[LABEL_COL], errors="coerce").fillna(0) > 0.5).astype(int)

if "split" not in df.columns:
    if "fold" in df.columns:
        df["split"] = df["fold"].map(lambda f: "training" if f in (0, 1) else "test")
    else:
        tr_idx, te_idx = train_test_split(df.index, test_size=0.2, stratify=df[LABEL_COL], random_state=SEED)
        df["split"] = "test"
        df.loc[tr_idx, "split"] = "training"

CSV_PATH = WORK_DIR / "labels_with_split.csv"
df.to_csv(CSV_PATH, index=False)
train_df = df[df["split"] == "training"].reset_index(drop=True)
val_df = df[df["split"] == "test"].reset_index(drop=True)
print("Train:", len(train_df), "Val:", len(val_df))
print("Label dist train:", train_df[LABEL_COL].value_counts().to_dict())

Train: 342 Val: 177
Label dist train: {0: 200, 1: 142}


In [166]:
# Dataset: load C4, C5 from multi_scale/<case_id>/<case_id>/ (cropped patches, one per slice; case = bag)
import os
import numpy as np
import h5py
import torch
from torch.utils.data import Dataset
from pathlib import Path

def load_bag(bag_dir, level=None):
    fname = f"{level}_patch_features.pt" if level else "patch_features.pt"
    x = torch.load(os.path.join(bag_dir, fname))
    with h5py.File(os.path.join(bag_dir, "info_patches.h5"), "r") as f:
        coords = np.array(f["coords"])
    idx = np.lexsort((coords[:, 0], coords[:, 1])) if coords.ndim >= 2 else np.arange(len(x))
    return x[idx]

class FPNMILDataset(Dataset):
    def __init__(self, dataframe, feat_root, label_col):
        self.df = dataframe.reset_index(drop=True)
        self.feat_root = Path(feat_root)
        self.label_col = label_col

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        # Bag = case; features at multi_scale/<case_id>/<case_id>/ (from cropped extraction)
        bag_id = str(row["image_id"])
        bag_dir = self.feat_root / bag_id / bag_id
        c4 = load_bag(bag_dir, "C4")
        c5 = load_bag(bag_dir, "C5")
        y = torch.tensor(row[self.label_col], dtype=torch.float32)
        return {"x": [c4, c5], "y": y}

def collate(batch):
    # Bags have different N; pad to max N per scale and pass mask
    x_out, mask_out = [], []
    for i in range(2):
        tensors = [b["x"][i] for b in batch]
        max_n = max(t.size(0) for t in tensors)
        C, H, W = tensors[0].shape[1], tensors[0].shape[2], tensors[0].shape[3]
        padded = torch.zeros(len(batch), max_n, C, H, W, dtype=tensors[0].dtype)
        mask = torch.zeros(len(batch), max_n, dtype=torch.float32)
        for b_idx, t in enumerate(tensors):
            n = t.size(0)
            padded[b_idx, :n] = t
            mask[b_idx, :n] = 1.0
        x_out.append(padded)
        mask_out.append(mask)
    y = torch.stack([b["y"] for b in batch], dim=0).unsqueeze(1)
    return {"x": x_out, "mask": mask_out, "y": y}

train_ds = FPNMILDataset(train_df, MS_DIR, LABEL_COL)
val_ds = FPNMILDataset(val_df, MS_DIR, LABEL_COL)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate)
print("Loaders ready.")

Loaders ready.


In [167]:
# Model: ISAB encoder + gated attention per scale, then gated scale aggregation + classifier
import math
import torch.nn as nn
import torch.nn.functional as F

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, dropout=0.0):
        super().__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        self.ln0 = nn.LayerNorm(dim_V) if ln else None
        self.ln1 = nn.LayerNorm(dim_V) if ln else None
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K, key_mask=None):
        Q, K, V = self.fc_q(Q), self.fc_k(K), self.fc_v(K)
        d = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(d, 2), 0)
        K_ = torch.cat(K.split(d, 2), 0)
        V_ = torch.cat(V.split(d, 2), 0)
        A = Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V)
        if key_mask is not None:
            # A: (B*num_heads, num_inds, N), key_mask: (B, N) -> expand to match A
            km = key_mask.unsqueeze(1).expand(-1, A.size(1), -1).unsqueeze(1).expand(-1, self.num_heads, -1, -1).reshape(A.size(0), A.size(1), A.size(2))
            A = A.masked_fill(km == 0, -1e9)
        A = F.softmax(A, dim=-1)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = self.ln0(O) if self.ln0 is not None else O
        O = O + F.relu(self.fc_o(O))
        O = self.ln1(O) if self.ln1 is not None else O
        return O, A

class ISAB(nn.Module):
    def __init__(self, d_model, d_hidden, num_inds, heads, ln=True):
        super().__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, d_hidden))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(d_hidden, d_model, d_hidden, heads, ln=ln)
        self.mab1 = MAB(d_model, d_hidden, d_hidden, heads, ln=ln)

    def forward(self, X, key_mask=None):
        H, _ = self.mab0(self.I.repeat(X.size(0), 1, 1), X, key_mask=key_mask)
        return self.mab1(X, H)[0]

class GatedAttn(nn.Module):
    def __init__(self, L, D, dropout=0.25):
        super().__init__()
        self.V = nn.Sequential(nn.Linear(L, D), nn.Tanh(), nn.Dropout(dropout))
        self.U = nn.Sequential(nn.Linear(L, D), nn.Sigmoid(), nn.Dropout(dropout))
        self.w = nn.Linear(D, 1)

    def forward(self, x, mask=None):
        scores = self.w(self.V(x) * self.U(x))
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
        A = F.softmax(scores, dim=1)
        return (A.transpose(1, 2) @ x).squeeze(1), A.squeeze(2)

class FPNMIL(nn.Module):
    def __init__(self, feat_dim=256, encoder_dim=256, num_scales=2, num_inds=20, heads=4, dropout=0.25):
        super().__init__()
        self.num_scales = num_scales
        self.encoders = nn.ModuleList([
            nn.Sequential(
                ISAB(feat_dim, encoder_dim, num_inds, heads, ln=True),
                ISAB(encoder_dim, encoder_dim, num_inds, heads, ln=True),
            ) for _ in range(num_scales)
        ])
        self.aggregators = nn.ModuleList([GatedAttn(encoder_dim, encoder_dim, dropout) for _ in range(num_scales)])
        self.scale_agg = GatedAttn(encoder_dim, encoder_dim, dropout)
        self.side_heads = nn.ModuleList([nn.Linear(encoder_dim, 1) for _ in range(num_scales)])
        self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(encoder_dim, 1))

    def forward(self, x_list, mask_list=None, deep_sup=True):
        scale_embs = []
        side_logits = []
        for i, x in enumerate(x_list):
            # x: (B, N, C, H, W) -> (B, N, C)
            if x.dim() == 5:
                x = x.mean(dim=(3, 4))
            mask = mask_list[i] if mask_list is not None else None
            h = self.encoders[i][0](x, key_mask=mask)
            h = self.encoders[i][1](h, key_mask=mask)
            emb, _ = self.aggregators[i](h, mask=mask)
            scale_embs.append(emb)
            side_logits.append(self.side_heads[i](emb))
        x = torch.stack(scale_embs, dim=1)
        emb, _ = self.scale_agg(x)
        logits = self.head(emb)
        if deep_sup:
            return logits, side_logits
        return logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FPNMIL(feat_dim=FPN_DIM, encoder_dim=ENCODER_DIM, num_scales=2, num_inds=20, heads=4, dropout=0.25).to(device)
print("Model:", sum(p.numel() for p in model.parameters()))

Model: 2530310


In [168]:
# Training loop
import random

def seed_all(s):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

pos = train_df[LABEL_COL].sum()
neg = len(train_df) - pos
bce_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(reduction="none")

def criterion_fn(logits, side_logits, y):
    y = y.to(device)
    w = torch.where(y > 0.5, bce_weight, torch.ones_like(bce_weight))
    loss = (w * criterion(logits, y)).mean()
    if side_logits is not None:
        for s in side_logits:
            loss = loss + 0.5 * (w.squeeze() * criterion(s, y)).mean()
    return loss

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

seed_all(SEED)
out_dir = WORK_DIR / "checkpoints"
out_dir.mkdir(parents=True, exist_ok=True)
best_auc = 0.0

In [169]:
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, f1_score

def evaluate(loader):
    model.eval()
    all_y, all_p = [], []
    with torch.no_grad():
        for batch in loader:
            x = [t.to(device) for t in batch["x"]]
            mask_list = [m.to(device) for m in batch["mask"]]
            y = batch["y"].to(device)
            logits, _ = model(x, mask_list=mask_list, deep_sup=True)
            all_y.append(y.cpu().numpy())
            all_p.append(torch.sigmoid(logits).cpu().numpy())
    y = np.vstack(all_y).ravel()
    p = np.vstack(all_p).ravel()
    pred = (p >= 0.5).astype(int)
    auc = roc_auc_score(y, p) if len(np.unique(y)) > 1 else 0.0
    return {"auc": auc, "bacc": balanced_accuracy_score(y, pred), "f1": f1_score(y, pred, zero_division=0)}

for epoch in range(EPOCHS):
    model.train()
    running = 0.0
    for batch in train_loader:
        x = [t.to(device) for t in batch["x"]]
        mask_list = [m.to(device) for m in batch["mask"]]
        y = batch["y"]
        optimizer.zero_grad()
        logits, side = model(x, mask_list=mask_list, deep_sup=True)
        loss = criterion_fn(logits, side, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        running += loss.item()
    scheduler.step()
    metrics = evaluate(val_loader)
    if metrics["auc"] > best_auc:
        best_auc = metrics["auc"]
        torch.save({"model": model.state_dict(), "epoch": epoch}, out_dir / "best.pth")
    print(f"Epoch {epoch+1}/{EPOCHS} loss={running/len(train_loader):.4f} val_auc={metrics['auc']:.4f} bacc={metrics['bacc']:.4f} f1={metrics['f1']:.4f}")

print("Done. Best checkpoint:", out_dir / "best.pth")

Epoch 1/30 loss=1.6704 val_auc=0.5633 bacc=0.5016 f1=0.0256
Epoch 2/30 loss=1.6571 val_auc=0.5775 bacc=0.5077 f1=0.3852
Epoch 3/30 loss=1.6168 val_auc=0.5905 bacc=0.5000 f1=0.0000
Epoch 4/30 loss=1.6219 val_auc=0.5834 bacc=0.4864 f1=0.3089
Epoch 5/30 loss=1.6207 val_auc=0.5887 bacc=0.5296 f1=0.1429
Epoch 6/30 loss=1.6109 val_auc=0.5891 bacc=0.5520 f1=0.4762
Epoch 7/30 loss=1.6108 val_auc=0.5900 bacc=0.5246 f1=0.1412
Epoch 8/30 loss=1.6010 val_auc=0.5888 bacc=0.5314 f1=0.6116
Epoch 9/30 loss=1.6024 val_auc=0.5960 bacc=0.5992 f1=0.6321
Epoch 10/30 loss=1.5974 val_auc=0.5978 bacc=0.5764 f1=0.5761
Epoch 11/30 loss=1.5632 val_auc=0.5952 bacc=0.5345 f1=0.1446
Epoch 12/30 loss=1.6044 val_auc=0.5959 bacc=0.6092 f1=0.6224
Epoch 13/30 loss=1.6215 val_auc=0.5943 bacc=0.5860 f1=0.6267
Epoch 14/30 loss=1.6040 val_auc=0.5969 bacc=0.5767 f1=0.5101
Epoch 15/30 loss=1.5914 val_auc=0.5989 bacc=0.6009 f1=0.6280
Epoch 16/30 loss=1.5880 val_auc=0.5985 bacc=0.6009 f1=0.6280
Epoch 17/30 loss=1.5877 val_auc=0