## Setup imports

In [1]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    RandRotated,
    RandZoomd,
    RandGaussianNoised,
    RandAdjustContrastd,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import matplotlib.pyplot as plt
import shutil
import os
import glob

## Set train/validation/test data filepath

In [2]:
data_dir = "./dataset"
train_images = sorted(glob.glob(os.path.join(data_dir,"train" ,"image", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir,"train" , "mask", "*.nii.gz")))
train_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]

val_images = sorted(glob.glob(os.path.join(data_dir,"val" ,"image", "*.nii.gz")))
val_labels = sorted(glob.glob(os.path.join(data_dir,"val" , "mask", "*.nii.gz")))
val_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(val_images, val_labels)]

test_images = sorted(glob.glob(os.path.join(data_dir,"test" ,"image", "*.nii.gz")))
test_labels = sorted(glob.glob(os.path.join(data_dir,"test" , "mask", "*.nii.gz")))
test_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_images, test_labels)]

## Setup data augmentation

For data augmentation, here are the basic requirements:

1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `ScaleIntensityRanged` clips the CT's data format, HU value, into a certain range (-57,164) and normalize it to (0,1)
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.

You can try more data augmentation techniques to further improve the performance.

In [3]:
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    RandRotated(keys=["image", "label"], prob=0.5, range_x=0.3, range_y=0.3, range_z=0.3),
    RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2),
    RandGaussianNoised(keys="image", prob=0.2, mean=0.0, std=0.05),
    RandAdjustContrastd(keys="image", prob=0.3, gamma=(0.7, 1.3)),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=(32, 32, 16),
        pos=1,
        neg=1,
        num_samples=4,
        image_key="image"
    ),
    Orientationd(keys=["image", "label"], axcodes="RAS")
])

val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS")
])




In [4]:
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from monai.data import CacheDataset, list_data_collate

# class CT_Dataset(Dataset):
#     def __init__(self, dataset_path, transform=None,split='test'):
#         self.dataset_path = dataset_path
#         self.transform = transform
#         self.split = split

#     def __len__(self):
#         return len(self.dataset_path)

#     def __getitem__(self, idx):
#         data = self.dataset_path[idx]
#         image = nib.load(data['image'])
#         label = nib.load(data['label'])
#         image = image.get_fdata()
#         label = label.get_fdata()
#         if self.transform:
#             image, label = self.transform(image, label)
#         return image, label
        

# train_files = train_data_dicts
# val_files = val_data_dicts
# test_files = test_data_dicts
# test_transforms = val_transforms

# # here we don't cache any data in case out of memory issue
# train_ds = CT_Dataset(train_files,train_transforms,split='train')
# train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
# val_ds = CT_Dataset(val_files,val_transforms,split='val')
# val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)
# test_ds = CT_Dataset(test_files,test_transforms,split='test')
# test_loader = DataLoader(test_ds, batch_size=2, shuffle=True, num_workers=4)
# val_ds = CT_Dataset(val_files,val_transforms,split='val')
# val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)
train_ds = CacheDataset(data=train_data_dicts, transform=train_transforms, cache_rate=0.5, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate)

val_ds = CacheDataset(data=val_data_dicts, transform=val_transforms, cache_rate=0.5, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, collate_fn=list_data_collate)

test_ds = CacheDataset(data=test_data_dicts, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, collate_fn=list_data_collate)


Loading dataset: 100%|██████████| 12/12 [00:05<00:00,  2.19it/s]
Loading dataset: 100%|██████████| 4/4 [00:01<00:00,  2.35it/s]
Loading dataset: 100%|██████████| 8/8 [00:06<00:00,  1.15it/s]


# Implement a 3D UNet for segmentation task

We give a possible network structure here, and you can modify it for a stronger performance.

In the block ```double_conv```, you can implement the following structure：

| Layer |
|-------|
| Conv3d |
| BatchNorm3d |
| PReLU |
| Conv3d |
| BatchNorm3d |
| PReLU |


In the overall UNet structure, you can implement the following structure. ```conv_down``` and ```conv_up``` refers to the function block you defined above.

