# PI-CAI prostate ROI cropping

Crop preprocessed T2W/ADC/HBV volumes to the **prostate bounding box** using whole-gland masks from [picai_labels](https://github.com/DIAGNijmegen/picai_labels) (`anatomical_delineations/whole_gland/AI/Bosma22b/`).

**Prerequisites:** Run `kaggle_picai_preprocessing.ipynb` first so you have `nnUNet_raw_data_fold0/1/2` with `imagesTr/*_0000.nii.gz` (T2W), `_0001` (ADC), `_0002` (HBV).

**Run order:** 1) Install deps (if needed) 2) Paths 3) (Optional) Get masks 4) Crop one case + visualize 5) Batch process all cases.

In [None]:
# Install if not present (Kaggle has SimpleITK; local may need it)
!pip install -q SimpleITK matplotlib

In [None]:
from pathlib import Path
import numpy as np

# Paths — set for Kaggle or local
# Root containing nnUNet_raw_data_fold0, fold1, fold2 (from picai_prep)
PREPROCESSED_ROOT = Path("/kaggle/input/notebooks/sananiroomand/pi-cai-preprocess")  # your preprocessed output

# Whole-gland masks — see "Where to get the masks" section below.
# Easiest: run the next cell to clone picai_labels from GitHub (no extra dataset needed).
MASKS_DIR = Path("/kaggle/input/picai_labels/anatomical_delineations/whole_gland/AI/Bosma22b")
if not MASKS_DIR.exists():
    MASKS_DIR = Path("./picai_labels/anatomical_delineations/whole_gland/AI/Bosma22b")

# Where to save ROI-cropped NIfTIs: fold0/case_id_0000.nii.gz, _0001, _0002 (same layout as imagesTr)
OUTPUT_ROI_ROOT = Path("/kaggle/working/picai_roi_crops")  # or Path("./picai_roi_crops")

FOLDS = [0, 1, 2]
TASK_PREFIX = "Task2201_picai_fold"

OUTPUT_ROI_ROOT.mkdir(parents=True, exist_ok=True)
print("Preprocessed:", PREPROCESSED_ROOT)
print("Masks:", MASKS_DIR)
print("Output ROI:", OUTPUT_ROI_ROOT.resolve())

## Where to get the masks

The **prostate whole-gland masks** are **not** in the main PI-CAI image dataset. They come from the separate **picai_labels** repo (official PI-CAI annotations):

