In [17]:
from __future__ import annotations

import sys
from importlib import reload
from pathlib import Path
from typing import Dict, Tuple, List
from itertools import zip_longest

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score

from torch.amp.grad_scaler import GradScaler
from sklearn.model_selection import StratifiedShuffleSplit

sys.path.append(str(Path().resolve().parent))

In [18]:
from src.config import SERIES_DIR, TRAIN_CSV
from src import dataloader as dl
from src import partition as pt
from src import slice_bag as sb
from src import baseline_25D as b25d

In [19]:
reload(dl)
reload(pt)
reload(sb)
reload(b25d)

<module 'src.baseline_25D' from '/home/aaron/aneurisym-detection/src/baseline_25D.py'>

In [None]:
# ----------------------------
# Config
# ----------------------------

# small, reproducible run
SEED = 117
EPOCHS = 3
BATCH_SIZE = 1      # keep 1 because bags have variable N
NUM_WORKERS = 0
LR = 1e-4
WEIGHT_DECAY = 1e-5
Z_THRESH_MM = 1.5   # split thin vs thick

# 2.5D bag params (must match model k)
K = 5
STRIDE = 2
RESIZE = (224, 224)
POOL = "mean"       # or "max"

# quick training on thin only
THIN_OUT_SPACING = (1.0, 1.0, 1.0)
THIN_PATCH = (96, 192, 192)

# optional: thick params (commented in code below)
THICK_XY_SPACING = 1.0
THICK_PATCH = (16, 192, 192)

OUTDIR = Path("checkpoints_25D")
OUTDIR.mkdir(parents=True, exist_ok=True)


In [None]:
# ----------------------------
# Helpers
# ----------------------------
def set_seed(seed: int = 1337):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def extract_targets(sample: Dict, target_cols: List[str]) -> torch.Tensor:
    # your dataset already puts "label" if present
    if "label" in sample:
        return sample["label"]
    # if not present, create zeros (won’t crash, but won’t learn)
    return torch.zeros(len(target_cols), dtype=torch.float32)

def alt_iter(dl_a, dl_b):
    for a,b in zip_longest(dl_a, dl_b, fillvalue=None):
        if a is not None: yield a, "thin"
        if b is not None: yield b, "thick"

def unpack_sample_from_batch(batch):
    # batch_size==1 loaders; keep dict structure but drop B dim
    sample = {k: (v[0] if isinstance(v, torch.Tensor) and v.dim() >= 1 and v.size(0) == 1 else v)
              for k, v in batch.items()}
    if isinstance(sample["volume"], torch.Tensor) and sample["volume"].ndim == 5:
        sample["volume"] = sample["volume"][0]  # [1,Z,Y,X]

    # unwrap non-tensor singletons (e.g., uid becomes str not [str])
    for k, v in list(sample.items()):
        if not isinstance(v, torch.Tensor) and isinstance(v, (list, tuple)) and len(v) == 1:
            sample[k] = v[0]
            
    return sample