| Layer | Input Channel | Output Channel |
|-------|-------------|--------------|
| conv_down1 | 1 | 16 |
| maxpool | 16 | 16 |
| conv_down2 | 16 | 32 |
| maxpool | 32 | 32 |
| conv_down3 | 32 | 64 |
| maxpool | 64 | 64 |
| conv_down4 | 64 | 128 |
| maxpool | 128 | 128 |
| conv_down5 | 128 | 256 |
| upsample | 256 | 256 |
| conv_up4 | 128+256 | 128 |
| upsample | 128 | 128 |
| conv_up3 | 64+128 | 64 |
| upsample | 64 | 32 |
| conv_up4 | 32+64 | 32 |
| upsample | 32 | 32 |
| conv_up4 | 16+32 | 16 |
| conv_out | 16 | 2 |


In [5]:
import torch
import torch.nn as nn

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm3d(out_channels),
        nn.PReLU(),
        nn.Conv3d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm3d(out_channels),
        nn.PReLU()
    )

class UNet(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.dconv_down1 = double_conv(1, 16)
        self.dconv_down2 = double_conv(16, 32)
        self.dconv_down3 = double_conv(32, 64)
        self.dconv_down4 = double_conv(64, 128)
        self.dconv_down5 = double_conv(128, 256)
        
        self.maxpool = nn.MaxPool3d(2)
        
        self.upsample4 = nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2)
        self.dconv_up4 = double_conv(128 + 256, 128)
        
        self.upsample3 = nn.ConvTranspose3d(128, 128, kernel_size=2, stride=2)
        self.dconv_up3 = double_conv(64 + 128, 64)
        
        self.upsample2 = nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2)
        self.dconv_up2 = double_conv(32 + 64, 32)
        
        self.upsample1 = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2)
        self.dconv_up1 = double_conv(16 + 32, 16)
        
        self.conv_last = nn.Conv3d(16, 2, 1)
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        
        conv4 = self.dconv_down4(x)
        x = self.maxpool(conv4)
        
        x = self.dconv_down5(x)
        
        x = self.upsample4(x)
        x = torch.cat([x, conv4], dim=1)
        x = self.dconv_up4(x)
        
        x = self.upsample3(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        
        x = self.upsample2(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        
        x = self.upsample1(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        return out

## Create Model, Loss, Optimizer

In [6]:
import numpy as np

class _CELoss(torch.nn.CrossEntropyLoss):
    def forward(self, pred, target):
        return super().forward(pred, target.squeeze(1).long())
    

class CombinedLoss(torch.nn.Module):
    def __init__(self, weights_dice=0.5, weights_ce=0.5):
        super().__init__()
        self.dice = DiceLoss(to_onehot_y=True, softmax=True, include_background=False)
        self.ce = _CELoss()
        self.weights_dice = weights_dice
        self.weights_ce = weights_ce

    def forward(self, pred, target):
        return self.dice(pred, target) * self.weights_dice + self.ce(pred, target) * self.weights_ce

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("Using GPU")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)
dice_metric = DiceMetric(include_background=False, reduction="mean")

# loss_function = DiceLoss(to_onehot_y=True, softmax=True, include_background=False)
loss_fns = {
    # "Dice": DiceLoss(to_onehot_y=True, softmax=True, include_background=False),
    # "CE": _CELoss(),
    "Combined_1_1_300_epoch": CombinedLoss(weights_dice=1.0, weights_ce=1.0),
    "Combined_7_3_300_epoch": CombinedLoss(weights_dice=0.7, weights_ce=0.3),
    "Combined_3_7_300_epoch": CombinedLoss(weights_dice=0.3, weights_ce=0.7),
}


Using GPU


## Define your training/val/test loop

In [7]:
from monai.metrics import DiceMetric, HausdorffDistanceMetric, MeanIoU, SurfaceDistanceMetric
from monai.data import MetaTensor
from tqdm.notebook import tqdm
from monai.transforms import EnsureTyped
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

for loss_name, loss_fn in loss_fns.items():
    print(f"Training with {loss_name} loss")
    writer = SummaryWriter(f"runs/{loss_name}_{datetime.now().strftime('%Y%m%d-%H%M%S')}")

    max_epochs = 200
    best_metric = -1
    best_metric_epoch = -1
    train_loss_values = []
    train_dice_values = []
    val_loss_values = []
    val_dice_values = []
    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])

    for epoch in range(max_epochs):
        model.train()
        train_loss = 0
        dice_metric.reset()
        
        for batch_data in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{max_epochs}", unit="batch"):
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_outputs = [post_pred(i) for i in decollate_batch(outputs)]
            train_labels = [post_label(i) for i in decollate_batch(labels)]

            dice_metric(y_pred=train_outputs, y=train_labels)
            train_loss += loss.item()

        train_loss /= len(train_loader)
        epoch_dice = dice_metric.aggregate().item()
        
        train_loss_values.append(train_loss)
        train_dice_values.append(epoch_dice)
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Dice/train", epoch_dice, epoch)

        print(f"Epoch {epoch + 1}/{max_epochs}, Loss: {train_loss:.4f}, Dice: {epoch_dice:.4f}")

        # Val
        model.eval()
        val_loss = 0.0
        dice_metric.reset()
        
        with torch.no_grad():
            for val_data in tqdm(val_loader, desc="Validation", unit="batch"):
                val_images = val_data["image"].to(device)
                val_labels = val_data["label"].to(device)
                
                val_outputs = sliding_window_inference(val_images, (32, 32, 16), 4, model)
            
                loss = loss_fn(val_outputs, val_labels)

                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]

                dice_metric(y_pred=val_outputs, y=val_labels)
                val_loss += loss.item()

            val_dice = dice_metric.aggregate().item()
            val_loss /= len(val_loader)
            val_loss_values.append(val_loss)
            val_dice_values.append(val_dice)
            writer.add_scalar("Loss/val", val_loss, epoch)
            writer.add_scalar("Dice/val", val_dice, epoch)
            
            if val_dice > best_metric:
                best_metric = val_dice
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), f"best_model_{loss_name}.pth")
                
            print(f"Validation - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}")

