# Kaggle: Offline feature extraction (PI-CAI preprocessed)

1. **Add input:** Two preprocessed datasets: one with folds 0,1,2 (e.g. pi-cai-preprocess) and one with folds 3,4 (e.g. pi-cai-preprocess-2). Each has `nnUNet_raw_data_fold<N>/.../imagesTr/`.
2. **Optional:** Add the PI-CAI dataset with `Metadata(for ISUP).csv` to get real labels; otherwise labels are built from the case list only (cs_pca=0).
3. **Set** `PREPROCESSED_ROOTS` in the paths cell to your two input paths (folds 0,1,2 and 3,4).
4. **Crop (optional):** Run "Batch crop all cases" to save cropped T2W/ADC/HBV to `CROPPED_ROOT`; otherwise extraction crops on the fly when `MASKS_DIR` is set.
5. Run extraction. Features are written to `FEAT_DIR`. **Each slice** (one 2D slice from the 3 modalities) = **one patch**; **each case** = **one bag** for FPN-MIL.

**Prostate crop:** Clone picai_labels; set `MASKS_DIR` to Bosma22b. Then either run "Batch crop all cases" and use `CROPPED_ROOT` for extraction, or run extraction with per-fold roots from CSV + `MASKS_DIR` (crop on the fly).

In [1]:
!pip install -q SimpleITK

In [None]:
# Clone picai_labels for prostate whole-gland masks (Option 1: use on Kaggle)
import os
if not os.path.exists("/kaggle/working/picai_labels"):
    !git clone --depth 1 https://github.com/DIAGNijmegen/picai_labels.git /kaggle/working/picai_labels
else:
    print("picai_labels already cloned at /kaggle/working/picai_labels")

In [None]:
# Two preprocessed datasets: (path, list of folds). Process folds 0,1,2 from first and 3,4 from second together.
PREPROCESSED_ROOTS = [
    ("/kaggle/input/notebooks/sananiroomand/pi-cai-preprocess", [0, 1, 2]),
    ("/kaggle/input/notebooks/sananiroomand/pi-cai-preprocess-2", [3, 4]),
]
LABELS_CSV_PATH = "/kaggle/working/picai_labels.csv"
FEAT_DIR = "/kaggle/working/picai_extracted_features"
# Crop to prostate ROI (whole-gland masks from picai_labels; run clone cell first)
MASKS_DIR = "/kaggle/working/picai_labels/anatomical_delineations/whole_gland/AI/Bosma22b"   # set to None to skip crop
CROP_MARGIN = 2   # voxels around prostate bbox
# After "Batch crop all cases", use this as input to extraction (optional; else extraction crops on the fly)
CROPPED_ROOT = "/kaggle/working/picai_roi_crops"

from pathlib import Path
import pandas as pd

case_to_fold = {}
fold_to_root = {}
for root_path, folds in PREPROCESSED_ROOTS:
    root = Path(root_path)
    if not root.exists():
        print(f"WARNING: Preprocessed root not found (skipped): {root_path}. Add this dataset as input for folds {folds}.")
        continue
    for fold in folds:
        fold_to_root[fold] = root
        images_tr = root / f"nnUNet_raw_data_fold{fold}" / f"Task2201_picai_fold{fold}" / "imagesTr"
        if not images_tr.exists():
            continue
        for f in images_tr.glob("*_0000.nii.gz"):
            case_id = f.name.replace("_0000.nii.gz", "")
            case_to_fold[case_id] = fold

FOLDS = sorted(fold_to_root.keys())
print("Discovered folds:", FOLDS)
if len(FOLDS) < 5:
    print("WARNING: Only folds", FOLDS, "found. Add the second preprocessed dataset (folds 3,4) as input to process all 5 folds.")
cases_per_fold = {f: sum(1 for c, fold in case_to_fold.items() if fold == f) for f in FOLDS}
print("Cases per fold:", cases_per_fold)
if not case_to_fold:
    raise FileNotFoundError("No *_0000.nii.gz found under any PREPROCESSED_ROOTS. Check paths and folder layout.")

# Try to load PI-CAI metadata for real cs_pca labels
# METADATA_PATHS = [
#     Path("/kaggle/input/prostate-cancer-pi-cai-dataset/Metadata(for ISUP).csv"),
#     Path("/kaggle/input/prostate-cancer-pi-cai-dataset/Metadata(for ISUP without lesion info).csv"),
# ]
# df = None
# for meta_path in METADATA_PATHS:
#     if meta_path.exists():
#         df_meta = pd.read_csv(meta_path)
#         pid_col = next((c for c in df_meta.columns if "patient" in c.lower() or c.lower() == "id"), df_meta.columns[0])
#         isup_col = next((c for c in df_meta.columns if "isup" in c.lower()), None)
#         study_col = next((c for c in df_meta.columns if "study" in c.lower()), None)
#         df_meta["patient_id"] = df_meta[pid_col].astype(str)
#         df_meta["isup"] = pd.to_numeric(df_meta[isup_col], errors="coerce") if isup_col else 0
#         df_meta["case_id"] = df_meta["patient_id"] + "_" + df_meta[study_col].astype(str) if study_col else df_meta["patient_id"]
#         df_meta["cs_pca"] = (df_meta["isup"] >= 2).astype(int)
#         df_meta["fold"] = df_meta["case_id"].map(case_to_fold)
#         df = df_meta[df_meta["case_id"].isin(case_to_fold)][["patient_id", "case_id", "cs_pca", "fold"]].drop_duplicates()
#         df = df.rename(columns={"case_id": "image_id"})
#         break