# ----------------------------
# One pass: make bag -> forward -> loss
# ----------------------------
def forward_one(sample: Dict, model: b25d.ResNet25D, device: torch.device, criterion, target_cols: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
    # sample["volume"] is [1,Z,Y,X] float in [0,1]
    bag, _centers = sb.make_slice_bag(sample, k=K, stride=STRIDE, resize=RESIZE)
    bag = bag.to(device)  # [N, K, H, W]
    logits_inst, logits_bag = model(bag)  # [N, C], [1, C]
    y = extract_targets(sample, target_cols).to(device)  # [C]
    loss = criterion(logits_bag.squeeze(0), y)
    return loss, logits_bag.detach().cpu().squeeze(0), y.detach().cpu()

# Kaggle weights: 13x for "Aneurysm Present", 1x for others (14 total)
def weighted_auc_14(target_cols, y_true: np.ndarray, y_pred: np.ndarray):
    # y_true,y_pred: [N, C]
    aucs = []
    for j, name in enumerate(target_cols):
        tj = y_true[:, j]
        pj = y_pred[:, j]
        auc = np.nan
        # need at least 2 classes present
        if len({0.0,1.0}.intersection(set(tj.tolist()))) == 2 or (np.min(tj)==0 and np.max(tj)==1):
            try: auc = roc_auc_score(tj, pj)
            except: pass
        aucs.append(auc)

    weights = np.ones(len(target_cols), dtype=float)
    ap_idx = target_cols.index("Aneurysm Present")
    weights[ap_idx] = 13.0

    # ignore NaN aucs in both numerator and denominator
    aucs_arr = np.array(aucs, dtype=float)
    mask = ~np.isnan(aucs_arr)
    if not np.any(mask): return np.nan
    return float(np.sum(aucs_arr[mask] * weights[mask]) / np.sum(weights[mask]))


@torch.no_grad()
def eval_loader_full(loader, model, device, target_cols):
    model.eval()
    all_logits = []
    all_targets = []
    for batch in loader:
        # unpack sample exactly like train
        sample = {k: (v[0] if isinstance(v, torch.Tensor) and getattr(v, "size", lambda: [1])()[0]==1 else v) for k, v in batch.items()}
        if isinstance(sample["volume"], torch.Tensor) and sample["volume"].ndim == 5:
            sample["volume"] = sample["volume"][0]
        bag, _ = sb.make_slice_bag(sample, k=K, stride=STRIDE, resize=RESIZE)
        bag = bag.to(device)
        _, logit_bag = model(bag)             # [1,C]
        y = extract_targets(sample, target_cols)  # [C]
        all_logits.append(torch.sigmoid(logit_bag.squeeze(0)).cpu().numpy())
        all_targets.append(y.cpu().numpy())
    y_pred = np.stack(all_logits, 0)
    y_true = np.stack(all_targets, 0)
    wauc = weighted_auc_14(target_cols, y_true, y_pred)
    return wauc

@torch.no_grad()
def eval_val_union(dl_list, model, device, target_cols):
    """
    Evaluate across multiple loaders (e.g., [dl_thin_val, dl_thick_val]).
    Deduplicate by UID: if a UID appears multiple times (different spacings),
    average its predicted probabilities.
    Returns weighted AUC(14) on the UID-averaged predictions.
    """
    model.eval()
    pred_map = {}   # uid -> list of [C] preds
    targ_map = {}   # uid -> [C] target (assumed identical across duplicates)

    for loader in dl_list:
        if loader is None:
            continue
        for batch in loader:
            sample = unpack_sample_from_batch(batch)
            uid = sample["uid"]
            bag, _ = sb.make_slice_bag(sample, k=K, stride=STRIDE, resize=RESIZE)
            bag = bag.to(device)
            _, logit_bag = model(bag)                 # [1,C]
            probs = torch.sigmoid(logit_bag.squeeze(0)).cpu().numpy()  # [C]
            y = extract_targets(sample, target_cols).cpu().numpy()      # [C]

            pred_map.setdefault(uid, []).append(probs)
            targ_map[uid] = y

    if not pred_map:
        return float("nan")

    # UID-level averaging (if seen in multiple loaders)
    uids = sorted(pred_map.keys())
    y_pred = np.stack([np.mean(pred_map[u], axis=0) for u in uids], axis=0)  # [N,C]
    y_true = np.stack([targ_map[u] for u in uids], axis=0)                   # [N,C]

    per_label_aucs = []
    for j, name in enumerate(target_cols):
        yj, pj = y_true[:, j], y_pred[:, j]
        aucj = np.nan
        if len(np.unique(yj)) > 1:
            try:
                aucj = roc_auc_score(yj, pj)
            except Exception:
                pass
        per_label_aucs.append((name, aucj))

    # Compact print: top-3 / bottom-3 (ignoring NaNs)
    valid = [(n,a) for n,a in per_label_aucs if not np.isnan(a)]
    valid.sort(key=lambda x: x[1])
    if valid:
        worst = ", ".join(f"{n}:{a:.2f}" for n,a in valid[:3])
        best  = ", ".join(f"{n}:{a:.2f}" for n,a in valid[-3:])
        print(f"Per-label AUCs — worst: {worst} | best: {best}")

    ap_idx = target_cols.index("Aneurysm Present")
    try:
        ap_auc = roc_auc_score(y_true[:, ap_idx], y_pred[:, ap_idx]) if len(set(y_true[:, ap_idx])) > 1 else float("nan")
    except Exception:
        ap_auc = float("nan")
        
    if not np.isnan(ap_auc):
        print(f"AP AUC (val-union)={ap_auc:.3f}")


    return weighted_auc_14(target_cols, y_true, y_pred)


def compute_pos_weight_from_labels(labels_map, uids, cap: float = 50.0, laplace: float = 1.0, device="cpu"):
    """
    pos_weight_j = (neg_j + laplace) / (pos_j + laplace), then clipped to [1, cap].
    Use laplace=1.0 for stronger smoothing on tiny datasets.
    """
    labels = np.stack([labels_map[u] for u in uids if u in labels_map], axis=0).astype(np.float32)
    pos = labels.sum(axis=0)
    neg = labels.shape[0] - pos
    pw = (neg + laplace) / (pos + laplace)          # Laplace smoothing
    pw = np.clip(pw, 1.0, cap)                      # avoid extreme weights
    return torch.as_tensor(pw, dtype=torch.float32, device=device)

# data split


# ----------------------------
# Train / Eval loops
# ----------------------------
def run_epoch(
    loader,                         # DataLoader OR a tuple/list of (dl_thin, dl_thick)
    model: b25d.ResNet25D,
    device,
    optimizer,
    scaler,
    train: bool,
    target_cols: List[str],
    *,
    criterion: nn.Module,          # <-- pass in (e.g., BCEWithLogitsLoss with pos_weight)
) -> Tuple[float, float | None]:
    """
    If `loader` is a single DataLoader -> iterate it.
    If `loader` is (dl_thin, dl_thick) or [dl_thin, dl_thick] -> alternate via alt_iter.
    """
    model.train(train)
    running = 0.0
    preds_ap, targs_ap = [], []
    idx_ap = target_cols.index("Aneurysm Present")

    # choose iterator
    if isinstance(loader, (tuple, list)) and len(loader) == 2:
        batch_iter = alt_iter(loader[0], loader[1])   # yields (batch, "thin"/"thick")
        num_steps = len(loader[0]) + len(loader[1])
        def extract_batch(item): 
            batch, _kind = item
            return batch
    else:
        batch_iter = loader
        num_steps = len(loader)
        def extract_batch(item): 
            return item

    for item in batch_iter:
        batch = extract_batch(item)
        sample = unpack_sample_from_batch(batch)

        if train:
            optimizer.zero_grad(set_to_none=True)
            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")):
                loss, logits_bag_cpu, y_cpu = forward_one(sample, model, device, criterion, target_cols)
            # backprop with AMP
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")):
                loss, logits_bag_cpu, y_cpu = forward_one(sample, model, device, criterion, target_cols)

        running += float(loss.item())
        # AP AUC bookkeeping (logits_bag_cpu / y_cpu are already on CPU)
        preds_ap.append(torch.sigmoid(torch.as_tensor(logits_bag_cpu[idx_ap])).item())
        targs_ap.append(float(y_cpu[idx_ap].item()))

    avg_loss = running / max(1, num_steps)
    auc_ap = None
    try:
        from sklearn.metrics import roc_auc_score
        auc_ap = roc_auc_score(targs_ap, preds_ap) if len(set(targs_ap)) > 1 else None
    except Exception:
        auc_ap = None
    return avg_loss, auc_ap


In [26]:
set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# labels map
target_cols = list(dl.TARGET_COLS_DEFAULT)
labels_map = dl.build_labels_map_from_csv(TRAIN_CSV, target_cols=target_cols)

meta = pd.read_csv(TRAIN_CSV, usecols=["SeriesInstanceUID","Modality","Aneurysm Present"])
meta["SeriesInstanceUID"] = meta["SeriesInstanceUID"].astype(str)

# restrict to UIDs that actually exist under SERIES_DIR (thin or thick)
thin_dirs, thick_dirs = pt.split_series_by_z(SERIES_DIR, z_thresh_mm=Z_THRESH_MM)
have = {p.name for p in thin_dirs + thick_dirs}
meta = meta[meta["SeriesInstanceUID"].isin(have)].reset_index(drop=True)

uids = meta["SeriesInstanceUID"].tolist()
mods = meta["Modality"].astype(str).tolist()
aps  = meta["Aneurysm Present"].astype(int).tolist()

train_uids, val_uids = safe_stratified_split(uids, mods, aps, val_frac=0.25, seed=SEED)
print(f"train={len(train_uids)}, val={len(val_uids)}; AP counts -> train:{sum(labels_map[u][-1] for u in train_uids)}, val:{sum(labels_map[u][-1] for u in val_uids)}")

# --- build loaders with your existing function, using uid_filter ---
dl_thin_tr, dl_thick_tr, ds_thin_tr, ds_thick_tr = pt.make_loaders(
    SERIES_DIR, labels_map,
    z_thresh_mm=Z_THRESH_MM,
    thin_out=THIN_OUT_SPACING,
    thick_xy=THICK_XY_SPACING,
    thin_patch=THIN_PATCH,
    thick_patch=THICK_PATCH,
    batch_size_thin=BATCH_SIZE,
    batch_size_thick=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    uid_filter=set(train_uids),
)

dl_thin_val, dl_thick_val, ds_thin_val, ds_thick_val = pt.make_loaders(
    SERIES_DIR, labels_map,
    z_thresh_mm=Z_THRESH_MM,
    thin_out=THIN_OUT_SPACING,
    thick_xy=THICK_XY_SPACING,
    thin_patch=THIN_PATCH,
    thick_patch=THICK_PATCH,
    batch_size_thin=BATCH_SIZE,
    batch_size_thick=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    uid_filter=set(val_uids),
)

# model
model = b25d.ResNet25D(num_classes=len(target_cols), k=K, pool=POOL, pretrained=True).to(device)

# opt / amp
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = GradScaler("cuda", enabled=(device.type == "cuda"))

pos_weight = compute_pos_weight_from_labels(labels_map, train_uids, cap=50.0, laplace=1.0, device=str(device))
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

print(f"train thin={len(ds_thin_tr)} thick={len(ds_thick_tr)} | val thin={len(ds_thin_val)} thick={len(ds_thick_val)}")

for epoch in range(1, EPOCHS + 1):
    tr_loss, tr_auc = run_epoch((dl_thin_tr, dl_thick_tr), model, device, optimizer, scaler, True, target_cols, criterion=criterion)
    va_loss, va_auc = run_epoch(dl_thin_val, model, device, optimizer, scaler, False, target_cols, criterion=criterion)

    msg = f"[Epoch {epoch:02d}] train_loss={tr_loss:.4f}"
    if tr_auc is not None: msg += f" | train_AUC(AP)={tr_auc:.3f}"
    msg += f" || val_loss(thin)={va_loss:.4f}"
    if va_auc is not None: msg += f" | val_AUC(AP)={va_auc:.3f}"
    print(msg)

    # wauc = eval_loader_full(dl_thin_val, model, device, target_cols)
    # print(f"weightedAUC(14, val-thin)={wauc:.3f}")

    wauc = eval_val_union([dl_thin_val, dl_thick_val], model, device, target_cols)
    print(f"weightedAUC(14, val-union)={wauc:.3f}")

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "config": {
            "K": K, "POOL": POOL, "RESIZE": RESIZE,
            "THIN_OUT_SPACING": THIN_OUT_SPACING, "THIN_PATCH": THIN_PATCH,
            "THICK_XY_SPACING": THICK_XY_SPACING, "THICK_PATCH": THICK_PATCH,
            "STRIDE": STRIDE, "target_cols": target_cols,
        },
    }, OUTDIR / f"resnet25d_epoch{epoch:02d}.pt")

