# Dataset

### Import libraries

In [None]:
!pip install monai

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped,
    Orientationd, Spacingd, NormalizeIntensityd, ScaleIntensityRanged,
    CropForegroundd, RandCropByPosNegLabeld,
    RandFlipd, RandRotated, RandZoomd, RandGaussianNoised,
    AsDiscreted, SpatialPadd
)
from monai.data import CacheDataset, DataLoader, load_decathlon_datalist
from monai.utils import set_determinism
from monai.apps import DecathlonDataset

### Get DecathlonDataset

**class monai.apps.DecathlonDataset**

- task – which task to download and execute: one of list (“Task01_BrainTumour”, “Task02_Heart”, “Task03_Liver”, “Task04_Hippocampus”, “Task05_Prostate”, “Task06_Lung”, “Task07_Pancreas”, “Task08_HepaticVessel”, “Task09_Spleen”, “Task10_Colon”).

- transform – transforms to execute operations on input data.

- section – expected data section, can be: training, validation or test.

In [None]:
train_data = DecathlonDataset(root_dir="/content/dataset", task="Task05_Prostate", transform=None, section="training", download=True)

### Load Data

**Separate train/val transforms.** Add augmentation (flip/rotate/zoom/noise) only to train.

In [None]:
set_determinism(42)  # Seeds for random generator

# Common (both train/val)
common = [
    LoadImaged(keys=["image","label"]),  # Load both image data and metadata.
    EnsureChannelFirstd(keys=["image","label"]),  # Adjust or add the channel dimension of input data to ensure channel_first shape.
    Orientationd(keys=["image","label"], axcodes="LPS"),  # LPS (left-posterior-superior) : (x-axis, y-axis, z-axis)
    Spacingd(keys=["image","label"], pixdim=(1.0, 1.0, 3.0), mode=("bilinear","nearest")),  # Resample input image into the specified pixdim - 'bilinear' for image and 'nearest-neighbor' for label
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),  # Normalize only non-zero values or entire image - Calculate mean and std on each channel separately.
    CropForegroundd(keys=["image","label"], source_key="image"),  # Crop an image using a bounding box to help training and evaluation if the valid part is small in the whole medical image.
    EnsureTyped(keys=["image","label"]),  # Ensure the input data to be a PyTorch Tensor or numpy array
]

In [None]:
train_transforms = Compose(common + [
    SpatialPadd(keys=["image", "label"], spatial_size=(192, 192, 32)),
    # Balanced patch sampling (Positive patch: foreground anatomy - Negative patch: background)
    RandCropByPosNegLabeld(  # Crop random fixed sized regions with the center being a foreground or background voxel (3d pixel) based on the Pos Neg Ratio.
        keys=["image","label"],
        label_key="label",
        image_key="image",
        spatial_size=(96,96,32),
        pos=1, neg=1,  #  Ratio: pos / (pos + neg) - for the probability to pick a foreground voxel as a center rather than a background voxel.
        num_samples=1
    ),
    # 3D Augmentations
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),  # Randomly flips the image along x direction (left/right).
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=1),  # Randomly flips the image along y direction (front/back).
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=2),  # Randomly flips the image along z direction (up/down).
    RandRotated(keys=["image","label"], prob=0.2, range_x=0.2, range_y=0.2, range_z=0.2, mode=("bilinear","nearest")),  # Bilinear: 2D image — interpolate between 4 neighboring pixels.
    RandZoomd(keys=["image","label"], prob=0.2, min_zoom=0.9, max_zoom=1.1, mode=("trilinear","nearest")),  # Trilinear: 3D volume — interpolate between 8 neighboring voxels.
    RandGaussianNoised(keys=["image"], prob=0.15, mean=0.0, std=0.05),
])

val_transforms = Compose(common + [
    AsDiscreted(keys=["label"], to_onehot=3)  # Convert input value to One-Hot format (set to_one_hot=N, N is the number of classes)
])

In [None]:
# Load image/label paths of decathlon challenge from JSON file
data_root = "/content/dataset"
datalist = load_decathlon_datalist(f"{data_root}/Task05_Prostate/dataset.json", data_list_key="training")

# Split Training & Validation Data
val_fraction = 0.1
split = int(len(datalist)*(1.0 - val_fraction))
train_files, val_files = datalist[:split], datalist[split:]

# Cache Data: It just returns the cached data instead of re-running the transforms — much faster.
train_ds = CacheDataset(train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)  # cache everything
val_ds   = CacheDataset(val_files,   transform=val_transforms,   cache_rate=1.0, num_workers=4)

# Load Data
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True,  num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

## Visualization

In [None]:
val_loader.dataset[0]['image'].shape