# if df is None:
#     df = pd.DataFrame([
#         {"patient_id": cid.split("_")[0], "image_id": cid, "cs_pca": 0, "fold": fold}
#         for cid, fold in case_to_fold.items()
#     ])
#     print("No metadata found; using dummy labels (cs_pca=0). Add PI-CAI dataset for real labels.")
# --- Automatically find PI-CAI metadata file ---
from pathlib import Path

candidates = [
    "Metadata(for ISUP).csv",
    "Metadata(for ISUP without lesion info).csv",
    "Metadata with lesion info.csv",
    "Metadata without lesion info.csv",
]

meta_path = None
for name in candidates:
    hits = list(Path("/kaggle/input").rglob(name))
    if hits:
        meta_path = hits[0]
        break

print("Using metadata:", meta_path)
df = None

if meta_path is not None:
    df_meta = pd.read_csv(meta_path)

    # identify columns
    pid_col = next((c for c in df_meta.columns if "patient" in c.lower() or c.lower() == "id"), df_meta.columns[0])
    isup_col = next((c for c in df_meta.columns if "isup" in c.lower()), None)
    study_col = next((c for c in df_meta.columns if "study" in c.lower()), None)

    # build ids
    df_meta["patient_id"] = df_meta[pid_col].astype(str)
    df_meta["isup"] = pd.to_numeric(df_meta[isup_col], errors="coerce") if isup_col else 0
    df_meta["case_id"] = df_meta["patient_id"] + "_" + df_meta[study_col].astype(str) if study_col else df_meta["patient_id"]

    # label (csPCa proxy from ISUP)
    df_meta["cs_pca"] = (df_meta["isup"] >= 2).astype(int)

    # map to folds discovered from preprocessed data
    df_meta["fold"] = df_meta["case_id"].map(case_to_fold)

    # keep only cases that exist in preprocessed set
    df = (
        df_meta[df_meta["case_id"].isin(case_to_fold)]
        [["patient_id", "case_id", "cs_pca", "fold"]]
        .drop_duplicates()
        .rename(columns={"case_id": "image_id"})
        .reset_index(drop=True)
    )

# fallback if still None or empty (e.g., ID mismatch)
if df is None or len(df) == 0:
    df = pd.DataFrame([
        {"patient_id": cid.split("_")[0], "image_id": cid, "cs_pca": 0, "fold": fold}
        for cid, fold in case_to_fold.items()
    ])
    print("WARNING: metadata did not match preprocessed case_ids; using dummy labels (cs_pca=0).")
    # fallback df creation...
df["preprocessed_root"] = df["fold"].map(fold_to_root).astype(str)
df.to_csv(LABELS_CSV_PATH, index=False)
print(f"Saved {LABELS_CSV_PATH} with {len(df)} rows. Fold counts:", df["fold"].value_counts().sort_index().to_dict())

KeyboardInterrupt: 

In [None]:
df.to_csv(LABELS_CSV_PATH, index=False)
print("cs_pca value counts (incl NaN):")
print(df["cs_pca"].value_counts(dropna=False))

all_zero = (df["cs_pca"] == 0).all()
any_one = (df["cs_pca"] == 1).any()
print("All rows cs_pca==0 ?", all_zero)
print("Any cs_pca==1 ?", any_one)

In [None]:
# Print dimensions of preprocessed volumes (run after paths + labels cell)
# Preprocessed slice size: 640×640 (H×W). Each slice is resized to 224×224 for patch extraction.
# nnUNet layout: 0000=T2W, 0001=ADC, 0002=HBV. Extractor uses all 3; notebook viz uses T2W only.
import numpy as np
from pathlib import Path
try:
    import SimpleITK as sitk
    def _load_nii(path):
        return np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(path))), dtype=np.float32)
except ImportError:
    import nibabel as nib
    def _load_nii(path):
        a = nib.load(str(path)).get_fdata()
        return np.transpose(a, (2, 0, 1)).astype(np.float32)

MODALITIES = [("0000", "T2W"), ("0001", "ADC"), ("0002", "HBV")]
N_CASES = 10   # number of cases to print dimensions for
count = 0
for _, row in df.iterrows():
    if count >= N_CASES:
        break
    cid, f = str(row["image_id"]), int(row["fold"])
    root = fold_to_root[f]
    tr = root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
    if not (tr / f"{cid}_0000.nii.gz").exists():
        continue
    count += 1
    print(f"Case {cid} (fold {f}):")
    shapes = []
    for suffix, name in MODALITIES:
        p = tr / f"{cid}_{suffix}.nii.gz"
        if p.exists():
            arr = _load_nii(p)
            D, H, W = arr.shape[0], arr.shape[1], arr.shape[2]
            print(f"  {name} (_{suffix}) shape = (D, H, W) = ({D}, {H}, {W})  [slices, height, width]")
            shapes.append((D, H, W))
        else:
            print(f"  {name} ({suffix}): not found")
    if len(shapes) == 3:
        D, H, W = shapes[0]
        print(f"  Full volume (3 mods stacked): (C, D, H, W) = (3, {D}, {H}, {W})")
    print()