print("Done ✅")

Device: cuda
train=7, val=3; AP counts -> train:2.0, val:1.0


  


train thin=7 thick=0 | val thin=1 thick=2
[Epoch 01] train_loss=0.7227 | train_AUC(AP)=0.300 || val_loss(thin)=0.8856
AP AUC (val-union)=1.000
weightedAUC(14, val-union)=1.000
[Epoch 02] train_loss=0.5591 | train_AUC(AP)=0.500 || val_loss(thin)=0.8075
AP AUC (val-union)=0.000
weightedAUC(14, val-union)=0.071
[Epoch 03] train_loss=0.4250 | train_AUC(AP)=0.700 || val_loss(thin)=0.7558
AP AUC (val-union)=1.000
weightedAUC(14, val-union)=1.000
Done ✅


In [None]:
# Device: cuda
# Thin series: 8 | Thick series: 2
# Training on THIN only for now (8 iters/epoch)
# [Epoch 01] train_loss=0.6521 | train_AUC(AP)=0.533 || val_loss=0.6259 | val_AUC(AP)=0.733
# [Epoch 02] train_loss=0.4617 | train_AUC(AP)=0.733 || val_loss=0.3998 | val_AUC(AP)=0.533
# [Epoch 03] train_loss=0.3232 | train_AUC(AP)=0.667 || val_loss=0.3329 | val_AUC(AP)=0.400
# Done ✅

