## Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# !pip install py_distance_transforms
!pip install ipywidgets
!pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
import matplotlib.pyplot as plt
%matplotlib inline

Traceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'monai'
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.5/266.5 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m
[?25h

**NOTE**: *First time importing `py_distance_transforms` might take a while (~up to 8 mins)*

In [4]:
# from py_distance_transforms import transform_cuda
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric, HausdorffDistanceMetric, compute_percent_hausdorff_distance, compute_iou, MeanIoU
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

from scipy.ndimage import distance_transform_edt
import torch.nn.functional as F
import numpy as np
import time
import timeit
import pandas as pd

# from juliacall import Main as jl
# jl.seval("import CUDA")

# print_config()

**Setup data directory**

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [5]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmps9exetfc


**Download dataset**

Downloads and extracts the dataset.  
The dataset comes from http://medicaldecathlon.com/.

In [6]:
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

Task09_Spleen.tar: 1.50GB [01:27, 18.5MB/s]                            

2024-07-17 17:11:31,596 - INFO - Downloaded: /tmp/tmps9exetfc/Task09_Spleen.tar





2024-07-17 17:11:34,634 - INFO - Verified 'Task09_Spleen.tar', md5: 410d4a301da4e5b2f6f86ec3ddba524e.
2024-07-17 17:11:34,635 - INFO - Writing into directory: /tmp/tmps9exetfc.


**Set MSD Spleen dataset path**

In [7]:
train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

**Set deterministic training for reproducibility**

In [8]:
set_determinism(seed=0)

**Setup transforms for training and validation**

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.

In [9]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        )
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)



In [18]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

Loading dataset: 100%|██████████| 32/32 [00:30<00:00,  1.03it/s]
Loading dataset: 100%|██████████| 9/9 [00:06<00:00,  1.39it/s]


## Qualitative Comparison

In [11]:
data_path_dir = f"/content/drive/MyDrive/dev/MolloiLab/distance-transforms-paper/data"

In [10]:
device = torch.device("cuda:0")

In [19]:
# Load the best models
best_model_dice = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
best_model_dice.load_state_dict(torch.load(os.path.join(data_path_dir, "best_metric_model_dice.pth")))

best_model_hd_dice = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
best_model_hd_dice.load_state_dict(torch.load(os.path.join(data_path_dir, "best_metric_model_hd_dice.pth")))

# Set the models to evaluation mode
best_model_dice.eval()
best_model_hd_dice.eval();

In [20]:
from ipywidgets import interact, IntSlider
import os

In [32]:
def visualize_slice(slice_num, images, labels, masks_list):
    fig, axs = plt.subplots(len(images), 4, figsize=(15, 5 * len(images)))

    for idx in range(len(images)):
        image = images[idx]
        label = labels[idx]
        masks = masks_list[idx]

        ax = axs[idx] if len(images) > 1 else [axs]
        ax[0].imshow(image[:, :, slice_num], cmap="gray", interpolation='lanczos')
        ax[0].set_title(f"Input Image {idx+1}")
        ax[0].axis('off')

        ax[1].imshow(label[:, :, slice_num], cmap="gray", interpolation='lanczos')
        ax[1].set_title(f"Ground Truth {idx+1}")
        ax[1].axis('off')

        ax[2].imshow(masks[0][:, :, slice_num], cmap="gray", interpolation='lanczos')
        ax[2].set_title(f"Model Dice Prediction {idx+1}")
        ax[2].axis('off')

        ax[3].imshow(masks[1][:, :, slice_num], cmap="gray", interpolation='lanczos')
        ax[3].set_title(f"Model HD Dice Prediction {idx+1}")
        ax[3].axis('off')

    plt.show()

def load_and_visualize(val_loader, best_model_dice, best_model_hd_dice, device):
    with torch.no_grad():
        images = []
        labels = []
        masks_list = []

        for i, val_data in enumerate(val_loader):
            if i == 3:  # Change this to load three sets of images
                break

            image = val_data["image"].to(device)
            label = val_data["label"].to(device)

            roi_size = (160, 160, 160)
            sw_batch_size = 4

            output_dice = sliding_window_inference(image, roi_size, sw_batch_size, best_model_dice)
            output_hd_dice = sliding_window_inference(image, roi_size, sw_batch_size, best_model_hd_dice)

            for b in range(image.shape[0]):
                images.append(image[b, 0].cpu().numpy())
                labels.append(label[b, 0].cpu().numpy())
                mask_dice_np = torch.argmax(output_dice, dim=1)[b].cpu().numpy()
                mask_hd_dice_np = torch.argmax(output_hd_dice, dim=1)[b].cpu().numpy()
                masks_list.append((mask_dice_np, mask_hd_dice_np))

        slice_slider = IntSlider(min=0, max=images[0].shape[2] - 1, step=1, value=0)
        interact(lambda slice_num: visualize_slice(slice_num, images, labels, masks_list), slice_num=slice_slider)

In [33]:
# Example usage
load_and_visualize(val_loader, best_model_dice, best_model_hd_dice, device)

interactive(children=(IntSlider(value=0, description='slice_num', max=112), Output()), _dom_classes=('widget-i…