if count == 0:
    print("No case found. Run paths cell and ensure PREPROCESSED_ROOTS have data.")
else:
    print(f"(Feature extraction uses all 3 modalities; notebook visualizations use T2W only. Showed {count} cases.)")

## Visualize patch extent (one patch per slice)

The extractor takes **one patch per slice**: each 2D slice (640×640) is resized/padded to 224×224. The green box shows the full slice (that whole region becomes one patch).

In [None]:
# Visualize patch bounding box for a sample case (run after paths + labels cell)
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

PATCH_SIZE = 224   # must match extractor
# One patch per slice (full slice resized to 224×224)

try:
    import SimpleITK as sitk
    def _load_nii(path):
        return np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(path))), dtype=np.float32)
except ImportError:
    import nibabel as nib
    def _load_nii(path):
        a = nib.load(str(path)).get_fdata()
        return np.transpose(a, (2, 0, 1)).astype(np.float32)

case_id, fold = None, None
for _, row in df.head(20).iterrows():
    cid, f = str(row["image_id"]), int(row["fold"])
    root = fold_to_root[f]
    tr = root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
    if (tr / f"{cid}_0000.nii.gz").exists():
        case_id, fold = cid, f
        break
if case_id is None:
    print("No case found under PREPROCESSED_ROOTS. Run paths cell and ensure data exists.")
else:
    root = fold_to_root[fold]
    tr = root / f"nnUNet_raw_data_fold{fold}" / f"Task2201_picai_fold{fold}" / "imagesTr"
    t2w = _load_nii(tr / f"{case_id}_0000.nii.gz")  # shape (D, H, W)
    n_slices = t2w.shape[0]
    H, W = t2w.shape[1], t2w.shape[2]
    indices = [max(0, n_slices//2 - 2), n_slices//2 - 1, n_slices//2, min(n_slices-1, n_slices//2 + 1), min(n_slices-1, n_slices//2 + 2)]
    indices = sorted(set(indices))[:5]
    fig, axes = plt.subplots(1, len(indices), figsize=(4*len(indices), 4))
    if len(indices) == 1:
        axes = [axes]
    for ax, zi in zip(axes, indices):
        sl = t2w[zi]  # shape (H, W)
        ax.imshow(sl.T, cmap="gray", origin="lower", extent=[0, W, 0, H], aspect="equal")
        # One patch per slice: full slice → resized to 224×224
        r = plt.Rectangle((0, 0), W, H, fill=False, edgecolor="lime", linewidth=2)
        ax.add_patch(r)
        ax.set_xlim(0, W)
        ax.set_ylim(0, H)
        ax.set_title(f"slice {zi} → one patch")
        ax.set_axis_off()
    plt.suptitle(f"Case {case_id} — one 224×224 patch per slice")
    plt.tight_layout()
    plt.show()

## Visualize crop: before (full volume) vs after (prostate ROI)

Requires `MASKS_DIR` to be set. Three rows: **Before** = full slice with prostate mask contour (green) and ROI box (cyan); **After** = cropped volume; **Mask** = whole-gland mask for the same slices.

## Batch crop all cases to prostate ROI

Crop **all** cases (T2W, ADC, HBV) to the prostate bounding box and save to `CROPPED_ROOT` with the same fold layout. Run this once (after paths + clone + df). Then run extraction: **each slice** (one 2D slice from the 3-channel volume) = **one patch**, and **each case** = **one bag** for the FPN-MIL model.

In [None]:
# Batch crop all cases: save cropped T2W, ADC, HBV to CROPPED_ROOT (same fold layout)
# Run after paths cell and clone; requires MASKS_DIR and df.
import numpy as np
from pathlib import Path

try:
    import SimpleITK as sitk
except ImportError:
    sitk = None

def resample_mask_to_ref(mask_path, ref_sitk):
    mask = sitk.ReadImage(str(mask_path))
    res = sitk.Resample(mask, ref_sitk, sitk.Transform(3, sitk.sitkIdentity), sitk.sitkNearestNeighbor, 0.0)
    return np.asarray(sitk.GetArrayFromImage(res))

def bbox_3d(mask, margin=0):
    inds = np.where(mask > 0)
    if len(inds[0]) == 0:
        return None
    zmin, zmax = inds[0].min() - margin, inds[0].max() + 1 + margin
    ymin, ymax = inds[1].min() - margin, inds[1].max() + 1 + margin
    xmin, xmax = inds[2].min() - margin, inds[2].max() + 1 + margin
    zmin, ymin, xmin = max(0, zmin), max(0, ymin), max(0, xmin)
    zmax = min(mask.shape[0], zmax)
    ymax = min(mask.shape[1], ymax)
    xmax = min(mask.shape[2], xmax)
    return (zmin, zmax), (ymin, ymax), (xmin, xmax)

out_root = Path(CROPPED_ROOT)
masks_dir = Path(MASKS_DIR) if MASKS_DIR else None

if masks_dir is None or not masks_dir.exists():
    print("Set MASKS_DIR (and run clone cell) to batch crop. Skipping.")
elif df is None or df.empty:
    print("Run paths cell first so df exists. Skipping.")
else:
    out_root.mkdir(parents=True, exist_ok=True)
    done, skip_no_mask, err = 0, 0, 0
    for _, row in df.iterrows():
        cid, f = str(row["image_id"]), int(row["fold"])
        root = fold_to_root[f]
        tr = root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
        out_tr = out_root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
        mask_path = masks_dir / f"{cid}.nii.gz"
        if not (tr / f"{cid}_0000.nii.gz").exists():
            continue
        if not mask_path.exists():
            skip_no_mask += 1
            continue
        out_tr.mkdir(parents=True, exist_ok=True)
        if (out_tr / f"{cid}_0000.nii.gz").exists():
            done += 1
            continue
        try:
            ref = sitk.ReadImage(str(tr / f"{cid}_0000.nii.gz"))
            t2w = np.asarray(sitk.GetArrayFromImage(ref), dtype=np.float32)
            adc = np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(tr / f"{cid}_0001.nii.gz"))), dtype=np.float32)
            hbv = np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(tr / f"{cid}_0002.nii.gz"))), dtype=np.float32)
            mask_arr = resample_mask_to_ref(mask_path, ref)
            box = bbox_3d(mask_arr, margin=CROP_MARGIN)
            if box is None:
                skip_no_mask += 1
                continue
            (zmin, zmax), (ymin, ymax), (xmin, xmax) = box
            t2w_c = t2w[zmin:zmax, ymin:ymax, xmin:xmax]
            adc_c = adc[zmin:zmax, ymin:ymax, xmin:xmax]
            hbv_c = hbv[zmin:zmax, ymin:ymax, xmin:xmax]
            origin = list(ref.GetOrigin())
            spacing = ref.GetSpacing()
            direction = ref.GetDirection()
            origin[0] += xmin * spacing[0]
            origin[1] += ymin * spacing[1]
            origin[2] += zmin * spacing[2]
            for name, arr in [("_0000", t2w_c), ("_0001", adc_c), ("_0002", hbv_c)]:
                img = sitk.GetImageFromArray(arr)
                img.SetSpacing(spacing)
                img.SetOrigin(origin)
                img.SetDirection(direction)
                sitk.WriteImage(img, str(out_tr / f"{cid}{name}.nii.gz"))
            done += 1
            if done % 50 == 0:
                print("Cropped", done, "cases...")
        except Exception as e:
            err += 1
            print("Error", cid, e)
    print(f"Batch crop done. Saved to {out_root}. Cropped: {done}, skipped (no mask): {skip_no_mask}, errors: {err}")

