# Overview

This notebook compares different loss functions for training a 3D segmentation model using MONAI. We focus on three loss function combinations:

1. Pure Dice Loss (Gold Standard)
2. PyDTs HD Loss + Dice Loss (Proposed method)
3. Scipy HD Loss + Dice Loss

The notebook has three main sections:

**Loss Function Timings**

We benchmark the computation timings for each loss function:
1. Pure Dice Loss
2. PyDTs HD Loss
3. Scipy HD Loss

outside of the training loop and in isolation. This helps us understand the computational efficiency of each loss function independently. Results are saved to CSV files.

**Training Loop Timings**

We benchmark the training loop timings for each loss function combination:

1. Pure Dice Loss
2. PyDTs HD Loss + Dice Loss (proposed method)
3. Scipy HD Loss + Dice Loss

We run the training loop for 10 epochs and measure the total training time, average epoch time, and standard deviation of epoch time. This helps us assess the computational efficiency of each loss function combination within the context of a training loop. Results are saved to CSV files.

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [None]:
!pip install py_distance_transforms
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

**NOTE**: *First time importing `py_distance_transforms` might take a while (~up to 8 mins)*

In [None]:
from py_distance_transforms import transform_cuda
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric, HausdorffDistanceMetric, compute_percent_hausdorff_distance, compute_iou, MeanIoU
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

from scipy.ndimage import distance_transform_edt
import torch.nn.functional as F
import numpy as np
import time
import timeit
import pandas as pd

from juliacall import Main as jl
jl.seval("import CUDA")

# print_config()

**Setup data directory**

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

**Download dataset**

Downloads and extracts the dataset.  
The dataset comes from http://medicaldecathlon.com/.

In [None]:
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

**Set MSD Spleen dataset path**

In [None]:
train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

**Set deterministic training for reproducibility**

In [None]:
set_determinism(seed=0)

**Setup transforms for training and validation**

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        )
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)

## Check setup in DataLoader

In [None]:
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()

## CacheDataset

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  
To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value.  
Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings.  
And set `num_workers` to enable multi-threads during caching.  
If want to to try the regular Dataset, just change to use the commented code below.

In [None]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