In [None]:
train_loader.dataset[0][0]['image'].shape

In [None]:
train_loader.dataset[0][0]['label'].shape

In [None]:
plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.imshow(train_loader.dataset[0][0]['image'][0, :, :, 10])
plt.subplot(122)
plt.imshow(train_loader.dataset[0][0]['label'][0, :, :, 10])
plt.show()

In [None]:
img = train_loader.dataset[0][0]['image'][0, :, :, 10].numpy()
label = train_loader.dataset[0][0]['label'][0, :, :, 10].numpy()

In [None]:
img.shape

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.pcolormesh(img.T, cmap='Greys_r')
plt.colorbar(label='HU')
plt.axis('off')
plt.subplot(122)
plt.pcolormesh(label.T, cmap='Greys_r')
plt.colorbar(label='HU')
plt.axis('off')
plt.show()

# Model

MONAI 3D UNetR

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from monai.networks.nets import SwinUNETR
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data.utils import decollate_batch
from monai.transforms import AsDiscrete

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# "labels": {
#    "0": "background",
#    "1": "PZ",
#    "2": "TZ"
#  }, -> 3 labels
n_classes = 3

model = SwinUNETR(
    in_channels=2,
    out_channels=3,
    feature_size=48,  # must match pretrained
    depths=(2, 2, 2, 2),  # matches pretrained config
    num_heads=(3, 6, 12, 24),  # matches pretrained config
    window_size=(7, 7, 7),  # local window size
    mlp_ratio=4.0,
    qkv_bias=True,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    norm_name="instance",
    spatial_dims=3,
    use_checkpoint=True,  # saves memory
).to(device)

In [None]:
model

Load Pre-Trained

In [None]:
%cd model
!wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt

In [None]:
weight = torch.load("./model_swinvit.pt", weights_only=True)
model.load_from(weights=weight)

Loss, Optimizer, Metrics

In [None]:
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

dice_metric = DiceMetric(include_background=False, reduction="mean")
post_pred = AsDiscrete(argmax=True, to_onehot=n_classes)
post_label = AsDiscrete(to_onehot=n_classes)

# Training Loop

In [None]:
import gc, torch

# Delete your model, optimizer, and any GPU tensors
del model, optimizer, train_loader, val_loader
del images, labels, outputs, loss

gc.collect()
torch.cuda.empty_cache()


In [None]:
epoch_num = 20

for epoch in range(epoch_num):
    print(f"Epoch {epoch+1}/{epoch_num}")
    model.train()
    running_loss = 0.0

    for batch in train_loader:

      images, labels = (batch["image"].to(device), batch["label"].to(device))

      with torch.autocast("cuda"):
            outputs = model(images)
            loss = loss_fn(outputs, labels)

      loss.backward()
      nn.utils.clip_grad_norm_(model.parameters(), 12.0)
      optimizer.step()
      running_loss += loss.item()
    scheduler.step()
    print(f"  Train Loss: {running_loss/len(train_loader):.4f}")

In [None]:
# Save the model's state_dict
PATH = "my_model.pth"
torch.save(model.state_dict(), PATH)

# Test & Visualization

In [None]:
with torch.no_grad(), torch.cuda.amp.autocast():
    for batch in train_loader:
        img, label = batch["image"].to(device), batch["label"].to(device)
        pred = model(img)                  # [B, C, H, W, D]
        pred_soft = torch.softmax(pred, dim=1)
        pred_mask = torch.argmax(pred_soft, dim=1)  # [B, H, W, D]
        break  # just first batch

In [None]:
img_np = img.cpu().numpy()[0, 0]        # shape: H x W x D
label_np = label.cpu().numpy()[0, 0]
pred_np = pred_mask.cpu().numpy()[0]

slice_idx = img_np.shape[-1] // 2

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img_np[:, :, slice_idx], cmap="gray")
plt.title("Image")

plt.subplot(1, 3, 2)
plt.imshow(label_np[:, :, slice_idx], cmap="gray")
plt.title("Ground Truth")

plt.subplot(1, 3, 3)
plt.imshow(img_np[:, :, slice_idx], cmap="gray")
plt.imshow(pred_np[:, :, slice_idx], cmap="jet", alpha=0.5)
plt.title("Prediction Overlay")

plt.show()

# Useful Links

[swin_unetr_btcv_segmentation_3d](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb)

[Prostate_MRI_Anatomy_Model](https://github.com/Project-MONAI/tutorials/blob/main/model_zoo/TCIA_PROSTATEx_Prostate_MRI_Anatomy_Model.ipynb)

[MONAI Applications](https://docs.monai.io/en/stable/apps.html)