## Show cropped images — all alignments

**2 cases × 6 slices** through the cropped stack; **every slice** has the mask contour (green) overlaid so you can verify alignment from top to bottom. Same convention: array `(H,W)` with `[y,x]`, no `.T`. If the green outline follows the prostate on every slice, the crop and mask are aligned.

In [None]:
# Show cropped images (from CROPPED_ROOT if available, else crop in memory). Alignment: same (H,W) [y,x], no .T.
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

try:
    import SimpleITK as sitk
except ImportError:
    sitk = None

def _load_nii(path):
    if sitk:
        return np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(path))), dtype=np.float32)
    import nibabel as nib
    a = nib.load(str(path)).get_fdata()
    return np.transpose(a, (2, 0, 1)).astype(np.float32)

def resample_mask_to_ref(mask_path, ref_sitk):
    mask = sitk.ReadImage(str(mask_path))
    res = sitk.Resample(mask, ref_sitk, sitk.Transform(3, sitk.sitkIdentity), sitk.sitkNearestNeighbor, 0.0)
    return np.asarray(sitk.GetArrayFromImage(res))

def bbox_3d(mask, margin=0):
    inds = np.where(mask > 0)
    if len(inds[0]) == 0:
        return None
    zmin, zmax = inds[0].min() - margin, inds[0].max() + 1 + margin
    ymin, ymax = inds[1].min() - margin, inds[1].max() + 1 + margin
    xmin, xmax = inds[2].min() - margin, inds[2].max() + 1 + margin
    zmin, ymin, xmin = max(0, zmin), max(0, ymin), max(0, xmin)
    zmax = min(mask.shape[0], zmax)
    ymax = min(mask.shape[1], ymax)
    xmax = min(mask.shape[2], xmax)
    return (zmin, zmax), (ymin, ymax), (xmin, xmax)

cropped_root = Path(CROPPED_ROOT)
masks_dir = Path(MASKS_DIR) if MASKS_DIR else None