Training with Combined_1_1_300_epoch loss


Training Epoch 1/200:   0%|          | 0/13 [00:00<?, ?batch/s]

KeyboardInterrupt: 

## Inference and Report performance on Test Set

In [9]:
for loss_name, loss_fn in loss_fns.items():
    print(f"Testing with {loss_name} loss")

    post_pred_save = Compose([
        AsDiscreted(keys="pred", argmax=True),
        EnsureTyped(keys="pred", data_type="tensor")
    ])

    model.load_state_dict(torch.load(f"best_model_{loss_name}.pth"))
    model.eval()
    post_trans = Compose([
        EnsureTyped(keys="pred", data_type="tensor"),
        Invertd(
            keys="pred",
            transform=val_transforms, 
            orig_keys="image",
            nearest_interp=False,
            to_tensor=True,
        ),
        SaveImaged(
            keys="pred",
            meta_keys="image_meta_dict",
            output_dir=f"./output/{loss_name}",
            output_postfix="pred",
            output_ext=".nii.gz",
            resample=False,
            separate_folder=False,
        )
    ])
    os.makedirs(f"./output/{loss_name}", exist_ok=True)

    dice_metric = DiceMetric(include_background=False, reduction="mean")
    jaccard_metric = MeanIoU(include_background=False, reduction="mean")
    hd95_metric = HausdorffDistanceMetric(percentile=95, include_background=False)
    asd_metric = SurfaceDistanceMetric(include_background=False, symmetric=True)

    with torch.no_grad():
        for test_data in tqdm(test_loader, desc="Testing", unit="batch"):
            test_image = test_data["image"].to(device)
            test_label = test_data["label"].to(device)
            
            test_output = sliding_window_inference(test_image, (64, 64, 16), 16, model)

            test_output_save = [post_pred_save({"pred": i}) for i in decollate_batch(test_output)]
            test_output = [post_pred(i) for i in decollate_batch(test_output)]
            test_labels = [post_label(i) for i in decollate_batch(test_label)]
            
            dice_metric(y_pred=test_output, y=test_labels)
            jaccard_metric(y_pred=test_output, y=test_labels)
            hd95_metric(y_pred=test_output, y=test_labels)
            asd_metric(y_pred=test_output, y=test_labels)

            for i in range(len(test_output_save)):
                sample_data = {
                    "image": test_data["image"][i],
                    "pred": MetaTensor(test_output_save[i]["pred"], meta=test_data["image"][i].meta)
                }
                post_trans(sample_data)

    dice_score = dice_metric.aggregate().item()
    jaccard_score = jaccard_metric.aggregate().item()
    hd95_score = hd95_metric.aggregate().item()
    asd_score = asd_metric.aggregate().item()

    print(f"\nTest Results for {loss_name} loss")
    print(f"Dice Score: {dice_score:.4f}")
    print(f"Jaccard Index: {jaccard_score:.4f}")
    print(f"95% Hausdorff Distance: {hd95_score:.4f}")
    print(f"Average Surface Distance: {asd_score:.4f}")


    # just in case I want to run again and forgot the reset the things
    dice_metric.reset()
    jaccard_metric.reset()
    hd95_metric.reset()
    asd_metric.reset()


