<a href="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/monai/3d_brain_tumor_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Brain tumor 3D segmentation with MONAI and Weights & Biases

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/main/colabs/monai/3d_brain_tumor_segmentation.ipynb)

This tutorial shows how to construct a training workflow of multi-labels 3D brain tumor segmentation task using [MONAI](https://github.com/Project-MONAI/MONAI) and use experiment tracking and data visualization features of [Weights & Biases](https://wandb.ai/site). The tutorial contains the following features:

1. Initialize a Weights & Biases run and synchrozize all configs associated with the run for reproducibility.
2. MONAI transform API:
    1. MONAI Transforms for dictionary format data.
    2. How to define a new transform according to MONAI `transforms` API.
    3. How to randomly adjust intensity for data augmentation.
3. Data Loading and Visualization:
    1. Load Nifti image with metadata, load a list of images and stack them.
    2. Cache IO and transforms to accelerate training and validation.
    3. Visualize the data using `wandb.Table` and interactive segmentation overlay on Weights & Biases.
4. Training a 3D `SegResNet` model
    1. Using the `networks`, `losses`, and `metrics` APIs from MONAI.
    2. Training the 3D `SegResNet` model using a PyTorch training loop.
    3. Track the training experiment using Weights & Biases.
    4. Log and version model checkpoints as model artifacts on Weights & Biases.
5. Visualize and compare the predictions on the validation dataset using `wandb.Table` and interactive segmentation overlay on Weights & Biases.

## 🌴 Setup and Installation

First, let us install the latest version of both MONAI and Weights and Biases.

In [1]:
!pip -q install -U "monai[itk,nibabel,tqdm]" wandb simpleitk

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.6/19.6 MB[0m [31m117.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.9/80.9 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.8/67.8 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m28.0/28.0 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.2/57.2 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m28.5/28.5 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.9/15.9 MB[0m [31m67.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os, json, glob, re, random
import numpy as np
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader

import wandb

from monai.config import print_config
from monai.utils import set_determinism
from monai.data import Dataset, decollate_batch
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.networks.nets import SegResNet
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd, Spacingd,
    RandSpatialCropd, RandFlipd, NormalizeIntensityd, Activations, AsDiscrete,
    SaveImaged
)

print_config()

# ---------- EDIT THESE PATHS ----------
DATA_ROOT = "/content/drive/MyDrive/ARNavigation"  # folder that contains case_XX folders
CKPT_DIR  = "./checkpoints"
PRED_DIR  = "./predictions"
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(PRED_DIR, exist_ok=True)

# What modalities to use (choose ONE of these lists; keep consistent across cases)
# MODALITIES = ["T1"]                    # 1-channel (safest)
MODALITIES = ["T1", "T1CE", "T2"]       # 3-channel (if most cases have these)

# ROI for patches and inference
ROI_TRAIN = (128, 128, 96)
ROI_INFER = (128, 128, 96)

# Training config
SEED = 0
BATCH_SIZE = 1
NUM_WORKERS = 4
MAX_EPOCHS = 20
LR = 1e-4
WEIGHT_DECAY = 1e-5

set_determinism(seed=SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



MONAI version: 1.5.1
Numpy version: 2.0.2
Pytorch version: 2.8.0+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 9c6d819f97e37f36c72f3bdfad676b455bd2fa0d
MONAI __file__: /usr/local/lib/python3.12/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.16.2
Pillow version: 11.3.0
Tensorboard version: 2.19.0
gdown version: 5.2.0
TorchVision version: 0.23.0+cu126
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.2.2
einops version: 0.8.1
transformers version: 4.56.1
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#insta

In [4]:
# If you don't want W&B, set USE_WANDB=False
USE_WANDB = True
WANDB_PROJECT = "monai-lesion-segmentation-nrrd"

if USE_WANDB:
    wandb.login()
    wandb_run = wandb.init(project=WANDB_PROJECT, config={
        "seed": SEED, "roi_train": ROI_TRAIN, "roi_infer": ROI_INFER,
        "batch_size": BATCH_SIZE, "num_workers": NUM_WORKERS,
        "max_epochs": MAX_EPOCHS, "lr": LR, "weight_decay": WEIGHT_DECAY,
        "modalities": MODALITIES, "data_root": DATA_ROOT
    })


invalid escape sequence '\/'


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msachin-daya13[0m ([33msachin-daya13-university-of-cape-town[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
# ----- REPLACE YOUR INDEXING CELL WITH THIS -----
import os, glob, re, json

# Choose what you want to use
# MODALITIES = ["T1"]          # 1-channel (T1WI)
MODALITIES = ["T1", "T1CE"]    # 2-channel (T1WI + T1-CE)
# MODALITIES = ["T1","T1CE","DTI"]  # if you really want DTI as a channel

# Patterns tailored to your filenames
MOD_PATTERNS = {
    # T1WI.nrrd (and variants like T1W.nrrd). Excludes CE/GD.
    "T1":   [r".*\bT1WI\b.*\.nrrd$", r".*\bT1W\b.*\.nrrd$"],
    # T1-CE.nrrd (also T1_CE / T1GD)
    "T1CE": [r".*\bT1[-_]?CE\b.*\.nrrd$", r".*\bT1[-_]?GD\b.*\.nrrd$"],
    # Optional extras
    "T2":   [r".*\bT2(WI)?\b.*\.nrrd$"],
    "DTI":  [r".*\bDTI\b.*\.nrrd$"],
    "CT":   [r".*\bCT\b.*\.nrrd$"],
}

# Your label is "Lesion.nrrd" at the case root
LABEL_PATTERNS = [r"^Lesion.*\.nrrd$"]   # tighten to lesion only

def find_first(patterns, files, prefer_dirs=("imaging_data",)):
    """Return the first file matching any pattern; prefer files under prefer_dirs."""
    cand = []
    for f in files:
        base = os.path.basename(f)
        for pat in patterns:
            if re.search(pat, base, flags=re.IGNORECASE):
                cand.append(f); break
    if not cand:
        return None
    # Prefer files from imaging_data/
    norm = lambda p: p.replace("\\", "/").lower()
    for pref in prefer_dirs:
        for c in cand:
            if f"/{pref.lower()}/" in norm(c):
                return c
    # Otherwise return the shortest path (usually the most specific)
    return sorted(cand, key=len)[0]

def build_items(data_root, modalities):
    items = []
    # look for case_* folders
    case_dirs = sorted([d for d in glob.glob(os.path.join(data_root, "case_*")) if os.path.isdir(d)])
    if not case_dirs:  # fallback: any subfolder under root
        case_dirs = sorted([d for d in glob.glob(os.path.join(data_root, "*")) if os.path.isdir(d)])

    for cdir in case_dirs:
        # find all NRRDs recursively (root + imaging_data)
        files = sorted(glob.glob(os.path.join(cdir, "**", "*.nrrd"), recursive=True))
        if not files:
            continue

        # label: prefer exact Lesion.nrrd at root; else any matching pattern
        label = None
        root_label = os.path.join(cdir, "Lesion.nrrd")
        if os.path.exists(root_label):
            label = root_label
        else:
            label = find_first(LABEL_PATTERNS, [f for f in files if os.path.dirname(f) == cdir] or files)
        if label is None:
            print(f"[skip] no lesion label in {cdir}")
            continue

        # images by requested modalities
        imgs, ok = [], True
        for m in modalities:
            pats = MOD_PATTERNS.get(m, [])
            f = find_first(pats, files)
            if f is None:
                ok = False
                print(f"[skip] missing {m} in {cdir}")
                break
            imgs.append(f)
        if not ok:
            continue

        items.append({
            "image": imgs if len(imgs) > 1 else imgs[0],
            "label": label,
            "case_id": os.path.basename(cdir),
        })

    return items

items_all = build_items(DATA_ROOT, MODALITIES)
print(f"Found {len(items_all)} cases with modalities={MODALITIES} and a lesion label.")
for it in items_all:
    print("•", it["case_id"])
    print("   images:", it["image"])
    print("   label :", it["label"])

assert len(items_all) > 0, "No usable cases found. Check folder names or patterns."



[skip] missing T1 in /content/drive/MyDrive/ARNavigation/case_04
Found 8 cases with modalities=['T1', 'T1CE'] and a lesion label.
• case_01
   images: ['/content/drive/MyDrive/ARNavigation/case_01/imaging_data/T1WI.nrrd', '/content/drive/MyDrive/ARNavigation/case_01/imaging_data/T1-CE.nrrd']
   label : /content/drive/MyDrive/ARNavigation/case_01/Lesion.nrrd
• case_02
   images: ['/content/drive/MyDrive/ARNavigation/case_02/imaging_data/T1WI.nrrd', '/content/drive/MyDrive/ARNavigation/case_02/imaging_data/T1-CE.nrrd']
   label : /content/drive/MyDrive/ARNavigation/case_02/Lesion.nrrd
• case_03
   images: ['/content/drive/MyDrive/ARNavigation/case_03/imaging_data/T1WI.nrrd', '/content/drive/MyDrive/ARNavigation/case_03/imaging_data/T1-CE.nrrd']
   label : /content/drive/MyDrive/ARNavigation/case_03/Lesion.nrrd
• case_05
   images: ['/content/drive/MyDrive/ARNavigation/case_05/imaging_data/T1WI.nrrd', '/content/drive/MyDrive/ARNavigation/case_05/imaging_data/T1-CE.nrrd']
   label : /conte

In [7]:
train_transform = Compose([
    LoadImaged(keys=["image","label"]),                 # ITKReader reads NRRD
    EnsureChannelFirstd(keys="image"),                  # stacks multi-modal -> CxDxHxW
    EnsureTyped(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.0,1.0,1.0), mode=("bilinear","nearest")),
    RandSpatialCropd(keys=["image","label"], roi_size=ROI_TRAIN, random_size=False),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=2),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
])

val_transform = Compose([
    LoadImaged(keys=["image","label"]),
    EnsureChannelFirstd(keys="image"),
    EnsureTyped(keys=["image","label"]),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    Spacingd(keys=["image","label"], pixdim=(1.0,1.0,1.0), mode=("bilinear","nearest")),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
])

post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


Next, we set up transforms for training and validation datasets respectively.

In [16]:
# If your items currently look like:
#   {"image": [T1_path, T1CE_path, ...], "label": label_path, "case_id": "case_01"}
# convert them to separate keys "img_T1", "img_T1CE", etc.

assert "items_all" in globals(), "Run the indexing cell first to create items_all."
assert "MODALITIES" in globals(), "Define MODALITIES = ['T1','T1CE'] or similar first."

def split_items_by_modality(items, modalities):
    new_items = []
    for it in items:
        imgs = it["image"] if isinstance(it["image"], list) else [it["image"]]
        if len(imgs) != len(modalities):
            # skip if modalities missing
            print(f"[skip] {it.get('case_id','?')} expected {len(modalities)} imgs, got {len(imgs)}")
            continue
        new = {"label": it["label"], "case_id": it["case_id"]}
        for m, p in zip(modalities, imgs):
            new[f"img_{m}"] = p
        new_items.append(new)
    return new_items

items_all = split_items_by_modality(items_all, MODALITIES)

# If you use single-case or manual splits, set train_items/val_items now; else keep your KFold flow.
if "train_items" not in globals() or "val_items" not in globals():
    if "splits" in globals() and len(splits) > 0:
        tr, va = splits[0]
        train_items = split_items_by_modality(tr, MODALITIES)
        val_items   = split_items_by_modality(va, MODALITIES)
    else:
        # simple 80/20 or same-case if only one
        if len(items_all) >= 2:
            k = max(1, int(0.8 * len(items_all)))
            train_items, val_items = items_all[:k], items_all[k:]
        else:
            train_items = items_all
            val_items   = items_all

print(f"[info] train={len(train_items)}  val={len(val_items)}")



[info] train=7  val=1


In [17]:
!pip -q install -U "monai[itk,nibabel,tqdm]>=1.3.0"

from monai.transforms import (
    Compose, LoadImaged, EnsureTyped, Orientationd, ResampleToMatchd,
    ConcatItemsd, RandSpatialCropd, RandFlipd, NormalizeIntensityd,
    Activations, AsDiscrete
)

# Build the list of modality keys from MODALITIES, e.g. ["img_T1","img_T1CE"]
IMG_KEYS = [f"img_{m}" for m in MODALITIES]
REF_KEY  = IMG_KEYS[0]  # use the first modality (e.g., img_T1) as reference

# Modes for resampling: linear for images, nearest for labels
RESAMPLE_TO_REF_IMAGES = dict(keys=IMG_KEYS[1:], key_dst=REF_KEY, mode="bilinear") if len(IMG_KEYS) > 1 else None
RESAMPLE_TO_REF_LABEL  = dict(keys="label",      key_dst=REF_KEY, mode="nearest")

ROI_TRAIN = globals().get("ROI_TRAIN", (128,128,96))
ROI_INFER = globals().get("ROI_INFER", (128,128,96))

def make_transforms():
    t_train = [
        LoadImaged(keys=IMG_KEYS + ["label"]),     # read each modality separately
        EnsureTyped(keys=IMG_KEYS + ["label"]),
        Orientationd(keys=IMG_KEYS + ["label"], axcodes="RAS"),
    ]
    # resample other modalities + label to match ref geometry (affine + shape)
    if RESAMPLE_TO_REF_IMAGES:
        t_train.append(ResampleToMatchd(**RESAMPLE_TO_REF_IMAGES))
    t_train.append(ResampleToMatchd(**RESAMPLE_TO_REF_LABEL))
    # stack aligned modalities -> "image" (C,D,H,W)
    t_train.append(ConcatItemsd(keys=IMG_KEYS, name="image", dim=0))
    # augment & normalize
    t_train += [
        RandSpatialCropd(keys=["image","label"], roi_size=ROI_TRAIN, random_size=False),
        RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]

    t_val = [
        LoadImaged(keys=IMG_KEYS + ["label"]),
        EnsureTyped(keys=IMG_KEYS + ["label"]),
        Orientationd(keys=IMG_KEYS + ["label"], axcodes="RAS"),
    ]
    if RESAMPLE_TO_REF_IMAGES:
        t_val.append(ResampleToMatchd(**RESAMPLE_TO_REF_IMAGES))
    t_val.append(ResampleToMatchd(**RESAMPLE_TO_REF_LABEL))
    t_val.append(ConcatItemsd(keys=IMG_KEYS, name="image", dim=0))
    t_val.append(NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True))

    post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    return Compose(t_train), Compose(t_val), post_pred

train_transform, val_transform, post_pred = make_transforms()
print("Transforms ready. Modalities:", MODALITIES, "  Ref key:", REF_KEY)


Transforms ready. Modalities: ['T1', 'T1CE']   Ref key: img_T1


monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


In [19]:
from monai.transforms import SaveImaged

PRED_DIR = globals().get("PRED_DIR", "./predictions")
os.makedirs(PRED_DIR, exist_ok=True)

saver = SaveImaged(
    keys="pred",
    meta_keys=f"{REF_KEY}_meta_dict",   # <— use ref modality’s meta
    output_dir=PRED_DIR,
    output_postfix="pred",
    output_ext=".nii.gz",
    separate_folder=False,
    resample=False,
    print_log=True,
    squeeze_end_dims=True,
)

# Example inference-save loop (works for val_loader):
# with torch.no_grad():
#     for i, batch in enumerate(val_loader):
#         x = batch["image"].to(device)
#         logits = sliding_window_inference(x, roi_size=ROI_INFER, sw_batch_size=1, predictor=model, overlap=0.5)
#         prob = Activations(sigmoid=True)(logits)
#         pred = AsDiscrete(threshold=0.5)(prob)  # (B,1,D,H,W)
#         data = {"pred": pred.cpu(), f"{REF_KEY}_meta_dict": batch[f"{REF_KEY}_meta_dict"]}
#         # Force a nice filename:
#         case_id = val_items[i]["case_id"]
#         data[f"{REF_KEY}_meta_dict"]["filename_or_obj"] = [os.path.join(PRED_DIR, f"{case_id}.nii.gz")]
#         saver(data)


AI generation fix below:

In [11]:
# split data into training and validation sets
# TODO: Implement cross-validation split. For now, we'll use a simple split.
# Example of a simple split:
# train_items = items_all[:int(len(items_all)*0.8)]
# val_items = items_all[int(len(items_all)*0.8):]

# Example of K-Fold cross-validation
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
splits = []
for train_idx, val_idx in kf.split(items_all):
    train_list = [items_all[i] for i in train_idx]
    val_list = [items_all[i] for i in val_idx]
    splits.append((train_list, val_list))

FOLD = 3  # change 0..3
train_items, val_items = splits[FOLD]

train_ds = Dataset(train_items, transform=train_transform)
val_ds   = Dataset(val_items,   transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=1,          shuffle=False, num_workers=NUM_WORKERS)

len(train_ds), len(val_ds)

This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


(7, 1)

In [20]:
IN_CH = len(MODALITIES)
OUT_CH = 1  # binary lesion

model = SegResNet(
    blocks_down=[1,2,2,4], blocks_up=[1,1,1],
    init_filters=16, in_channels=IN_CH, out_channels=OUT_CH, dropout_prob=0.2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS)
loss_function = DiceLoss(sigmoid=True)  # binary
dice_metric = DiceMetric(include_background=False, reduction="mean")  # ignore background
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True


`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.


In [21]:
def infer_sw(x):
    with torch.cuda.amp.autocast():
        return sliding_window_inference(
            inputs=x, roi_size=ROI_INFER, sw_batch_size=1, predictor=model, overlap=0.5
        )


In [29]:
# ===== New transforms (channel-first -> orient -> resample -> concat) + rebuild loaders =====
import gc
from monai.transforms import (
    Compose, LoadImaged, EnsureTyped, EnsureChannelFirstd, Orientationd, ResampleToMatchd,
    ConcatItemsd, RandSpatialCropd, RandFlipd, NormalizeIntensityd,
    Activations, AsDiscrete
)
from monai.data import Dataset, DataLoader

assert 'MODALITIES' in globals() and 'train_items' in globals() and 'val_items' in globals(), "Run the indexing/split cells first."

IMG_KEYS = [f"img_{m}" for m in MODALITIES]
REF_KEY  = IMG_KEYS[0]  # use first modality as the reference grid
ROI_TRAIN = globals().get("ROI_TRAIN", (128,128,96))
ROI_INFER = globals().get("ROI_INFER", (128,128,96))
BATCH_SIZE = globals().get("BATCH_SIZE", 1)

# --- TRAIN transform
train_transform = Compose([
    # per-modality load
    LoadImaged(keys=IMG_KEYS + ["label"]),
    # make sure each is (C, *spatial): if no channel, add C=1
    EnsureChannelFirstd(keys=IMG_KEYS + ["label"]),
    EnsureTyped(keys=IMG_KEYS + ["label"]),
    # reorient to RAS (expects channel-first)
    Orientationd(keys=IMG_KEYS + ["label"], axcodes="RAS"),
    # resample all non-ref modalities to the ref geometry; label with nearest
    ResampleToMatchd(keys=IMG_KEYS[1:], key_dst=REF_KEY, mode="bilinear") if len(IMG_KEYS) > 1 else lambda x: x,
    ResampleToMatchd(keys="label",      key_dst=REF_KEY, mode="nearest"),
    # stack C=1 modalities -> "image" with C = len(MODALITIES)
    ConcatItemsd(keys=IMG_KEYS, name="image", dim=0),
    # crop/aug/normalize
    RandSpatialCropd(keys=["image","label"], roi_size=ROI_TRAIN, random_size=False),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=2),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
])

# --- VAL transform
val_transform = Compose([
    LoadImaged(keys=IMG_KEYS + ["label"]),
    EnsureChannelFirstd(keys=IMG_KEYS + ["label"]),
    EnsureTyped(keys=IMG_KEYS + ["label"]),
    Orientationd(keys=IMG_KEYS + ["label"], axcodes="RAS"),
    ResampleToMatchd(keys=IMG_KEYS[1:], key_dst=REF_KEY, mode="bilinear") if len(IMG_KEYS) > 1 else lambda x: x,
    ResampleToMatchd(keys="label",      key_dst=REF_KEY, mode="nearest"),
    ConcatItemsd(keys=IMG_KEYS, name="image", dim=0),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
])

post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# --- Rebuild datasets/loaders (use num_workers=0 for clear errors first)
train_ds = Dataset(train_items, transform=train_transform)
val_ds   = Dataset(val_items,   transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=1,          shuffle=False, num_workers=0)

gc.collect()

# --- Smoke test a single batch
b = next(iter(train_loader))
print("image shape:", tuple(b["image"].shape), " label shape:", tuple(b["label"].shape))
# expect: image shape (C, D, H, W) after dataset (when collated by DataLoader it becomes (B, C, D, H, W))


monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


image shape: (1, 2, 128, 128, 96)  label shape: (1, 1, 128, 128, 96)


In [30]:
# ===== Fix items -> per-modality keys, filter bad cases, rebuild loaders, smoke test =====
import os, gc
from monai.data import Dataset, DataLoader
import SimpleITK as sitk

assert 'MODALITIES' in globals(), "Define MODALITIES (e.g., ['T1','T1CE']) first."
assert 'items_all' in globals(),  "Run your indexing cell to create items_all first."

# 1) Helpers
def items_are_keyed(items, modalities):
    needed = [f"img_{m}" for m in modalities] + ["label"]
    for it in items:
        if not all(k in it for k in needed):
            return False
    return True

def split_items_by_modality(items, modalities):
    """Convert {'image':[p1,p2,...],'label':...} -> {'img_T1':p1, 'img_T1CE':p2, ...,'label':...}"""
    new_items = []
    for it in items:
        imgs = it["image"] if isinstance(it["image"], list) else [it["image"]]
        if len(imgs) != len(modalities):
            # skip if count mismatch
            print(f"[skip] {it.get('case_id','?')}: expected {len(modalities)} images, got {len(imgs)}")
            continue
        new = {"label": it["label"], "case_id": it.get("case_id","unknown")}
        for m, p in zip(modalities, imgs):
            new[f"img_{m}"] = p
        new_items.append(new)
    return new_items

def filter_cases_with_all_modalities(items, modalities):
    need = [f"img_{m}" for m in modalities] + ["label"]
    good, bad = [], []
    for it in items:
        if all(k in it and it[k] for k in need):
            good.append(it)
        else:
            bad.append(it.get("case_id","unknown"))
    if bad:
        print(f"[info] dropping {len(bad)} case(s) missing modalities: {bad}")
    return good

# 2) Ensure items_all has per-modality keys
if not items_are_keyed(items_all, MODALITIES):
    items_all = split_items_by_modality(items_all, MODALITIES)

# 3) Ensure train_items/val_items exist and are keyed
if 'train_items' not in globals() or 'val_items' not in globals():
    # simple 80/20 or same-case if only one
    if len(items_all) >= 2:
        k = max(1, int(0.8 * len(items_all)))
        train_items, val_items = items_all[:k], items_all[k:]
    else:
        train_items = items_all
        val_items   = items_all

if not items_are_keyed(train_items, MODALITIES):
    train_items = split_items_by_modality(train_items, MODALITIES)
if not items_are_keyed(val_items, MODALITIES):
    val_items   = split_items_by_modality(val_items,   MODALITIES)

# 4) Drop any cases missing required modalities
train_items = filter_cases_with_all_modalities(train_items, MODALITIES)
val_items   = filter_cases_with_all_modalities(val_items,   MODALITIES)

print(f"[info] train={len(train_items)}  val={len(val_items)}  modalities={MODALITIES}")

# 5) Build transforms if needed (expects img_<MOD> keys) — uses your existing ones if present
if 'train_transform' not in globals() or 'val_transform' not in globals():
    from monai.transforms import (
        Compose, LoadImaged, EnsureTyped, Orientationd, ResampleToMatchd,
        ConcatItemsd, RandSpatialCropd, RandFlipd, NormalizeIntensityd,
        Activations, AsDiscrete, EnsureChannelFirstd
    )
    IMG_KEYS = [f"img_{m}" for m in MODALITIES]
    REF_KEY  = IMG_KEYS[0]
    ROI_TRAIN = globals().get("ROI_TRAIN", (128,128,96))

    t_train = [
        lambda x: (print(f"Input to transforms: {x.keys()}"), x)[1], # Debugging print before LoadImaged
        LoadImaged(keys=IMG_KEYS + ["label"], image_only=False, reader="ITKReader"), # Explicitly set reader and load metadata
        lambda x: (print(f"Shape after LoadImaged: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}, Meta: {[x[k].meta for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
        EnsureChannelFirstd(keys=IMG_KEYS + ["label"], channel_dim="no_channel", spatial_dims=3), # Explicitly set spatial_dims
        lambda x: (print(f"Shape after EnsureChannelFirstd: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
        EnsureTyped(keys=IMG_KEYS + ["label"]),
        lambda x: (print(f"Shape after EnsureTyped: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
        Orientationd(keys=IMG_KEYS + ["label"], axcodes="RAS"),
        lambda x: (print(f"Shape after Orientationd: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print

    ]
    if len(IMG_KEYS) > 1:
        t_train.append(ResampleToMatchd(keys=IMG_KEYS[1:], key_dst=REF_KEY, mode="bilinear"))
        t_train.append(lambda x: (print(f"Shape after ResampleToMatchd (images): {[x[k].shape for k in IMG_KEYS]}"), x)[1]) # Debugging print
    t_train.append(ResampleToMatchd(keys="label", key_dst=REF_KEY, mode="nearest"))
    t_train.append(lambda x: (print(f"Shape after ResampleToMatchd (label): {x['label'].shape if 'label' in x else 'N/A'}"), x)[1]) # Debugging print

    t_train += [
        ConcatItemsd(keys=IMG_KEYS, name="image", dim=0),
        # Debugging print to check shape after concat
        lambda x: (print(f"Shape after ConcatItemsd: {x['image'].shape}"), x)[1],
        RandSpatialCropd(keys=["image","label"], roi_size=ROI_TRAIN, random_size=False),
        lambda x: (print(f"Shape after RandSpatialCropd: image {x['image'].shape}, label {x['label'].shape if 'label' in x else 'N/A'}"), x)[1], # Debugging print
        RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
    train_transform = Compose(t_train)

    t_val = [
        lambda x: (print(f"Input to transforms: {x.keys()}"), x)[1], # Debugging print before LoadImaged
        LoadImaged(keys=IMG_KEYS + ["label"], image_only=False, reader="ITKReader"),
        lambda x: (print(f"Shape after LoadImaged: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}, Meta: {[x[k].meta for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
        EnsureChannelFirstd(keys=IMG_KEYS + ["label"], channel_dim="no_channel", spatial_dims=3), # Explicitly set spatial_dims
        lambda x: (print(f"Shape after EnsureChannelFirstd: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
        EnsureTyped(keys=IMG_KEYS + ["label"]),
        lambda x: (print(f"Shape after EnsureTyped: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
        Orientationd(keys=IMG_KEYS + ["label"], axcodes="RAS"),
        lambda x: (print(f"Shape after Orientationd: {[x[k].shape for k in IMG_KEYS + ['label']]}, Dtype: {[x[k].dtype for k in IMG_KEYS + ['label']]}"), x)[1], # Debugging print
    ]
    if len(IMG_KEYS) > 1:
        t_val.append(ResampleToMatchd(keys=IMG_KEYS[1:], key_dst=REF_KEY, mode="bilinear"))
        t_val.append(lambda x: (print(f"Shape after ResampleToMatchd (images): {[x[k].shape for k in IMG_KEYS]}"), x)[1]) # Debugging print
    t_val.append(ResampleToMatchd(keys="label", key_dst=REF_KEY, mode="nearest"))
    t_val.append(lambda x: (print(f"Shape after ResampleToMatchd (label): {x['label'].shape if 'label' in x else 'N/A'}"), x)[1]) # Debugging print

    t_val += [
        ConcatItemsd(keys=IMG_KEYS, name="image", dim=0),
        # Debugging print to check shape after concat
        lambda x: (print(f"Shape after ConcatItemsd: {x['image'].shape}"), x)[1],
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
    val_transform = Compose(t_val)

# 6) Rebuild loaders (num_workers=0 first to surface issues clearly)
BATCH_SIZE  = globals().get("BATCH_SIZE", 1)
NUM_WORKERS = 0   # use 0 now; you can switch back to your NUM_WORKERS later

train_ds = Dataset(train_items, transform=train_transform)
val_ds   = Dataset(val_items,   transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=1,          shuffle=False, num_workers=NUM_WORKERS)

gc.collect()

# 7) Smoke test one batch
_batch = next(iter(train_loader))
print("Sample keys:", list(_batch.keys()))
print("image shape:", tuple(_batch["image"].shape), " label shape:", tuple(_batch["label"].shape))

[info] train=7  val=1  modalities=['T1', 'T1CE']
Sample keys: ['label', 'case_id', 'img_T1', 'img_T1CE', 'image']
image shape: (1, 2, 128, 128, 96)  label shape: (1, 1, 128, 128, 96)


In [32]:
# ==== Train SegResNet end-to-end (works with your current loaders) ====
import os, gc, torch
from tqdm.auto import tqdm
from monai.networks.nets import SegResNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.transforms import Activations, AsDiscrete

# --- sanity: need loaders from your previous cell
assert 'train_loader' in globals() and 'val_loader' in globals(), "Run the transforms + loaders cell first."

# --- config / defaults (reuse globals when present)
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
IN_CH         = len(MODALITIES)                      # channels = number of modalities
OUT_CH        = 1                                    # binary lesion
MAX_EPOCHS    = globals().get("MAX_EPOCHS", 50)
LR            = globals().get("LR", 1e-4)
WEIGHT_DECAY  = globals().get("WEIGHT_DECAY", 1e-5)
ROI_INFER     = globals().get("ROI_INFER", (128,128,96))
CKPT_DIR      = globals().get("CKPT_DIR", "./checkpoints")
USE_WANDB     = globals().get("USE_WANDB", False)
os.makedirs(CKPT_DIR, exist_ok=True)

# --- model / opt / sched
model = SegResNet(
    blocks_down=[1,2,2,4], blocks_up=[1,1,1],
    init_filters=16, in_channels=IN_CH, out_channels=OUT_CH, dropout_prob=0.2
).to(device)

optimizer    = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler    = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS)
loss_function= DiceLoss(sigmoid=True)  # binary
dice_metric  = DiceMetric(include_background=False, reduction="mean")
scaler       = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True

# --- inference helper
def infer_sw(x):
    with torch.cuda.amp.autocast():
        return sliding_window_inference(x, roi_size=ROI_INFER, sw_batch_size=1, predictor=model, overlap=0.5)

# --- checkpoint path
fold_tag = str(globals().get("FOLD", "0"))
best_ckpt = os.path.join(CKPT_DIR, f"best_fold{fold_tag}.pth")
best_dice = -1.0

# --- W&B metric schema (optional)
if USE_WANDB:
    wandb.define_metric("epoch/step")
    wandb.define_metric("epoch/*", step_metric="epoch/step")

# ======================= TRAIN =======================
for epoch in range(MAX_EPOCHS):
    model.train()
    run_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{MAX_EPOCHS} - train"):
        x = batch["image"].to(device)     # (B,C,D,H,W)
        y = batch["label"].to(device)     # (B,1,D,H,W)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            yhat = model(x)
            loss = loss_function(yhat, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        run_loss += float(loss)

    scheduler.step()
    mean_train_loss = run_loss / max(1, len(train_loader))

    # ----- Validation
    model.eval()
    dice_metric.reset()
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="  validate", leave=False):
            x = batch["image"].to(device)
            y = batch["label"].to(device)
            yhat = infer_sw(x)  # (B,1,D,H,W)

            # post-process each volume
            yhat = [AsDiscrete(threshold=0.5)(Activations(sigmoid=True)(p)) for p in decollate_batch(yhat)]
            dice_metric(y_pred=yhat, y=y)

    val_dice = dice_metric.aggregate().item()
    dice_metric.reset()

    if USE_WANDB:
        wandb.log({
            "epoch/step": epoch,
            "epoch/train_loss": mean_train_loss,
            "epoch/val_dice":   val_dice,
            "epoch/lr":         scheduler.get_last_lr()[0],
        })

    print(f"epoch {epoch:03d} | train_loss {mean_train_loss:.4f} | val_dice {val_dice:.4f}")

    # save best
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), best_ckpt)
        if USE_WANDB:
            wandb.save(best_ckpt)

print("Best val Dice:", best_dice, " | saved:", best_ckpt)
gc.collect()


`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.


Epoch 1/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)


  validate:   0%|          | 0/1 [00:00<?, ?it/s]

`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)


epoch 000 | train_loss 0.9958 | val_dice 0.0214


Epoch 2/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 001 | train_loss 0.9953 | val_dice 0.0271


Epoch 3/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 002 | train_loss 0.9980 | val_dice 0.0294


Epoch 4/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 003 | train_loss 0.9948 | val_dice 0.0298


Epoch 5/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 004 | train_loss 0.9944 | val_dice 0.0221


Epoch 6/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 005 | train_loss 0.9979 | val_dice 0.0324


Epoch 7/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 006 | train_loss 0.9972 | val_dice 0.0277


Epoch 8/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 007 | train_loss 0.9970 | val_dice 0.0150


Epoch 9/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 008 | train_loss 0.9957 | val_dice 0.0162


Epoch 10/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 009 | train_loss 0.9992 | val_dice 0.0171


Epoch 11/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 010 | train_loss 0.9943 | val_dice 0.0259


Epoch 12/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 011 | train_loss 0.9978 | val_dice 0.0250


Epoch 13/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 012 | train_loss 0.9923 | val_dice 0.0315


Epoch 14/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 013 | train_loss 0.9899 | val_dice 0.0347


Epoch 15/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 014 | train_loss 0.9955 | val_dice 0.0371


Epoch 16/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 015 | train_loss 0.9918 | val_dice 0.0411


Epoch 17/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 016 | train_loss 0.9901 | val_dice 0.0239


Epoch 18/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 017 | train_loss 0.9890 | val_dice 0.0373


Epoch 19/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 018 | train_loss 0.9919 | val_dice 0.0565


Epoch 20/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 019 | train_loss 0.9903 | val_dice 0.0375


Epoch 21/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 020 | train_loss 0.9912 | val_dice 0.0047


Epoch 22/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 021 | train_loss 0.9972 | val_dice 0.0012


Epoch 23/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 022 | train_loss 0.9996 | val_dice 0.0014


Epoch 24/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 023 | train_loss 0.9949 | val_dice 0.0053


Epoch 25/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 024 | train_loss 0.9954 | val_dice 0.0165


Epoch 26/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 025 | train_loss 0.9987 | val_dice 0.0311


Epoch 27/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 026 | train_loss 0.9928 | val_dice 0.0364


Epoch 28/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 027 | train_loss 0.9926 | val_dice 0.0223


Epoch 29/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 028 | train_loss 0.9934 | val_dice 0.0105


Epoch 30/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 029 | train_loss 0.9994 | val_dice 0.0085


Epoch 31/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 030 | train_loss 0.9953 | val_dice 0.0077


Epoch 32/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 031 | train_loss 0.9944 | val_dice 0.0064


Epoch 33/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 032 | train_loss 0.9872 | val_dice 0.0034


Epoch 34/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 033 | train_loss 0.9936 | val_dice 0.0042


Epoch 35/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 034 | train_loss 0.9938 | val_dice 0.0041


Epoch 36/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 035 | train_loss 0.9890 | val_dice 0.0064


Epoch 37/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 036 | train_loss 0.9907 | val_dice 0.0064


Epoch 38/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 037 | train_loss 0.9921 | val_dice 0.0055


Epoch 39/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 038 | train_loss 0.9917 | val_dice 0.0049


Epoch 40/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 039 | train_loss 0.9978 | val_dice 0.0048


Epoch 41/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 040 | train_loss 0.9970 | val_dice 0.0049


Epoch 42/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 041 | train_loss 0.9900 | val_dice 0.0059


Epoch 43/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 042 | train_loss 0.9903 | val_dice 0.0070


Epoch 44/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 043 | train_loss 0.9878 | val_dice 0.0070


Epoch 45/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 044 | train_loss 0.9953 | val_dice 0.0069


Epoch 46/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 045 | train_loss 0.9996 | val_dice 0.0069


Epoch 47/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 046 | train_loss 0.9943 | val_dice 0.0069


Epoch 48/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 047 | train_loss 0.9945 | val_dice 0.0068


Epoch 49/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 048 | train_loss 0.9918 | val_dice 0.0068


Epoch 50/50 - train:   0%|          | 0/7 [00:00<?, ?it/s]

  validate:   0%|          | 0/1 [00:00<?, ?it/s]

epoch 049 | train_loss 0.9901 | val_dice 0.0068
Best val Dice: 0.05651044845581055  | saved: ./checkpoints/best_fold3.pth


4334

In [39]:
# ==== Save VAL predictions as NIfTI using SimpleITK geometry from the reference modality ====
import os, torch, numpy as np
from tqdm.auto import tqdm
from monai.inferers import sliding_window_inference

import SimpleITK as sitk  # we installed simpleitk earlier
assert 'model' in globals() and 'device' in globals(), "Train the model first."
assert 'val_loader' in globals() and 'val_items' in globals(), "Build loaders first."
assert 'best_ckpt' in globals(), "best_ckpt missing (train cell defines it)."

# config / paths
PRED_DIR  = globals().get("PRED_DIR", "./predictions")
ROI_INFER = globals().get("ROI_INFER", (128,128,96))
MODALITIES= globals().get("MODALITIES", ["T1"])
REF_KEY   = globals().get("REF_KEY", f"img_{MODALITIES[0]}")
os.makedirs(PRED_DIR, exist_ok=True)

# load best weights
model.load_state_dict(torch.load(best_ckpt, map_location=device))
model.eval()

def to_uint8_mask(t):
    """(B,1,D,H,W) torch -> (D,H,W) uint8 numpy"""
    arr = (t[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
    return arr

with torch.no_grad():
    for i, batch in enumerate(tqdm(val_loader, desc="Saving VAL predictions")):
        x = batch["image"].to(device)                                  # (B,C,D,H,W)
        logits = sliding_window_inference(x, roi_size=ROI_INFER, sw_batch_size=1, predictor=model, overlap=0.5)
        pred_mask = to_uint8_mask(torch.sigmoid(logits))               # (D,H,W)

        # pick output name
        case_id = batch.get("case_id", val_items[i]["case_id"])
        if isinstance(case_id, (list, tuple)) and len(case_id): case_id = case_id[0]
        out_path = os.path.join(PRED_DIR, f"{case_id}.nii.gz")

        # get a reference image filepath (prefer the ref modality path from your indexed items)
        ref_path = val_items[i].get(REF_KEY, None)
        if ref_path is None:
            # fallback: use the label file's geometry
            ref_path = val_items[i]["label"]

        # write NIfTI with geometry copied from reference
        ref_img = sitk.ReadImage(ref_path)                             # preserves spacing/origin/direction
        itk_img = sitk.GetImageFromArray(pred_mask)                    # z,y,x
        itk_img.SetSpacing(ref_img.GetSpacing())
        itk_img.SetOrigin(ref_img.GetOrigin())
        itk_img.SetDirection(ref_img.GetDirection())
        sitk.WriteImage(itk_img, out_path)

# show what we wrote
!ls -lh "$PRED_DIR" | sed -n '1,120p'



Saving VAL predictions:   0%|          | 0/1 [00:00<?, ?it/s]

total 100K
-rw-r--r-- 1 root root 97K Sep 25 02:24 case_08.nii.gz


In [40]:
from google.colab import drive
drive.mount('/content/drive')

# Copy predictions to Drive
!mkdir -p "/content/drive/MyDrive/BrainPredictions"
!rsync -ah --info=progress2 "$PRED_DIR/" "/content/drive/MyDrive/BrainPredictions/"

# Verify
!ls -lh "/content/drive/MyDrive/BrainPredictions" | sed -n '1,120p'


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
         98.62K 100%    7.85MB/s    0:00:00 (xfr#1, to-chk=0/2)
total 97K
-rw------- 1 root root 97K Sep 25 02:24 case_08.nii.gz


In [34]:
if USE_WANDB:
    table = wandb.Table(columns=[
        "Case","SliceIdx", *[f"Img-Ch{c}" for c in range(len(MODALITIES))], "GT-Label", "Pred-Label"
    ])
    ds = Dataset(val_items, transform=val_transform)
    with torch.no_grad():
        for i in range(len(ds)):
            s = ds[i]
            img = s["image"].cpu().numpy()  # (C,D,H,W)
            lab = s["label"].cpu().numpy()  # (D,H,W) (or (1,D,H,W))
            if lab.ndim == 4: lab = lab[0]
            x = s["image"].unsqueeze(0).to(device)
            yhat = infer_sw(x)[0]
            yhat = AsDiscrete(threshold=0.5)(Activations(sigmoid=True)(yhat)).cpu().numpy()[0]  # (D,H,W)

            C, D, H, W = img.shape
            for z in range(0, D, max(1, D//16)):
                imgs = [wandb.Image(img[c, z, :, :]) for c in range(C)]
                gt   = wandb.Image(lab[z, :, :])
                pr   = wandb.Image(yhat[z, :, :])
                table.add_data(val_items[i]["case_id"], int(z), *imgs, gt, pr)
    wandb.log({"Predictions/LesionSeg": table})

if USE_WANDB:
    wandb.finish()


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


0,1
epoch/lr,███████▇▇▇▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
epoch/step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch/train_loss,▆▆▇▅▅▇▇▆█▅▄▃▆▄▃▄▃▃▇█▆▇▄▄▄▆▅▁▅▅▃▄▄▇▇▁▆▅▅▃
epoch/val_dice,▅▆▆▆▅▆▃▄▄▅▆▇▇█▇▇▂▁▁▂▆▇▅▃▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂

0,1
epoch/lr,0.0
epoch/step,49.0
epoch/train_loss,0.99008
epoch/val_dice,0.00681