# Show ALL alignments: 2 cases × 6 cropped slices each, every slice with mask contour overlay (same [y,x], no .T)
N_SLICES_PER_CASE = 6
if masks_dir and masks_dir.exists() and df is not None and not df.empty and sitk:
    cases_to_show = []
    for _, row in df.head(40).iterrows():
        cid, f = str(row["image_id"]), int(row["fold"])
        root = fold_to_root[f]
        tr = root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
        if (tr / f"{cid}_0000.nii.gz").exists() and (masks_dir / f"{cid}.nii.gz").exists():
            cases_to_show.append((cid, f))
            if len(cases_to_show) >= 2:
                break
    if not cases_to_show:
        print("No case with both preprocessed data and mask found.")
    else:
        n_cases = len(cases_to_show)
        n_cols = N_SLICES_PER_CASE
        fig, axes = plt.subplots(n_cases, n_cols, figsize=(2.5 * n_cols, 2.5 * n_cases))
        if n_cases == 1:
            axes = axes.reshape(1, -1)
        for row_idx, (cid, f) in enumerate(cases_to_show):
            root = fold_to_root[f]
            tr = root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
            ref = sitk.ReadImage(str(tr / f"{cid}_0000.nii.gz"))
            t2w_full = _load_nii(tr / f"{cid}_0000.nii.gz")
            mask_arr = resample_mask_to_ref(masks_dir / f"{cid}.nii.gz", ref)
            box = bbox_3d(mask_arr, margin=CROP_MARGIN)
            if box is None:
                continue
            (zmin, zmax), (ymin, ymax), (xmin, xmax) = box
            t2w_crop = t2w_full[zmin:zmax, ymin:ymax, xmin:xmax]
            mask_crop = mask_arr[zmin:zmax, ymin:ymax, xmin:xmax]
            Dc, Hc, Wc = t2w_crop.shape
            indices = [int(round(i * (Dc - 1) / (n_cols - 1))) for i in range(n_cols)] if Dc > 1 else [0] * n_cols
            indices = [min(z, Dc - 1) for z in indices]
            x_1d = np.arange(Wc)
            y_1d = np.arange(Hc)
            for col_idx, zi in enumerate(indices):
                sl = t2w_crop[zi]
                mask_sl = mask_crop[zi]
                ax = axes[row_idx, col_idx]
                ax.imshow(sl, cmap="gray", origin="lower", extent=[0, Wc, 0, Hc], aspect="equal")
                ax.contour(x_1d, y_1d, mask_sl, levels=[0.5], colors="lime", linewidths=1.5)
                ax.set_title(f"{cid} z={zi}")
                ax.set_axis_off()
        plt.suptitle("All alignments: cropped slice + mask contour (same [y,x], no .T) — verify green outline on prostate")
        plt.tight_layout()
        plt.show()
else:
    print("Set MASKS_DIR and run paths cell to show alignments.")

In [None]:
# Before/after prostate crop (run after paths cell; set MASKS_DIR to use)
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

try:
    import SimpleITK as sitk
except ImportError:
    sitk = None

def _load_nii(path):
    if sitk:
        return np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(path))), dtype=np.float32)
    import nibabel as nib
    a = nib.load(str(path)).get_fdata()
    return np.transpose(a, (2, 0, 1)).astype(np.float32)

def resample_mask_to_ref(mask_path, ref_sitk):
    mask = sitk.ReadImage(str(mask_path))
    res = sitk.Resample(mask, ref_sitk, sitk.Transform(3, sitk.sitkIdentity), sitk.sitkNearestNeighbor, 0.0)
    return np.asarray(sitk.GetArrayFromImage(res))

def bbox_3d(mask, margin=0):
    inds = np.where(mask > 0)
    if len(inds[0]) == 0:
        return None
    zmin, zmax = inds[0].min() - margin, inds[0].max() + 1 + margin
    ymin, ymax = inds[1].min() - margin, inds[1].max() + 1 + margin
    xmin, xmax = inds[2].min() - margin, inds[2].max() + 1 + margin
    zmin, ymin, xmin = max(0, zmin), max(0, ymin), max(0, xmin)
    zmax = min(mask.shape[0], zmax)
    ymax = min(mask.shape[1], ymax)
    xmax = min(mask.shape[2], xmax)
    return (zmin, zmax), (ymin, ymax), (xmin, xmax)

masks_dir = Path(MASKS_DIR) if MASKS_DIR else None
if masks_dir is None or not masks_dir.exists():
    print("Set MASKS_DIR to the prostate whole-gland masks folder (e.g. picai_labels/.../Bosma22b) to visualize crop.")
elif df is None or df.empty:
    print("Run the paths cell first so df is built (case list and labels).")
