In [1]:
!pip install -q monai


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m99.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m74.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m40.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
"""
MONAI 2D contrastive pipeline for ADNI T1 MRI slices with UNet mask reconstruction targets.

Features
- 3D preprocessing: RAS orientation, 1.0 mm spacing, intensity clipping by percentiles inside brain, z score in mask,
  optional skull strip heuristic when masks are missing, persistent logging of mean and std per case.
- 2D axial slice extraction with brain area threshold, fixed center crop, and midline column recording.
- Contrastive dataset: each slice yields two stochastic views with synchronized spatial transforms for image and mask
  and image only intensity jitter and noise. Includes custom hemispheric mirror transform around anatomical midline.
- Reproducibility: global seeds and per view seeds. Logging of transform params and counts.
- Validation and test pipelines with deterministic transforms.
- Visual QC utilities to inspect original vs two views with mask overlay.

Requirements
- Python 3.9+
- torch >= 2.0, monai >= 1.3, nibabel, scipy, numpy, matplotlib

This file is structured to be imported as a module or run as a script for quick sanity checks.
"""
from __future__ import annotations

import os
import json
import math
import random
import shutil
import logging
import pathlib
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader

from monai.transforms import (
    Compose,
    MapTransform,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    Spacingd,
    EnsureTyped,
    RandAffined,
    RandFlipd,
    Rand2DElasticd,
    RandShiftIntensityd,
    RandScaleIntensityd,
    RandGaussianNoised,
)
from monai.transforms.transform import Randomizable
from monai.data import list_data_collate
from monai.utils import InterpolateMode, set_determinism

from scipy.ndimage import gaussian_filter, binary_closing, binary_opening

# ----------------------------------------
# Config
# ----------------------------------------

@dataclass
class PipelineConfig:
    # paths
    data_csv: Optional[str] = None  # optional CSV with columns: image_path, mask_path (optional), patient_id
    preproc_dir: str = "./preproc"
    logs_dir: str = "./logs"

    # preprocessing
    target_spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0)
    percentile_clip: Tuple[float, float] = (0.5, 99.5)
    skull_strip_when_missing: bool = True

    # 2D slicing
    crop_size: Tuple[int, int] = (192, 192)
    min_brain_area_ratio: float = 0.01  # exclude slices with less than 1 percent foreground

    # augmentation probabilities and ranges
    rot_deg: float = 15.0
    rot_prob: float = 0.5
    hflip_prob: float = 0.5
    vflip_prob: float = 0.2
    elastic_prob: float = 0.3
    elastic_spacing_range: Tuple[int, int] = (16, 40)
    elastic_magnitude_range: Tuple[float, float] = (2.0, 5.0)
    jitter_prob: float = 0.5
    jitter_brightness: float = 0.10
    jitter_contrast_range: Tuple[float, float] = (0.9, 1.1)
    gamma_range: Tuple[float, float] = (0.9, 1.1)
    noise_prob: float = 0.3
    noise_std_range: Tuple[float, float] = (0.01, 0.05)
    mirror_prob: float = 0.3

    # loader
    batch_size: int = 8
    num_workers: int = 4
    pin_memory: bool = True

    # reproducibility
    seed: int = 42


# ----------------------------------------
# Logging and seeds
# ----------------------------------------

def setup_logging(log_dir: str) -> logging.Logger:
    os.makedirs(log_dir, exist_ok=True)
    logger = logging.getLogger("monai2d")
    logger.setLevel(logging.INFO)
    # clear previous handlers
    logger.handlers.clear()
    fh = logging.FileHandler(os.path.join(log_dir, "run.log"))
    ch = logging.StreamHandler()
    fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
    fh.setFormatter(fmt)
    ch.setFormatter(fmt)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger


def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    set_determinism(seed=seed)


# ----------------------------------------
# Utilities
# ----------------------------------------