In [None]:
# Device: cuda
# Thin series: 8 | Thick series: 2
# Training on THIN only for now (8 iters/epoch)
# [Epoch 01] train_loss=0.6492 | train_AUC(AP)=0.633 || val_loss=0.6287 | val_AUC(AP)=0.733
# weightedAUC(14)=0.599
# [Epoch 02] train_loss=0.4638 | train_AUC(AP)=0.733 || val_loss=0.5045 | val_AUC(AP)=0.600
# weightedAUC(14)=0.601
# [Epoch 03] train_loss=0.3269 | train_AUC(AP)=0.333 || val_loss=0.4203 | val_AUC(AP)=0.500
# weightedAUC(14)=0.310
# Done ✅

In [None]:
# Device: cuda
# Thin series: 8 | Thick series: 2
# Training by alternating THIN and THICK batches
# [Epoch 01] train_loss=0.7400 | train_AUC(AP)=0.381 || val_loss(thin)=0.7221 | val_AUC(AP)=0.000
# weightedAUC(14, thin)=0.561
# [Epoch 02] train_loss=0.5189 | train_AUC(AP)=0.714 || val_loss(thin)=0.5326 | val_AUC(AP)=0.800
# weightedAUC(14, thin)=0.327
# [Epoch 03] train_loss=0.4395 | train_AUC(AP)=0.238 || val_loss(thin)=0.4649 | val_AUC(AP)=0.600
# weightedAUC(14, thin)=0.428
# Done ✅