Testing with Combined_1_1_300_epoch loss


Testing:   0%|          | 0/8 [00:00<?, ?batch/s]

2025-05-12 23:11:25,309 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_52_pred.nii.gz
2025-05-12 23:11:35,522 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_53_pred.nii.gz
2025-05-12 23:11:38,225 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_56_pred.nii.gz
2025-05-12 23:11:40,725 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_59_pred.nii.gz
2025-05-12 23:11:46,551 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_60_pred.nii.gz
2025-05-12 23:11:54,676 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_61_pred.nii.gz
2025-05-12 23:11:59,393 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_62_pred.nii.gz
2025-05-12 23:12:04,257 INFO image_writer.py:197 - writing: output\Combined_1_1_300_epoch\spleen_63_pred.nii.gz

Test Results for Combined_1_1_300_epoch loss
Dice Score: 0.7022
Jaccard Index: 0.5649
95% Hausdorff Dis

Testing:   0%|          | 0/8 [00:00<?, ?batch/s]

2025-05-12 23:12:20,179 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_52_pred.nii.gz
2025-05-12 23:12:28,551 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_53_pred.nii.gz
2025-05-12 23:12:30,971 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_56_pred.nii.gz
2025-05-12 23:12:34,901 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_59_pred.nii.gz
2025-05-12 23:12:44,070 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_60_pred.nii.gz
2025-05-12 23:12:53,956 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_61_pred.nii.gz
2025-05-12 23:12:59,811 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_62_pred.nii.gz
2025-05-12 23:13:05,738 INFO image_writer.py:197 - writing: output\Combined_7_3_300_epoch\spleen_63_pred.nii.gz

Test Results for Combined_7_3_300_epoch loss
Dice Score: 0.7350
Jaccard Index: 0.6051
95% Hausdorff Dis

Testing:   0%|          | 0/8 [00:00<?, ?batch/s]

2025-05-12 23:13:18,598 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_52_pred.nii.gz
2025-05-12 23:13:25,249 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_53_pred.nii.gz
2025-05-12 23:13:27,577 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_56_pred.nii.gz
2025-05-12 23:13:29,182 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_59_pred.nii.gz
2025-05-12 23:13:36,142 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_60_pred.nii.gz
2025-05-12 23:13:44,750 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_61_pred.nii.gz
2025-05-12 23:13:50,422 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_62_pred.nii.gz
2025-05-12 23:13:55,255 INFO image_writer.py:197 - writing: output\Combined_3_7_300_epoch\spleen_63_pred.nii.gz

Test Results for Combined_3_7_300_epoch loss
Dice Score: 0.7503
Jaccard Index: 0.6224
95% Hausdorff Dis