def load_nifti(path: str) -> Tuple[np.ndarray, np.ndarray]:
    img = nib.load(path)
    data = img.get_fdata(dtype=np.float32)
    return data, img.affine


def save_nifti(path: str, data: np.ndarray, affine: np.ndarray) -> None:
    nib.save(nib.Nifti1Image(data.astype(np.float32), affine), path)


def compute_brain_mask_simple(vol: np.ndarray) -> np.ndarray:
    """Simple foreground heuristic mask on 3D volume (expects channel last).
    Steps: Gaussian smooth, global percentile threshold, morphological open/close.
    """
    v = gaussian_filter(vol, sigma=1.0)
    thr = np.percentile(v[v > np.percentile(v, 5)], 60)
    mask = v > thr
    mask = binary_opening(mask, iterations=2)
    mask = binary_closing(mask, iterations=2)
    return mask


def clip_and_normalize_in_mask(vol: np.ndarray, mask: np.ndarray, pmin: float, pmax: float) -> Tuple[np.ndarray, Dict[str, float]]:
    assert vol.shape == mask.shape
    vox = vol[mask > 0]
    lo = np.percentile(vox, pmin)
    hi = np.percentile(vox, pmax)
    v = np.clip(vol, lo, hi)
    mu = float(v[mask > 0].mean())
    sigma = float(v[mask > 0].std() + 1e-8)
    v = (v - mu) / sigma
    stats = {"pmin": float(lo), "pmax": float(hi), "mean": mu, "std": sigma}
    return v.astype(np.float32), stats


def get_bbox_2d(mask2d: np.ndarray) -> Tuple[int, int, int, int]:
    ys, xs = np.where(mask2d > 0)
    if len(xs) == 0:
        return 0, mask2d.shape[0], 0, mask2d.shape[1]
    return int(ys.min()), int(ys.max()) + 1, int(xs.min()), int(xs.max()) + 1