else:
    case_id, fold = None, None
    for _, row in df.head(30).iterrows():
        cid, f = str(row["image_id"]), int(row["fold"])
        root = fold_to_root[f]
        tr = root / f"nnUNet_raw_data_fold{f}" / f"Task2201_picai_fold{f}" / "imagesTr"
        mask_path = masks_dir / f"{cid}.nii.gz"
        if (tr / f"{cid}_0000.nii.gz").exists() and mask_path.exists():
            case_id, fold = cid, f
            break
    if case_id is None:
        print("No case found with both preprocessed data and mask.")
    else:
        root = fold_to_root[fold]
        tr = root / f"nnUNet_raw_data_fold{fold}" / f"Task2201_picai_fold{fold}" / "imagesTr"
        ref = sitk.ReadImage(str(tr / f"{case_id}_0000.nii.gz"))
        t2w_full = _load_nii(tr / f"{case_id}_0000.nii.gz")
        mask_arr = resample_mask_to_ref(masks_dir / f"{case_id}.nii.gz", ref)
        box = bbox_3d(mask_arr, margin=CROP_MARGIN)
        if box is None:
            print("Empty mask; cannot show crop.")
        else:
            (zmin, zmax), (ymin, ymax), (xmin, xmax) = box
            t2w_crop = t2w_full[zmin:zmax, ymin:ymax, xmin:xmax]
            D, H, W = t2w_full.shape
            Dc, Hc, Wc = t2w_crop.shape
            # Show 3 slices: before (full + bbox + mask contour), after (crop), mask
            indices_full = [max(0, D//2 - 1), D//2, min(D-1, D//2 + 1)]
            fig, axes = plt.subplots(3, 3, figsize=(12, 12))
            for col, zi in enumerate(indices_full):
                sl_full = t2w_full[zi]
                mask_sl = mask_arr[zi]
                ax_before = axes[0, col]
                # Image and contour: sl_full, mask_sl are (H, W) with [y,x]. Extent [0,W,0,H] = (x,y); no .T so they align.
                ax_before.imshow(sl_full, cmap="gray", origin="lower", extent=[0, W, 0, H], aspect="equal")
                x_1d = np.arange(W)
                y_1d = np.arange(H)
                ax_before.contour(x_1d, y_1d, mask_sl, levels=[0.5], colors="lime", linewidths=1.5)
                rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="cyan", linewidth=2)
                ax_before.add_patch(rect)
                ax_before.set_title(f"Before (slice {zi})")
                ax_before.set_axis_off()
                z_crop = zi - zmin
                if 0 <= z_crop < Dc:
                    sl_crop = t2w_crop[z_crop]
                    axes[1, col].imshow(sl_crop, cmap="gray", origin="lower", extent=[0, Wc, 0, Hc], aspect="equal")
                axes[1, col].set_title(f"After crop (slice {zi})" if 0 <= z_crop < Dc else "After (n/a)")
                axes[1, col].set_axis_off()
                axes[2, col].imshow(mask_sl, cmap="Greens", origin="lower", extent=[0, W, 0, H], aspect="equal", vmin=0, vmax=1)
                axes[2, col].set_title(f"Mask (slice {zi})")
                axes[2, col].set_axis_off()
            axes[0, 0].set_ylabel("Before (full)", fontsize=11)
            axes[1, 0].set_ylabel("After (ROI)", fontsize=11)
            axes[2, 0].set_ylabel("Mask", fontsize=11)
            plt.suptitle(f"Case {case_id} — prostate crop (box: z=[{zmin}:{zmax}] y=[{ymin}:{ymax}] x=[{xmin}:{xmax}])")
            plt.tight_layout()
            plt.show()

In [None]:
# Write the extractor script to disk so the next cell can run it (skip this cell if you already uploaded offline_feature_extraction_picai.py)
_SCRIPT = r'''
"""Offline FPN feature extraction for PI-CAI (backbone + FPN top-down/lateral)."""
import argparse, sys
from pathlib import Path
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    import SimpleITK as sitk
    HAS_SITK = True
except ImportError:
    HAS_SITK = False

class FeaturePyramidNetwork(nn.Module):
    def __init__(self, in_channels_list, out_channels, top_down_pathway=True, upsample_method="nearest"):
        super().__init__()
        self.top_down_pathway = top_down_pathway
        self.upsample_method = upsample_method
        self.inner_blocks = nn.ModuleDict({
            f"inner_block_{idx}": nn.Conv2d(in_ch, out_channels, kernel_size=1, bias=True)
            for idx, in_ch in enumerate(in_channels_list)
        })
        self.layer_blocks = nn.ModuleDict({
            f"layer_block_{idx}": nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True)
            for idx in range(len(in_channels_list))
        })
    def forward(self, selected_fmaps):
        last_inner = self.inner_blocks["inner_block_1"](selected_fmaps[-1])
        results = [self.layer_blocks["layer_block_1"](last_inner)]
        if self.top_down_pathway:
            inner_lateral = self.inner_blocks["inner_block_0"](selected_fmaps[0])
            feat_shape = inner_lateral.shape[-2:]
            inner_top_down = F.interpolate(last_inner, size=feat_shape, mode=self.upsample_method)
            last_inner = inner_lateral + inner_top_down
            results.insert(0, self.layer_blocks["layer_block_0"](last_inner))
        else:
            inner_lateral = self.inner_blocks["inner_block_0"](selected_fmaps[0])
            results.insert(0, self.layer_blocks["layer_block_0"](inner_lateral))
        results.append(F.max_pool2d(results[-1], kernel_size=1, stride=4, padding=0))
        return OrderedDict([(f"feat_{i}", fmap) for i, fmap in enumerate(results)])

def get_backbone_and_fpn(out_dim=256):
    from torchvision.models import resnet18
    resnet = resnet18(weights=None)
    class Backbone(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1, self.bn1, self.relu = resnet.conv1, resnet.bn1, resnet.relu
            self.maxpool, self.layer1, self.layer2 = resnet.maxpool, resnet.layer1, resnet.layer2
            self.layer3, self.layer4 = resnet.layer3, resnet.layer4
        def forward(self, x):
            x = self.relu(self.bn1(self.conv1(x)))
            x = self.maxpool(x)
            x = self.layer1(x)
            x = self.layer2(x)
            f1 = self.layer3(x)
            f2 = self.layer4(f1)
            return [f1, f2]
    backbone = Backbone()
    fpn = FeaturePyramidNetwork(in_channels_list=[256, 512], out_channels=out_dim, top_down_pathway=True)
    class Encoder(nn.Module):
        def __init__(self):
            super().__init__()
            self.backbone = backbone
            self.fpn = fpn
        def forward(self, x):
            fmaps = self.backbone(x)
            refined = self.fpn(fmaps)
            return refined["feat_0"], refined["feat_1"]
    return Encoder()

def _load(path):
    if HAS_SITK:
        return np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(str(path))), dtype=np.float32)
    import nibabel as nib
    a = nib.load(str(path)).get_fdata()
    return np.transpose(a, (2,0,1)).astype(np.float32)

def load_preprocessed(root, case_id, fold):
    tr = root / f"nnUNet_raw_data_fold{fold}" / f"Task2201_picai_fold{fold}" / "imagesTr"
    return np.stack([_load(tr/f"{case_id}_0000.nii.gz"), _load(tr/f"{case_id}_0001.nii.gz"), _load(tr/f"{case_id}_0002.nii.gz")], axis=0)

def resample_mask_to_ref(mask_path, ref_sitk):
    mask = sitk.ReadImage(str(mask_path))
    res = sitk.Resample(mask, ref_sitk, sitk.Transform(3, sitk.sitkIdentity), sitk.sitkNearestNeighbor, 0.0)
    return np.asarray(sitk.GetArrayFromImage(res))

def bbox_3d(mask, margin=0):
    inds = np.where(mask > 0)
    if len(inds[0]) == 0: return None
    zmin, zmax = inds[0].min() - margin, inds[0].max() + 1 + margin
    ymin, ymax = inds[1].min() - margin, inds[1].max() + 1 + margin
    xmin, xmax = inds[2].min() - margin, inds[2].max() + 1 + margin
    zmin, ymin, xmin = max(0, zmin), max(0, ymin), max(0, xmin)
    zmax = min(mask.shape[0], zmax)
    ymax = min(mask.shape[1], ymax)
    xmax = min(mask.shape[2], xmax)
    return (zmin, zmax), (ymin, ymax), (xmin, xmax)

def crop_vol_to_roi(vol, masks_dir, case_id, ref_path, margin=2):
    masks_dir = Path(masks_dir)
    mask_path = masks_dir / f"{case_id}.nii.gz"
    if not mask_path.exists():
        return vol, None
    ref = sitk.ReadImage(str(ref_path))
    mask_arr = resample_mask_to_ref(mask_path, ref)
    box = bbox_3d(mask_arr, margin=margin)
    if box is None: return vol, None
    (zmin, zmax), (ymin, ymax), (xmin, xmax) = box
    vol = vol[:, zmin:zmax, ymin:ymax, xmin:xmax].copy()
    return vol, np.array([zmin, zmax, ymin, ymax, xmin, xmax], dtype=np.int32)

def extract(vol, patch_size=224):
    """Extract one 224x224 patch per slice: each slice is resized/padded to patch_size x patch_size."""
    c, d, h, w = vol.shape
    out, coords = [], []
    for z in range(d):
        sl = vol[:, z, :, :]
        if sl.shape[1] < patch_size or sl.shape[2] < patch_size:
            sl = np.pad(sl, ((0,0),(0,max(0,patch_size-sl.shape[1])),(0,max(0,patch_size-sl.shape[2]))), mode="constant", constant_values=0)
        if sl.shape[1] > patch_size or sl.shape[2] > patch_size:
            t = torch.nn.functional.interpolate(torch.from_numpy(sl).unsqueeze(0).float(), (patch_size, patch_size), mode="bilinear", align_corners=False)
            sl = t.squeeze(0).numpy()
        out.append(sl)
        coords.append([z, 0, 0])
    return np.stack(out).astype(np.float32), np.array(coords, dtype=np.float32)

def main(argv=None):
    argv = argv or sys.argv[1:]
    p = argparse.ArgumentParser()
    p.add_argument("--source", default="preprocessed")
    p.add_argument("--preprocessed_root", default=None)
    p.add_argument("--labels_csv", default=None)
    p.add_argument("--feat_dir", default="picai_extracted_features")
    p.add_argument("--patch_size", type=int, default=224)
    p.add_argument("--feat_dim", type=int, default=256)
    p.add_argument("--device", default="cuda")
    p.add_argument("--batch_slices", type=int, default=32)
    p.add_argument("--stride", type=int, default=224, help="Stride for patch grid (224=no overlap, 112=50%% overlap)")
    p.add_argument("--masks_dir", default=None, help="Prostate whole-gland masks (picai_labels/Bosma22b); crop to ROI before extraction")
    p.add_argument("--crop_margin", type=int, default=2)
    args = p.parse_args(argv)
    if not args.labels_csv or not args.preprocessed_root:
        raise ValueError("Need --labels_csv and --preprocessed_root")
    if not HAS_SITK:
        raise ImportError("pip install SimpleITK")
    import pandas as pd
    import h5py
    df = pd.read_csv(args.labels_csv)
    if "image_id" not in df.columns:
        df["image_id"] = df["patient_id"]
    root_col = "preprocessed_root" if "preprocessed_root" in df.columns else None
    feat_dir = Path(args.feat_dir)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    encoder = get_backbone_and_fpn(args.feat_dim).to(device).eval()
    for _, row in df.iterrows():
        case_id, fold = str(row["image_id"]), int(row["fold"])
        root = Path(row[root_col]) if root_col else Path(args.preprocessed_root)
        tr = root / f"nnUNet_raw_data_fold{fold}" / f"Task2201_picai_fold{fold}" / "imagesTr"
        if not (tr / f"{case_id}_0000.nii.gz").exists():
            print("Skip", case_id)
            continue
        out_bag = feat_dir / "multi_scale" / case_id / case_id
        out_bag.mkdir(parents=True, exist_ok=True)
        if (out_bag / "C4_patch_features.pt").exists():
            print("Exists", case_id)
            continue
        try:
            vol = load_preprocessed(root, case_id, fold)
        except Exception as e:
            print("Load failed", case_id, e)
            continue
        roi_bbox = None
        if args.masks_dir:
            ref_path = tr / f"{case_id}_0000.nii.gz"
            vol, roi_bbox = crop_vol_to_roi(vol, args.masks_dir, case_id, ref_path, margin=args.crop_margin)
            if roi_bbox is not None:
                print("  Cropped to ROI", vol.shape)
        X, coords = extract(vol, args.patch_size)
        X = (X - X.mean()) / (X.std() + 1e-5)
        X = torch.from_numpy(X).float().to(device)
        c4l, c5l = [], []
        for i in range(0, len(X), args.batch_slices):
            with torch.no_grad():
                c4, c5 = encoder(X[i:i+args.batch_slices])
            c4l.append(c4.cpu())
            c5l.append(c5.cpu())
        C4, C5 = torch.cat(c4l), torch.cat(c5l)
        torch.save(C4, out_bag / "C4_patch_features.pt")
        torch.save(C5, out_bag / "C5_patch_features.pt")
        with h5py.File(out_bag / "info_patches.h5", "w") as f:
            f.create_dataset("coords", data=coords)
            f.attrs["patch_size"] = args.patch_size
            f.attrs["extract_mode"] = "one_patch_per_slice"
            f.attrs["img_height"], f.attrs["img_width"] = vol.shape[2], vol.shape[3]
            if roi_bbox is not None:
                f.create_dataset("roi_bbox", data=roi_bbox)
        print("Saved", case_id, C4.shape[0], "slices")
    print("Done.")

if __name__ == "__main__":
    main()
'''
Path("/kaggle/working/offline_feature_extraction_picai.py").write_text(_SCRIPT.strip())
print("Script written to /kaggle/working/offline_feature_extraction_picai.py")

In [None]:
# Run offline feature extraction
# Input: use CROPPED_ROOT if you ran "Batch crop all cases"; else per-row preprocessed_root from CSV (all 5 folds).
# Output: one patch per slice (T2W+ADC+HBV), one bag per case → multi_scale/<case_id>/<case_id>/ C4, C5, info_patches.h5
import sys
import subprocess
import pandas as pd
from pathlib import Path

use_cropped = Path(CROPPED_ROOT).exists() and any(
    (Path(CROPPED_ROOT) / f"nnUNet_raw_data_fold{f}").exists() for f in FOLDS
)
preprocessed_root = CROPPED_ROOT if use_cropped else str(fold_to_root[FOLDS[0]])
labels_csv = LABELS_CSV_PATH
if use_cropped:
    # Script must read all cases from CROPPED_ROOT; CSV has per-row roots, so override.
    df_run = pd.read_csv(LABELS_CSV_PATH)
    df_run["preprocessed_root"] = CROPPED_ROOT
    labels_csv = "/kaggle/working/picai_labels_run.csv"
    df_run.to_csv(labels_csv, index=False)
    print("Using single root for all folds (cropped):", CROPPED_ROOT)
cmd = [
    sys.executable, "/kaggle/working/offline_feature_extraction_picai.py",
    "--source", "preprocessed",
    "--preprocessed_root", preprocessed_root,
    "--labels_csv", labels_csv,
    "--feat_dir", FEAT_DIR,
    "--device", "cuda",
    "--stride", "224",   # 224 = no overlap; use 112 for 50% overlap (more patches)
]
if not use_cropped and MASKS_DIR:
    cmd += ["--masks_dir", str(MASKS_DIR), "--crop_margin", str(CROP_MARGIN)]
if not use_cropped:
    print("Using per-row preprocessed_root from CSV (folds", FOLDS, "). Extractor will use root for each case from CSV.")
subprocess.run(cmd, check=True)
print("Features saved under", FEAT_DIR)