# Loss Function Timings
Write custom losses: [credit](https://github.com/JunMa11/SegWithDistMap/blob/master/code/train_LA_HD.py)

In [None]:
data_path_dir = f"/content/drive/MyDrive/dev/MolloiLab/distance-transforms-paper/data"

In [None]:
np_pred = np.random.choice([0, 1], size=(1, 1, 224, 224, 112)).astype(np.float32)
np_label = np.random.choice([0, 1], size=(1, 1, 224, 224, 112)).astype(np.float32)

torch_pred = torch.tensor(np_pred).cuda()
torch_label = torch.tensor(np_label).cuda()

## Dice Loss

In [None]:
dice_loss = DiceLoss(to_onehot_y=False, softmax=False)

In [None]:
%%timeit
dice_loss(torch_pred, torch_label)

In [None]:
import statistics

In [None]:
# Number of runs
number = 5
repeat = 10

In [None]:
# Use timeit.repeat to get multiple timing results
times_dice = timeit.repeat(lambda: dice_loss(torch_pred, torch_label), number=7, repeat=number)

In [None]:
# Calculate statistics
min_dice, std_dice = min(times_dice), statistics.stdev(times_dice)
min_dice, std_dice # seconds

## Scipy HD Loss

In [None]:
def compute_dtm(img_gt, out_shape):
    """
    compute the distance transform map of foreground in binary mask
    input: segmentation, shape = (batch_size, x, y, z)
    output: the foreground Distance Map (SDM)
    dtm(x) = 0; x in segmentation boundary
             inf|x-y|; x in segmentation
    """

    fg_dtm = np.zeros(out_shape)

    for b in range(out_shape[0]): # batch size
        for c in range(1, out_shape[1]):
            posmask = img_gt[b].astype(bool)
            if posmask.any():
                posdis = distance_transform_edt(posmask)
                fg_dtm[b][c] = posdis

    return fg_dtm

In [None]:
def hd_loss(seg_soft, gt, seg_dtm, gt_dtm):
    """
    compute huasdorff distance loss for binary segmentation
    input: seg_soft: softmax results,  shape=(b,2,x,y,z)
           gt: ground truth, shape=(b,x,y,z)
           seg_dtm: segmentation distance transform map; shape=(b,2,x,y,z)
           gt_dtm: ground truth distance transform map; shape=(b,2,x,y,z)
    output: boundary_loss; sclar
    """

    delta_s = (seg_soft[:,1,...] - gt.float()) ** 2
    s_dtm = seg_dtm[:,1,...] ** 2
    g_dtm = gt_dtm[:,1,...] ** 2
    dtm = s_dtm + g_dtm

    multipled = torch.einsum('bxyz, bxyz->bxyz', delta_s, dtm)
    hd_loss = multipled.mean()

    return hd_loss

In [None]:
# Convert torch_pred to shape (b, 2, x, y, z) for softmax output
torch_pred_soft = torch.stack((1 - torch_pred, torch_pred), dim=1).squeeze(2)

In [None]:
def time_hd_loss_scipy():
  gt_dtm_npy = compute_dtm(torch_label.cpu().numpy(), torch_pred_soft.shape)
  gt_dtm = torch.from_numpy(gt_dtm_npy).float().cuda(torch_pred_soft.device.index)
  seg_dtm_npy = compute_dtm(torch_pred_soft[:, 1, :, :, :].cpu().numpy() > 0.5, torch_pred_soft.shape)
  seg_dtm = torch.from_numpy(seg_dtm_npy).float().cuda(torch_pred_soft.device.index)

  hd_loss(torch_pred_soft, torch_label[:, 0, :, :, :], seg_dtm, gt_dtm)

In [None]:
times_hd_scipy = timeit.repeat(lambda: time_hd_loss_scipy(), number=number, repeat=repeat)

In [None]:
# Calculate statistics
min_hd_scipy, std_hd_scipy = min(times_hd_scipy), statistics.stdev(times_hd_scipy)
min_hd_scipy, std_hd_scipy # seconds

## PyDTs HD Loss (Proposed)

In [None]:
def compute_dtm_gpu(img_gt, out_shape):
    """
    compute the distance transform map of foreground in binary mask
    input: segmentation, shape = (batch_size, x, y, z)
    output: the foreground Distance Map (SDM)
    dtm(x) = 0; x in segmentation boundary
             inf|x-y|; x in segmentation
    """

    # Convert img_gt to float if not already float
    if img_gt.dtype != torch.float32:
        img_gt = img_gt.float()

    fg_dtm = torch.zeros(out_shape, dtype=torch.float32, device=img_gt.device)

    for b in range(out_shape[0]):  # batch size
        for c in range(1, out_shape[1]):
            posmask = img_gt[b]
            if posmask.bool().any():
                posdis = transform_cuda(posmask)
                fg_dtm[b, c] = posdis

    return fg_dtm.to(img_gt.dtype)

In [None]:
def time_hd_loss_pydt():
  gt_dtm = compute_dtm_gpu(torch_label, torch_pred_soft.shape)
  seg_dtm = compute_dtm_gpu(torch_pred_soft[:, 1, :, :, :] > 0.5, torch_pred_soft.shape)
  hd_loss(torch_pred_soft, torch_label[:, 0, :, :, :], seg_dtm, gt_dtm)

In [None]:
times_hd_pydt = timeit.repeat(lambda: time_hd_loss_pydt(), number=number, repeat=repeat)

In [None]:
jl.seval("CUDA.GC.gc(true); CUDA.reclaim()") # IMPORTANT, otherwise GPU RAM can overflow

In [None]:
# Calculate statistics
min_hd_pydt, std_hd_pydt = min(times_hd_pydt[1:-1]), statistics.stdev(times_hd_pydt[1:-1])
min_hd_pydt, std_hd_pydt # seconds

In [None]:
# Create a dictionary with the results
results = {
    'Loss Function': ['Dice Loss', 'HD Loss (SciPy)', 'HD Loss (PyDT)'],
    'Minimum Time (s)': [min_dice, min_hd_scipy, min_hd_pydt],
    'Standard Deviation (s)': [std_dice, std_hd_scipy, std_hd_pydt]
}

# Create the DataFrame
df_pure_losses = pd.DataFrame(results)

# Set 'Loss Function' as the index
df_pure_losses = df_pure_losses.set_index('Loss Function')

# Format the numbers to 6 decimal places
# df_pure_losses = df_pure_losses.applymap(lambda x: f"{x:.6f}")

# Display the DataFrame
print(df_pure_losses)

In [None]:
# Save the DataFrame as a CSV file
df_pure_losses_path = f"{data_path_dir}/hd_loss_pure_losses_timings.csv"
df_pure_losses.to_csv(df_pure_losses_path)

# Training Loop Timings

In [None]:
dice_loss = DiceLoss(to_onehot_y=True, softmax=True)

## Pure Dice Loss (Gold Standard)

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

max_epochs = 10
model_data = []
epoch_loss_values = []
epoch_times = []

for epoch in range(max_epochs):
    start_time = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = dice_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    end_time = time.time()
    epoch_time = end_time - start_time
    epoch_times.append(epoch_time)
    print(f"Epoch {epoch + 1} training time: {epoch_time:.2f} seconds")

# Print total training time, average time per epoch, and standard deviation
total_training_time = sum(epoch_times)
avg_epoch_time = np.mean(epoch_times)
std_epoch_time = np.std(epoch_times)
print(f"\nTotal training time: {total_training_time:.2f} seconds")
print(f"Average training time per epoch: {avg_epoch_time:.2f} seconds")
print(f"Standard deviation of training time per epoch: {std_epoch_time:.2f} seconds")

# Append the model's details and timings to the list
model_data.append({
    'Model': 'Plain Dice Loss',
    'Total Training Time (s)': total_training_time,
    'Avg Epoch Time (s)': avg_epoch_time,
    'Std Epoch Time (s)': std_epoch_time
})

# Create a DataFrame from the model_data list
df_plain_dice = pd.DataFrame(model_data)

# Save the DataFrame as a CSV file
df_dice_path = f"{data_path_dir}/hd_loss_plain_dice_timing.csv"
df_plain_dice.to_csv(df_dice_path)

In [None]:
torch.cuda.empty_cache()

## Scipy HD Loss + Dice Loss

Run this for only 10 epochs, just to showcase the difference in training speed. Accuracy differences between `scipy.ndimage.distance_transform_edt` and `py_distance_transforms.transform_cuda` should be neglible

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

max_epochs = 10
model_data = []
epoch_loss_values = []
epoch_times = []

alpha = 1.0

for epoch in range(max_epochs):
    start_time = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss_seg_dice = dice_loss(outputs, labels)
        outputs_soft = F.softmax(outputs, dim=1)

        with torch.no_grad():
            gt_dtm_npy = compute_dtm(labels.cpu().numpy(), outputs_soft.shape)
            gt_dtm = torch.from_numpy(gt_dtm_npy).float().cuda(outputs_soft.device.index)
            seg_dtm_npy = compute_dtm(outputs_soft[:, 1, :, :, :].cpu().numpy()>0.5, outputs_soft.shape)
            seg_dtm = torch.from_numpy(seg_dtm_npy).float().cuda(outputs_soft.device.index)

        loss_hd = hd_loss(outputs_soft, labels[:, 0, :, :, :], seg_dtm, gt_dtm)
        loss = alpha*loss_seg_dice + (1 - alpha) * loss_hd
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    end_time = time.time()
    epoch_time = end_time - start_time
    epoch_times.append(epoch_time)
    print(f"Epoch {epoch + 1} training time: {epoch_time:.2f} seconds")

    alpha -= 0.001
    if alpha <= 0.001:
        alpha = 0.001

# Print total training time, average time per epoch, and standard deviation
total_training_time = sum(epoch_times)
avg_epoch_time = np.mean(epoch_times)
std_epoch_time = np.std(epoch_times)
print(f"\nTotal training time: {total_training_time:.2f} seconds")
print(f"Average training time per epoch: {avg_epoch_time:.2f} seconds")
print(f"Standard deviation of training time per epoch: {std_epoch_time:.2f} seconds")

# Append the model's details and timings to the list
model_data.append({
    'Model': 'Scipy HD Loss + Dice Loss',
    'Total Training Time (s)': total_training_time,
    'Avg Epoch Time (s)': avg_epoch_time,
    'Std Epoch Time (s)': std_epoch_time
})

# Create a DataFrame from the model_data list
df_hd_dice_scipy = pd.DataFrame(model_data)

# Save the DataFrame as a CSV file
df_hd_dice_scipy_path = f"{data_path_dir}/hd_loss_hd_dice_scipy_timing.csv"
df_hd_dice_scipy.to_csv(df_hd_dice_scipy_path)

In [None]:
torch.cuda.empty_cache()
torch.cuda.empty_cache()
torch.cuda.empty_cache()

## PyDTs HD Loss + Dice Loss (Proposed)

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

max_epochs = 10
model_data = []
epoch_loss_values = []
epoch_times = []

alpha = 1.0

for epoch in range(max_epochs):
    start_time = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss_seg_dice = dice_loss(outputs, labels)
        outputs_soft = F.softmax(outputs, dim=1)

        with torch.no_grad():
            gt_dtm = compute_dtm_gpu(labels, outputs_soft.shape)
            seg_dtm = compute_dtm_gpu(outputs_soft[:, 1, :, :, :]>0.5, outputs_soft.shape)

        loss_hd = hd_loss(outputs_soft, labels[:, 0, :, :, :], seg_dtm, gt_dtm)
        loss = alpha*loss_seg_dice + (1 - alpha) * loss_hd
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    jl.seval("CUDA.GC.gc(true); CUDA.reclaim()") # IMPORTANT, otherwise GPU RAM overflows, not a huge slowdown penalty either
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    end_time = time.time()
    epoch_time = end_time - start_time
    epoch_times.append(epoch_time)
    print(f"Epoch {epoch + 1} training time: {epoch_time:.2f} seconds")

    alpha -= 0.001
    if alpha <= 0.001:
        alpha = 0.001

# Print total training time, average time per epoch, and standard deviation
total_training_time = sum(epoch_times)
avg_epoch_time = np.mean(epoch_times)
std_epoch_time = np.std(epoch_times)
print(f"\nTotal training time: {total_training_time:.2f} seconds")
print(f"Average training time per epoch: {avg_epoch_time:.2f} seconds")
print(f"Standard deviation of training time per epoch: {std_epoch_time:.2f} seconds")

# Append the model's details and timings to the list
model_data.append({
    'Model': 'PyDTs HD Loss + Dice Loss',
    'Total Training Time (s)': total_training_time,
    'Avg Epoch Time (s)': avg_epoch_time,
    'Std Epoch Time (s)': std_epoch_time
})

# Create a DataFrame from the model_data list
df_hd_dice_pydt = pd.DataFrame(model_data)

# Save the DataFrame as a CSV file
df_hd_dice_pydt_path = f"{data_path_dir}/hd_loss_hd_dice_pydt_timing.csv"
df_hd_dice_pydt.to_csv(df_hd_dice_pydt_path)

In [None]:
torch.cuda.empty_cache()