def center_crop_with_bbox(img: np.ndarray, mask: np.ndarray, size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int]]:
    h, w = img.shape
    th, tw = size
    y0, y1, x0, x1 = get_bbox_2d(mask)
    cy = (y0 + y1) // 2
    cx = (x0 + x1) // 2
    sy = max(0, cy - th // 2)
    sx = max(0, cx - tw // 2)
    sy = min(sy, h - th)
    sx = min(sx, w - tw)
    pad_y0 = max(0, -sy)
    pad_x0 = max(0, -sx)
    sy = max(0, sy)
    sx = max(0, sx)
    img_crop = img[sy: sy + th, sx: sx + tw]
    mask_crop = mask[sy: sy + th, sx: sx + tw]
    offset = (sy, sx)
    # zero pad if needed
    if img_crop.shape != (th, tw):
        out_img = np.zeros((th, tw), dtype=img.dtype)
        out_msk = np.zeros((th, tw), dtype=mask.dtype)
        out_img[: img_crop.shape[0], : img_crop.shape[1]] = img_crop
        out_msk[: mask_crop.shape[0], : mask_crop.shape[1]] = mask_crop
        img_crop, mask_crop = out_img, out_msk
    return img_crop, mask_crop, offset


# ----------------------------------------
# 3D preprocessing
# ----------------------------------------

def preprocess_subject(
    image_path: str,
    output_dir: str,
    mask_path: Optional[str] = None,
    cfg: PipelineConfig = PipelineConfig(),
    logger: Optional[logging.Logger] = None,
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    pid = pathlib.Path(image_path).stem.split("_")[1] if "_" in pathlib.Path(image_path).stem else pathlib.Path(image_path).stem

    data_dict: Dict[str, Any] = {"image": image_path}
    if mask_path is not None and os.path.exists(mask_path):
        data_dict["mask"] = mask_path

    # Build transform keys and per-key interpolation modes dynamically (so it works when mask is missing)
    keys_pre = list(data_dict.keys())
    modes_pre = tuple(InterpolateMode.BILINEAR if k == "image" else InterpolateMode.NEAREST for k in keys_pre)

    t_pre = Compose([
        LoadImaged(keys=keys_pre, allow_missing_keys=True),
        EnsureChannelFirstd(keys=keys_pre, allow_missing_keys=True),
        Orientationd(keys=keys_pre, axcodes="RAS", allow_missing_keys=True),
        Spacingd(keys=keys_pre, pixdim=cfg.target_spacing, mode=modes_pre, allow_missing_keys=True),
        EnsureTyped(keys=keys_pre, allow_missing_keys=True),
    ])

    d = t_pre(data_dict)

    img_t = d["image"]  # MetaTensor/Tensor with channel-first [1, H, W, D]
    # Robust affine extraction across MONAI versions
    _aff = None
    try:
        if hasattr(img_t, "meta") and isinstance(getattr(img_t, "meta"), dict) and "affine" in img_t.meta:
            _aff = img_t.meta["affine"]
        elif hasattr(img_t, "affine") and img_t.affine is not None:
            _aff = img_t.affine
    except Exception:
        _aff = None
    if _aff is None and "image_meta_dict" in d and isinstance(d["image_meta_dict"], dict) and "affine" in d["image_meta_dict"]:
        _aff = d["image_meta_dict"]["affine"]
    if _aff is None:
        _aff = np.eye(4, dtype=np.float32)
    
    aff = _aff
    img = img_t[0].cpu().numpy()  # 3D

    if "mask" in d:
        msk = (d["mask"][0].cpu().numpy() > 0.5).astype(np.uint8)
    else:
        if cfg.skull_strip_when_missing:
            msk = compute_brain_mask_simple(img).astype(np.uint8)
        else:
            # fallback to nonzero
            msk = (img != 0).astype(np.uint8)

    # clip and z score inside mask
    img_n, stats = clip_and_normalize_in_mask(img, msk, cfg.percentile_clip[0], cfg.percentile_clip[1])

    # save
    case_dir = os.path.join(output_dir, pid)
    os.makedirs(case_dir, exist_ok=True)
    out_img = os.path.join(case_dir, "image_preproc.nii.gz")
    out_msk = os.path.join(case_dir, "mask_preproc.nii.gz")
    save_nifti(out_img, img_n, aff)
    save_nifti(out_msk, msk.astype(np.uint8), aff)

    meta = {
        "patient_id": pid,
        "image_path": out_img,
        "mask_path": out_msk,
        "affine": aff.tolist(),
        "stats": stats,
        "shape": list(img_n.shape),
    }

    if logger is not None:
        with open(os.path.join(os.path.dirname(out_img), "stats.json"), "w") as f:
            json.dump(meta, f, indent=2)
        logger.info(f"Preprocessed {pid} saved to {case_dir} with stats {stats}")

    return meta


# ----------------------------------------
# Slice indexing
# ----------------------------------------

def build_slice_index(preproc_cases: List[Dict[str, Any]], cfg: PipelineConfig, index_json: str) -> List[Dict[str, Any]]:
    index: List[Dict[str, Any]] = []
    for meta in preproc_cases:
        img3d, _ = load_nifti(meta["image_path"])
        msk3d, _ = load_nifti(meta["mask_path"])
        h, w, d = img3d.shape
        mid_col = w // 2
        for z in range(d):
            m2 = msk3d[:, :, z] > 0
            area_ratio = float(m2.mean())
            if area_ratio < cfg.min_brain_area_ratio:
                continue
            index.append({
                "patient_id": meta["patient_id"],
                "image_path": meta["image_path"],
                "mask_path": meta["mask_path"],
                "slice_index": int(z),
                "midline_col": int(mid_col),
                "affine": meta["affine"],
                "shape2d": [int(h), int(w)],
            })
    os.makedirs(os.path.dirname(index_json), exist_ok=True)
    with open(index_json, "w") as f:
        json.dump(index, f, indent=2)
    return index


# ----------------------------------------
# Custom transforms
# ----------------------------------------

class RandGammaIntensityd(MapTransform, Randomizable):
    """Random gamma correction for images in [C, H, W] assuming intensities are roughly standard normal.
    Applies image <- sign(image) * |image|**gamma with gamma in a range around 1.
    """

    def __init__(self, keys: Sequence[str], gamma_range: Tuple[float, float] = (0.9, 1.1), prob: float = 0.5) -> None:
        super().__init__(keys)
        self.gamma_range = gamma_range
        self.prob = prob
        self._do = False
        self._gamma = 1.0

    def randomize(self) -> None:
        self._do = self.R.random() < self.prob
        self._gamma = self.R.uniform(self.gamma_range[0], self.gamma_range[1])

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        d = dict(data)
        self.randomize()
        if not self._do:
            return d
        for k in self.keys:  # type: ignore
            x = d[k]
            # keep sign to be stable around 0 after z score
            d[k] = torch.sign(x) * torch.pow(torch.abs(x) + 1e-8, self._gamma)
        return d


class MirrorAroundMidline2Dd(MapTransform, Randomizable):
    """Reflect image and mask around anatomical midline column stored in data["midline_col"].
    This is distinct from a simple flip since it mirrors around mid column rather than image boundary.
    """

    def __init__(self, keys: Sequence[str], prob: float = 0.3) -> None:
        super().__init__(keys)
        self.prob = prob
        self._do = False

    def randomize(self) -> None:
        self._do = self.R.random() < self.prob

    @staticmethod
    def _mirror_np(arr: np.ndarray, mid_col: int) -> np.ndarray:
        # arr shape [H, W]
        h, w = arr.shape
        j = np.arange(w)
        j_m = np.clip(2 * mid_col - j, 0, w - 1)
        return arr[:, j_m]

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        d = dict(data)
        self.randomize()
        if not self._do:
            return d
        mid = int(d.get("midline_col", d["image"].shape[-1] // 2))
        for k in self.keys:  # type: ignore
            x = d[k]
            if isinstance(x, torch.Tensor):
                x2d = x[0].cpu().numpy()
                y2d = self._mirror_np(x2d, mid)
                if k == "mask":
                    # nearest like behavior by discrete copy
                    y2d = (y2d > 0.5).astype(np.float32)
                d[k] = torch.from_numpy(y2d[None, ...]).to(x.dtype)
            else:
                raise ValueError("Transform expects tensors with channel first")
        return d


class RandGaussianNoiseRanged(MapTransform, Randomizable):
    """Additive Gaussian noise with std sampled from a range per-call.
    Applies to image only; keeps mask untouched by choosing appropriate keys at call site.
    """

    def __init__(self, keys: Sequence[str], std_range: Tuple[float, float] = (0.01, 0.05), prob: float = 0.3, mean: float = 0.0) -> None:
        super().__init__(keys)
        self.std_range = std_range
        self.mean = mean
        self.prob = prob
        self._do = False
        self._std = float(np.mean(std_range))

    def randomize(self) -> None:
        self._do = self.R.random() < self.prob
        self._std = self.R.uniform(self.std_range[0], self.std_range[1])

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        d = dict(data)
        self.randomize()
        if not self._do:
            return d
        for k in self.keys:  # type: ignore
            x = d[k]
            noise = torch.randn_like(x) * self._std + self.mean
            d[k] = x + noise
        return d


# ----------------------------------------
# Transform builders
# ----------------------------------------

def make_spatial_aug(cfg: PipelineConfig) -> Compose:
    rad = math.radians(cfg.rot_deg)
    # dictionary transforms expect a tuple of modes aligned with keys order ["image", "mask"], not a dict
    mode_imgmask = (InterpolateMode.BILINEAR, InterpolateMode.NEAREST)
    return Compose([
        RandAffined(
            keys=["image", "mask"],
            prob=cfg.rot_prob,
            rotate_range=(-rad, rad),  # 2D in plane
            translate_range=None,
            scale_range=None,
            mode=mode_imgmask,
            padding_mode="zeros",
        ),
        Rand2DElasticd(
            keys=["image", "mask"],
            prob=cfg.elastic_prob,
            spacing=cfg.elastic_spacing_range,
            magnitude_range=cfg.elastic_magnitude_range,
            mode=mode_imgmask,
            padding_mode="zeros",
        ),
        RandFlipd(keys=["image", "mask"], spatial_axis=1, prob=cfg.hflip_prob),  # LR (W axis for 2D [H,W])
        RandFlipd(keys=["image", "mask"], spatial_axis=0, prob=cfg.vflip_prob),  # AP (H axis for 2D [H,W])
        MirrorAroundMidline2Dd(keys=["image", "mask"], prob=cfg.mirror_prob),
    ])


def make_intensity_aug(cfg: PipelineConfig) -> Compose:
    return Compose([
        RandShiftIntensityd(keys=["image"], offsets=cfg.jitter_brightness, prob=cfg.jitter_prob),
        RandScaleIntensityd(keys=["image"], factors=cfg.jitter_contrast_range, prob=cfg.jitter_prob),
        RandGammaIntensityd(keys=["image"], gamma_range=cfg.gamma_range, prob=cfg.jitter_prob),
        RandGaussianNoiseRanged(keys=["image"], std_range=cfg.noise_std_range, prob=cfg.noise_prob, mean=0.0),
    ])


# ----------------------------------------
# Dataset
# ----------------------------------------

class ContrastiveSliceDataset(Dataset):
    def __init__(
        self,
        slice_index: List[Dict[str, Any]],
        cfg: PipelineConfig,
        logger: Optional[logging.Logger] = None,
        deterministic: bool = False,
    ) -> None:
        self.idx = slice_index
        self.cfg = cfg
        self.logger = logger
        self.det = deterministic
        self.vol_cache: Dict[str, np.ndarray] = {}
        self.msk_cache: Dict[str, np.ndarray] = {}

        # deterministic transforms for base crop
        self.base_crop_size = cfg.crop_size

        # augmentation pipelines
        self.spatial_aug_a = make_spatial_aug(cfg)
        self.intensity_aug_a = make_intensity_aug(cfg)
        self.spatial_aug_b = make_spatial_aug(cfg)
        self.intensity_aug_b = make_intensity_aug(cfg)

        if self.det:
            # no random in validation or test
            self.spatial_aug_a = Compose([])
            self.intensity_aug_a = Compose([])
            self.spatial_aug_b = Compose([])
            self.intensity_aug_b = Compose([])

    def __len__(self) -> int:
        return len(self.idx)

    def _load_volume(self, path: str) -> np.ndarray:
        if path not in self.vol_cache:
            arr, _ = load_nifti(path)
            self.vol_cache[path] = arr
        return self.vol_cache[path]

    def _load_mask(self, path: str) -> np.ndarray:
        if path not in self.msk_cache:
            arr, _ = load_nifti(path)
            self.msk_cache[path] = (arr > 0.5).astype(np.uint8)
        return self.msk_cache[path]

    @staticmethod
    def _seed_transforms(t: Compose, seed: int) -> None:
        # propagate seed to all randomizable transforms
        for tr in t.transforms:
            if isinstance(tr, Randomizable):
                tr.set_random_state(seed=seed)

    def __getitem__(self, i: int) -> Dict[str, Any]:
        item = self.idx[i]
        img3d = self._load_volume(item["image_path"])  # H W D
        msk3d = self._load_mask(item["mask_path"])     # H W D
        z = int(item["slice_index"])
        img2d = img3d[:, :, z]
        msk2d = msk3d[:, :, z]

        # center crop on brain bbox with zero pad
        img2d, msk2d, offset = center_crop_with_bbox(img2d, msk2d, self.base_crop_size)
        # adjust midline column into cropped coordinates
        _, sx = offset
        midline_cropped = int(np.clip(int(item["midline_col"]) - sx, 0, self.base_crop_size[1] - 1))

        # prepare dict
        base = {
            "image": torch.from_numpy(img2d[None, ...].astype(np.float32)),
            "mask": torch.from_numpy(msk2d[None, ...].astype(np.float32)),
            "midline_col": int(midline_cropped),
        }

        # view A
        if not self.det:
            seed_a = np.random.randint(0, 2**31 - 1)
            self._seed_transforms(self.spatial_aug_a, seed_a)
            self._seed_transforms(self.intensity_aug_a, seed_a + 1337)
        da = self.spatial_aug_a(dict(base))
        da = self.intensity_aug_a(da)

        # view B
        if not self.det:
            seed_b = np.random.randint(0, 2**31 - 1)
            self._seed_transforms(self.spatial_aug_b, seed_b)
            self._seed_transforms(self.intensity_aug_b, seed_b + 7331)
        db = self.spatial_aug_b(dict(base))
        db = self.intensity_aug_b(db)

        # mask safety: keep binary
        da["mask"] = (da["mask"] > 0.5).float()
        db["mask"] = (db["mask"] > 0.5).float()

        # foreground ratio check
        fg_ratio = float(da["mask"].mean().item())
        if fg_ratio < 1e-4 and self.logger is not None:
            self.logger.warning(f"Very low foreground ratio in sample {i} pid={item['patient_id']} slice={z}")

        return {
            "image_a": da["image"],
            "mask_a": da["mask"],
            "image_b": db["image"],
            "mask_b": db["mask"],
            "patient_id": item["patient_id"],
            "slice_index": z,
            "midline_col": int(item["midline_col"]),
            "affine": np.array(item["affine"], dtype=np.float32),
        }


# ----------------------------------------
# Data utilities
# ----------------------------------------

def split_by_patient(cases: List[Dict[str, Any]], train: float = 0.7, val: float = 0.15, seed: int = 42) -> Tuple[List[str], List[str], List[str]]:
    pids = sorted({c["patient_id"] for c in cases})
    rng = np.random.RandomState(seed)
    rng.shuffle(pids)
    n = len(pids)
    n_tr = int(train * n)
    n_vl = int(val * n)
    tr = pids[:n_tr]
    vl = pids[n_tr:n_tr + n_vl]
    te = pids[n_tr + n_vl:]
    return tr, vl, te


def filter_index_by_pids(index: List[Dict[str, Any]], pids: List[str]) -> List[Dict[str, Any]]:
    pidset = set(pids)
    return [it for it in index if it["patient_id"] in pidset]


# ----------------------------------------
# Validation and test dataset wrappers
# ----------------------------------------

def build_dataloaders(
    index: List[Dict[str, Any]],
    cases_meta: List[Dict[str, Any]],
    cfg: PipelineConfig,
    logger: Optional[logging.Logger] = None,
) -> Dict[str, DataLoader]:
    tr_pids, vl_pids, te_pids = split_by_patient(cases_meta, train=0.7, val=0.15, seed=cfg.seed)
    idx_tr = filter_index_by_pids(index, tr_pids)
    idx_vl = filter_index_by_pids(index, vl_pids)
    idx_te = filter_index_by_pids(index, te_pids)

    if logger is not None:
        logger.info(f"Slices by split — train:{len(idx_tr)} val:{len(idx_vl)} test:{len(idx_te)}")

    ds_tr = ContrastiveSliceDataset(idx_tr, cfg, logger=logger, deterministic=False)
    ds_vl = ContrastiveSliceDataset(idx_vl, cfg, logger=logger, deterministic=True)
    ds_te = ContrastiveSliceDataset(idx_te, cfg, logger=logger, deterministic=True)

    if len(ds_tr) == 0:
        raise ValueError(
            "Training dataset is empty. Provide subjects in demo_subjects or from your CSV, "
            "lower cfg.min_brain_area_ratio if many slices are filtered, and ensure preprocessing ran."
        )

    dl_tr = DataLoader(ds_tr, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
    dl_vl = DataLoader(ds_vl, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
    dl_te = DataLoader(ds_te, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)

    return {"train": dl_tr, "val": dl_vl, "test": dl_te}


# ----------------------------------------
# Visualization utilities
# ----------------------------------------

def _overlay_mask(img: np.ndarray, msk: np.ndarray, alpha: float = 0.4) -> np.ndarray:
    img_n = (img - img.min()) / (img.ptp() + 1e-8)
    rgb = np.stack([img_n, img_n, img_n], axis=-1)
    color = np.array([1.0, 0.0, 0.0])[None, None, :]
    rgb = (1 - alpha * msk[..., None]) * rgb + alpha * msk[..., None] * color
    return np.clip(rgb, 0, 1)


def show_qc_grid(dataset: ContrastiveSliceDataset, indices: Sequence[int] = (0, 1, 2)) -> None:
    import matplotlib.pyplot as plt

    n = len(indices)
    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(12, 4 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)
    for r, i in enumerate(indices):
        sample = dataset[i]
        img0 = sample["image_a"][0].numpy()  # using view A pre spatial, but already augmented; for original, you can bypass aug in code if needed
        # For a true original view, you can modify dataset to return base before aug. For brevity, we display A and B plus A with overlay.
        img_a = sample["image_a"][0].numpy()
        msk_a = sample["mask_a"][0].numpy()
        img_b = sample["image_b"][0].numpy()
        msk_b = sample["mask_b"][0].numpy()

        axes[r, 0].imshow(_overlay_mask(img_a, msk_a > 0.5))
        axes[r, 0].set_title(f"View A pid={sample['patient_id']} z={sample['slice_index']}")
        axes[r, 0].axis("off")
        axes[r, 1].imshow(_overlay_mask(img_b, msk_b > 0.5))
        axes[r, 1].set_title("View B")
        axes[r, 1].axis("off")
        axes[r, 2].imshow(img_a, cmap="gray")
        axes[r, 2].set_title("View A gray")
        axes[r, 2].axis("off")
    plt.tight_layout()
    plt.show()


# ----------------------------------------
# Script entry: example usage
# ----------------------------------------


def _make_dummy_case(root: str, pid: str = "DUMMY01", shape: Tuple[int, int, int] = (192, 192, 64)) -> Dict[str, Any]:
    os.makedirs(root, exist_ok=True)
    h, w, d = shape
    # synthetic brain-like blob
    yy, xx, zz = np.meshgrid(np.linspace(-1, 1, h), np.linspace(-1, 1, w), np.linspace(-1, 1, d), indexing="ij")
    r = np.sqrt(yy**2 + (xx*0.9)**2 + (zz*1.4)**2)
    img = np.exp(-3*r**2).astype(np.float32)
    img = (img * 2000 + 300)  # shift/scale to mimic T1 range
    mask = (r < 0.9).astype(np.uint8)
    aff = np.diag([1.0, 1.0, 1.0, 1.0]).astype(np.float32)
    case_dir = os.path.join(root, pid)
    os.makedirs(case_dir, exist_ok=True)
    img_path = os.path.join(case_dir, f"{pid}_T1.nii.gz")
    msk_path = os.path.join(case_dir, f"{pid}_mask.nii.gz")
    save_nifti(img_path, img, aff)
    save_nifti(msk_path, mask, aff)
    return {"image_path": img_path, "mask_path": msk_path, "patient_id": pid}


if __name__ == "__main__":
    cfg = PipelineConfig()
    logger = setup_logging(cfg.logs_dir)
    set_global_seed(cfg.seed)

    # demo subject list example (replace with your actual paths)
    demo_subjects = [
        {
            "image_path": "/kaggle/input/1-mri-samples/002_S_0295/002_S_0295/MT1__GradWarp__N3m/2010-05-13_06_37_21.0/I291867/ADNI_002_S_0295_MR_MT1__GradWarp__N3m_Br_20120322162736575_S84944_I291867.nii", 
            "mask_path": None, 
            "patient_id": "001"
        },
        {
            "image_path": "/kaggle/input/1-mri-samples/002_S_0295/002_S_0295/MT1__GradWarp__N3m/2010-05-13_06_45_21.0/I291869/ADNI_002_S_0295_MR_MT1__GradWarp__N3m_Br_20120322162842799_S84948_I291869.nii", 
            "mask_path": None, 
            "patient_id": "001"
        },
        {
            "image_path": "/kaggle/input/1-mri-samples/002_S_0413/002_S_0413/MT1__GradWarp__N3m/2010-05-06_12_37_46.0/I291872/ADNI_002_S_0413_MR_MT1__GradWarp__N3m_Br_20120322163151051_S84763_I291872.nii", 
            "mask_path": None, 
            "patient_id": "002"
        },
        {
            "image_path": "/kaggle/input/1-mri-samples/002_S_0413/002_S_0413/MT1__GradWarp__N3m/2010-05-06_12_46_10.0/I291873/ADNI_002_S_0413_MR_MT1__GradWarp__N3m_Br_20120322163254826_S84764_I291873.nii", 
            "mask_path": None, 
            "patient_id": "002"
        },
    ]


    if not demo_subjects:
        logger.info("No demo subjects provided — creating a synthetic dummy case for a smoke test.")
        demo_subjects = [_make_dummy_case("./synthetic_cases", pid="SYN001")]  # quick sanity dataset

    preproc_metas: List[Dict[str, Any]] = []
    for s in demo_subjects:
        meta = preprocess_subject(
            image_path=s["image_path"],
            mask_path=s.get("mask_path"),
            output_dir=cfg.preproc_dir,
            cfg=cfg,
            logger=logger,
        )
        if "patient_id" in s:
            meta["patient_id"] = s["patient_id"]
        preproc_metas.append(meta)

    index_path = os.path.join(cfg.logs_dir, "slice_index.json")
    slice_index = build_slice_index(preproc_metas, cfg, index_path)

    loaders = build_dataloaders(slice_index, preproc_metas, cfg, logger=logger)

    for batch in loaders["train"]:
        logger.info(f"Train batch shapes A: {batch['image_a'].shape}, B: {batch['image_b'].shape}")
        break

    # Optional QC display
    # ds_train = loaders["train"].dataset  # type: ignore
    # show_qc_grid(ds_train, indices=[0, 1, 2])

    index_path = os.path.join(cfg.logs_dir, "slice_index.json")
    slice_index = build_slice_index(preproc_metas, cfg, index_path)

    loaders = build_dataloaders(slice_index, preproc_metas, cfg, logger=logger)

    # quick sanity check iteration
    for batch in loaders["train"]:
        # batch contains keys: image_a, mask_a, image_b, mask_b, patient_id, slice_index, midline_col, affine
        logger.info(f"Train batch shapes A: {batch['image_a'].shape}, B: {batch['image_b'].shape}")
        break

    # Optional QC display
    # ds_train = loaders["train"].dataset  # type: ignore
    # show_qc_grid(ds_train, indices=[0, 1, 2])


2025-11-05 10:39:00.956624: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762339141.161197      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762339141.231123      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-05 10:39:18,598 | INFO | Preprocessed 002 saved to ./preproc/002 with stats {'pmin': 487.17303833007816, 'pmax': 13756.887910156245, 'mean': 3683.96484375, 'std': 2340.64062501}
2025-11-05 10:39:23,032 | INFO | Preprocessed 002 saved to ./preproc/002 with stats {'pmin': 397.98068359375003, 'pmax': 13387.878632812452, 'mean': 3646.12353515625, 'std': 2330.84423829125}
2025-11-05 10:39:27,747 | INFO | Preprocessed 002 saved to 