| Source | What you need |
|--------|----------------|
| **GitHub** | [github.com/DIAGNijmegen/picai_labels](https://github.com/DIAGNijmegen/picai_labels) → folder `anatomical_delineations/whole_gland/AI/Bosma22b/` (one `.nii.gz` per case, e.g. `10000_1000000.nii.gz`) |

**Easiest:** Run the cell below. It clones the repo into your notebook environment and sets `MASKS_DIR` to the Bosma22b folder. You don't need to add any extra Kaggle dataset.

**Alternative:** If you already have picai_labels as a Kaggle dataset (e.g. you uploaded it), add it as an input to this notebook and set `MASKS_DIR` in the Paths cell to the path inside that dataset (e.g. `/kaggle/input/your-dataset-name/anatomical_delineations/whole_gland/AI/Bosma22b`).

In [None]:
if not MASKS_DIR.exists():
    import subprocess
    clone_dir = Path("./picai_labels")
    if not (clone_dir / "anatomical_delineations").exists():
        subprocess.run(["git", "clone", "--depth", "1", "https://github.com/DIAGNijmegen/picai_labels.git", str(clone_dir)], check=True)
    MASKS_DIR = clone_dir / "anatomical_delineations" / "whole_gland" / "AI" / "Bosma22b"
    print("Using masks from:", MASKS_DIR)
else:
    n_masks = len(list(MASKS_DIR.glob("*.nii.gz"))) if MASKS_DIR.exists() else 0
    print(f"Found {n_masks} mask files in", MASKS_DIR)

In [None]:
import SimpleITK as sitk


def get_images_tr_path(preprocessed_root: Path, fold: int):
    return preprocessed_root / f"nnUNet_raw_data_fold{fold}" / f"{TASK_PREFIX}{fold}" / "imagesTr"


def load_case_volumes(images_tr: Path, case_id: str):
    """Load T2W (0000), ADC (0001), HBV (0002) as numpy arrays and the reference image for spacing/origin."""
    ref = sitk.ReadImage(str(images_tr / f"{case_id}_0000.nii.gz"))
    t2w = sitk.GetArrayFromImage(ref)
    adc = sitk.GetArrayFromImage(sitk.ReadImage(str(images_tr / f"{case_id}_0001.nii.gz")))
    hbv = sitk.GetArrayFromImage(sitk.ReadImage(str(images_tr / f"{case_id}_0002.nii.gz")))
    return t2w, adc, hbv, ref


def resample_mask_to_reference(mask_path: Path, ref_image: sitk.Image):
    """Resample mask to the same grid as ref_image (preprocessed volume)."""
    mask = sitk.ReadImage(str(mask_path))
    resampled = sitk.Resample(mask, ref_image, sitk.Transform(3, sitk.sitkIdentity),
                              sitk.sitkNearestNeighbor, 0.0)
    return sitk.GetArrayFromImage(resampled)


def bbox_3d(mask: np.ndarray, margin: int = 0):
    """Compute 3D bounding box of mask > 0. Returns (zmin, zmax), (ymin, ymax), (xmin, xmax) in array indices."""
    inds = np.where(mask > 0)
    if len(inds[0]) == 0:
        return None
    zmin, zmax = inds[0].min() - margin, inds[0].max() + 1 + margin
    ymin, ymax = inds[1].min() - margin, inds[1].max() + 1 + margin
    xmin, xmax = inds[2].min() - margin, inds[2].max() + 1 + margin
    # Clamp to volume size
    zmin = max(0, zmin)
    ymin = max(0, ymin)
    xmin = max(0, xmin)
    zmax = min(mask.shape[0], zmax)
    ymax = min(mask.shape[1], ymax)
    xmax = min(mask.shape[2], xmax)
    return (zmin, zmax), (ymin, ymax), (xmin, xmax)


def crop_volumes_to_roi(t2w, adc, hbv, mask_resampled, margin: int = 2):
    """Crop T2W, ADC, HBV to prostate bbox. Returns cropped arrays and bbox."""
    box = bbox_3d(mask_resampled, margin=margin)
    if box is None:
        return None, None
    (zmin, zmax), (ymin, ymax), (xmin, xmax) = box
    t2w_c = t2w[zmin:zmax, ymin:ymax, xmin:xmax]
    adc_c = adc[zmin:zmax, ymin:ymax, xmin:xmax]
    hbv_c = hbv[zmin:zmax, ymin:ymax, xmin:xmax]
    return (t2w_c, adc_c, hbv_c), box


def save_cropped_case(out_dir: Path, case_id: str, t2w_c, adc_c, hbv_c, ref_image: sitk.Image, box):
    """Save cropped volumes as NIfTI (flat: out_dir/case_id_0000.nii.gz, etc.)."""
    (zmin, zmax), (ymin, ymax), (xmin, xmax) = box
    out_dir.mkdir(parents=True, exist_ok=True)
    origin = ref_image.GetOrigin()
    spacing = ref_image.GetSpacing()
    direction = ref_image.GetDirection()
    o = list(origin)
    o[0] += xmin * spacing[0]
    o[1] += ymin * spacing[1]
    o[2] += zmin * spacing[2]
    for name, arr in [("_0000", t2w_c), ("_0001", adc_c), ("_0002", hbv_c)]:
        img = sitk.GetImageFromArray(arr)
        img.SetSpacing(spacing)
        img.SetOrigin(o)
        img.SetDirection(direction)
        sitk.WriteImage(img, str(out_dir / f"{case_id}{name}.nii.gz"))
    return out_dir

## Crop one case and visualize

Pick a case that exists in preprocessed data and has a mask; crop and show one axial slice.

In [None]:
# Find first case that has both preprocessed data and mask
case_id = None
fold_used = None
for fold in FOLDS:
    images_tr = get_images_tr_path(PREPROCESSED_ROOT, fold)
    if not images_tr.exists():
        continue
    for nii in images_tr.glob("*_0000.nii.gz"):
        cid = nii.name.replace("_0000.nii.gz", "")
        mask_path = MASKS_DIR / f"{cid}.nii.gz"
        if mask_path.exists():
            case_id = cid
            fold_used = fold
            break
    if case_id is not None:
        break

if case_id is None:
    print("No case found with both preprocessed data and mask. Check PREPROCESSED_ROOT and MASKS_DIR.")
else:
    print(f"Example case: {case_id} (fold {fold_used})")
    images_tr = get_images_tr_path(PREPROCESSED_ROOT, fold_used)
    t2w, adc, hbv, ref = load_case_volumes(images_tr, case_id)
    mask_resampled = resample_mask_to_reference(MASKS_DIR / f"{case_id}.nii.gz", ref)
    cropped, box = crop_volumes_to_roi(t2w, adc, hbv, mask_resampled)
    if cropped is not None:
        t2w_c, adc_c, hbv_c = cropped
        out_fold_dir = OUTPUT_ROI_ROOT / f"fold{fold_used}"
        save_cropped_case(out_fold_dir, case_id, t2w_c, adc_c, hbv_c, ref, box)
        print("Saved to", out_fold_dir / case_id)
        # Slice index for visualization (middle of crop)
        z_slice = t2w_c.shape[0] // 2

In [None]:
if case_id is not None and cropped is not None:
    import matplotlib.pyplot as plt
    nz = t2w_c.shape[0]
    slices = [nz // 4, nz // 2, 3 * nz // 4]  # 3 axial slices through the crop
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    for row, z in enumerate(slices):
        axes[row, 0].imshow(t2w_c[z], cmap="gray")
        axes[row, 0].set_ylabel(f"z={z}", fontsize=9)
        axes[row, 0].set_title("T2W")
        axes[row, 1].imshow(adc_c[z], cmap="gray")
        axes[row, 1].set_title("ADC")
        axes[row, 2].imshow(hbv_c[z], cmap="gray")
        axes[row, 2].set_title("HBV")
    for ax in axes.flat:
        ax.set_axis_off()
    plt.suptitle(f"Example crop — {case_id} (ROI: {t2w_c.shape[0]}×{t2w_c.shape[1]}×{t2w_c.shape[2]} slices)")
    plt.tight_layout()
    plt.show()
else:
    print("Nothing to show (no case cropped).")

## Batch process all cases with masks

Iterate over all preprocessed cases (folds 0–2); for each case that has a Bosma22b mask, crop and save to `OUTPUT_ROI_ROOT/fold{N}/<case_id>/`.

In [None]:
done = 0
skipped_no_mask = 0
errors = []
for fold in FOLDS:
    images_tr = get_images_tr_path(PREPROCESSED_ROOT, fold)
    if not images_tr.exists():
        print(f"Fold {fold}: imagesTr not found, skip")
        continue
    out_fold_dir = OUTPUT_ROI_ROOT / f"fold{fold}"
    out_fold_dir.mkdir(parents=True, exist_ok=True)
    for nii in sorted(images_tr.glob("*_0000.nii.gz")):
        case_id = nii.name.replace("_0000.nii.gz", "")
        mask_path = MASKS_DIR / f"{case_id}.nii.gz"
        if not mask_path.exists():
            skipped_no_mask += 1
            continue
        try:
            t2w, adc, hbv, ref = load_case_volumes(images_tr, case_id)
            mask_resampled = resample_mask_to_reference(mask_path, ref)
            cropped, box = crop_volumes_to_roi(t2w, adc, hbv, mask_resampled)
            if cropped is None:
                errors.append((case_id, "empty mask after resample"))
                continue
            t2w_c, adc_c, hbv_c = cropped
            save_cropped_case(out_fold_dir, case_id, t2w_c, adc_c, hbv_c, ref, box)
            done += 1
        except Exception as e:
            errors.append((case_id, str(e)))
    print(f"Fold {fold}: done")

print(f"\nCropped {done} cases. Skipped (no mask): {skipped_no_mask}. Errors: {len(errors)}")
if errors:
    for cid, msg in errors[:10]:
        print(f"  {cid}: {msg}")
    if len(errors) > 10:
        print(f"  ... and {len(errors) - 10} more")