In [1]:
import os
import sys
from glob import glob
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.transforms.functional as F

import cv2
import monai
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    RandFlipd,
    RandShiftIntensityd,
    RandRotate90d,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
import SimpleITK as sitk
from scipy.spatial.distance import directed_hausdorff

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [2]:
min_intensity = -175
max_intensity = 250

In [3]:
# Loading training set
random.seed(53)

images = glob('../AortaSeg/images' + "/*.mha", recursive=True)
masks = glob('../AortaSeg/masks' + "/*.mha", recursive=True)

images = sorted(images)
masks = sorted(masks)

data_list = [{"img": img, "mask": mask} for img, mask in zip(images, masks)]

num_train = int(0.8*len(data_list))

num_val = len(data_list)-num_train

train_list = random.sample(data_list, num_train)

for element in train_list:
    data_list.remove(element)

val_list = data_list

print("There are: {} train and {} validation samples.".format(len(train_list), len(val_list)))

There are: 40 train and 10 validation samples.


In [4]:
print(val_list)

[{'img': '../AortaSeg/images/subject001_CTA.mha', 'mask': '../AortaSeg/masks/subject001_label.mha'}, {'img': '../AortaSeg/images/subject006_CTA.mha', 'mask': '../AortaSeg/masks/subject006_label.mha'}, {'img': '../AortaSeg/images/subject010_CTA.mha', 'mask': '../AortaSeg/masks/subject010_label.mha'}, {'img': '../AortaSeg/images/subject017_CTA.mha', 'mask': '../AortaSeg/masks/subject017_label.mha'}, {'img': '../AortaSeg/images/subject026_CTA.mha', 'mask': '../AortaSeg/masks/subject026_label.mha'}, {'img': '../AortaSeg/images/subject029_CTA.mha', 'mask': '../AortaSeg/masks/subject029_label.mha'}, {'img': '../AortaSeg/images/subject045_CTA.mha', 'mask': '../AortaSeg/masks/subject045_label.mha'}, {'img': '../AortaSeg/images/subject050_CTA.mha', 'mask': '../AortaSeg/masks/subject050_label.mha'}, {'img': '../AortaSeg/images/subject052_CTA.mha', 'mask': '../AortaSeg/masks/subject052_label.mha'}, {'img': '../AortaSeg/images/subject053_CTA.mha', 'mask': '../AortaSeg/masks/subject053_label.mha'}]

In [5]:
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "mask"], ensure_channel_first=True),
        ScaleIntensityRanged(
            keys=["img"],
            a_min=min_intensity,
            a_max=max_intensity,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["img", "mask"], source_key="img"),
        Orientationd(keys=["img", "mask"], axcodes="RAS"),
        Spacingd(keys=["img", "mask"], pixdim=(2, 2, 2), mode=("bilinear", "nearest"))

    ]
)

monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


In [6]:
num_cpus = torch.get_num_threads()

In [7]:
val_ds = CacheDataset(val_list, val_transforms, cache_num=len(val_list), cache_rate=1, num_workers = num_cpus)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=num_cpus, pin_memory=True)

Loading dataset: 100%|██████████| 10/10 [01:04<00:00,  6.43s/it]


In [8]:
num_classes=24

### Unet Evaluation

In [9]:
unet = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=num_classes,
    channels=(32, 64, 128, 256, 512, 1024),
    strides=(2, 2, 2, 2, 2),
    num_res_units=0,
).cuda()

In [10]:
unet.load_state_dict(torch.load("3DUNet_model.pth", weights_only=True))

<All keys matched successfully>

In [22]:
from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance

dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False, num_classes=num_classes)
mean_hd = 0
mean_sd = 0

unet.eval()
print('Begin to validate!!')
with torch.no_grad():
    for batch in tqdm(val_loader):
        val_img = batch['img'].cuda()
        val_mask = batch['mask'].cuda()
        val_mask_one_hot = AsDiscrete(to_onehot = num_classes)(val_mask)

        roi_size = (128,128,128)
        sw_batch_size = 4
        val_output = sliding_window_inference(val_img, roi_size, sw_batch_size, unet)        

        val_output = torch.argmax(val_output, dim=1, keepdim=True)
        val_output_one_hot = AsDiscrete(to_onehot = num_classes)(val_output)
        
        dice_metric(y_pred=val_output, y=val_mask)
        
        hd = compute_hausdorff_distance(y_pred = val_output_one_hot, y = val_mask_one_hot, 
                                        percentile=95, spacing = [2.0,2.0,2.0]).squeeze()
        
        
        sd = compute_average_surface_distance(y_pred = val_output_one_hot, y = val_mask_one_hot,
                                              spacing = [2.0,2.0,2.0]).squeeze()
        
        valid_mask_hd = torch.isfinite(hd)
        
        hd_filtered = hd[valid_mask_hd]
        
        valid_mask_sd = torch.isfinite(sd)
        
        sd_filtered = sd[valid_mask_sd]
        
        mean_hd += (hd_filtered.sum() / len(hd_filtered))
        mean_sd += (sd_filtered.sum() / len(sd_filtered))
        
    mean_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    mean_hd = mean_hd / 10
    mean_sd = mean_sd / 10

    print('Validation ended !!') 
    print("Mean Validation Dice is ",mean_dice)
    print("Mean Surface distance is ", mean_sd.item())
    print("Mean Hausdorff distance is ", mean_hd.item())

Begin to validate!!


100%|██████████| 10/10 [01:41<00:00, 10.13s/it]

Validation ended !!
Mean Validation Dice is  0.6445080637931824
Mean Surface distance is  4.335171699523926
Mean Hausdorff distance is  17.256668090820312





### VNet Evaluation

In [25]:
vnet = monai.networks.nets.VNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=num_classes
).cuda()

In [26]:
vnet.load_state_dict(torch.load("3DVNet_model.pth", weights_only=True))

<All keys matched successfully>

In [27]:
from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance

dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False, num_classes=num_classes)
mean_hd = 0
mean_sd = 0

vnet.eval()
print('Begin to validate!!')
with torch.no_grad():
    for batch in tqdm(val_loader):
        val_img = batch['img'].cuda()
        val_mask = batch['mask'].cuda()
        val_mask_one_hot = AsDiscrete(to_onehot = num_classes)(val_mask)

        roi_size = (128,128,128)
        sw_batch_size = 4
        val_output = sliding_window_inference(val_img, roi_size, sw_batch_size, vnet)        

        val_output = torch.argmax(val_output, dim=1, keepdim=True)
        val_output_one_hot = AsDiscrete(to_onehot = num_classes)(val_output)
        
        dice_metric(y_pred=val_output, y=val_mask)
        
        hd = compute_hausdorff_distance(y_pred = val_output_one_hot, y = val_mask_one_hot, 
                                        percentile=95, spacing = [2.0,2.0,2.0]).squeeze()
        
        
        sd = compute_average_surface_distance(y_pred = val_output_one_hot, y = val_mask_one_hot,
                                              spacing = [2.0,2.0,2.0]).squeeze()
        
        valid_mask_hd = torch.isfinite(hd)
        
        hd_filtered = hd[valid_mask_hd]
        
        valid_mask_sd = torch.isfinite(sd)
        
        sd_filtered = sd[valid_mask_sd]
        
        mean_hd += (hd_filtered.sum() / len(hd_filtered))
        mean_sd += (sd_filtered.sum() / len(sd_filtered))
        
    mean_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    mean_hd = mean_hd / 10
    mean_sd = mean_sd / 10

    print('Validation ended !!') 
    print("Mean Validation Dice is ",mean_dice)
    print("Mean Surface distance is ", mean_sd.item())
    print("Mean Hausdorff distance is ", mean_hd.item())

Begin to validate!!


100%|██████████| 10/10 [01:46<00:00, 10.70s/it]

Validation ended !!
Mean Validation Dice is  0.6744836568832397
Mean Surface distance is  3.983412981033325
Mean Hausdorff distance is  18.931427001953125





### CIS-UNET Evaluation

In [9]:
from CIS_UNet import CIS_UNet

encoder_channels = [64, 64, 128, 256]

cisunet = CIS_UNet(spatial_dims=3, in_channels=1, num_classes=num_classes, encoder_channels=encoder_channels).cuda()

In [10]:
cisunet.load_state_dict(torch.load("CISUNet_model.pth", weights_only=True))

<All keys matched successfully>

In [11]:
from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance

dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False, num_classes=num_classes)
mean_hd = 0
mean_sd = 0

cisunet.eval()
print('Begin to validate!!')
with torch.no_grad():
    for batch in tqdm(val_loader):
        val_img = batch['img'].cuda()
        val_mask = batch['mask'].cuda()
        val_mask_one_hot = AsDiscrete(to_onehot = num_classes)(val_mask)

        roi_size = (128,128,128)
        sw_batch_size = 4
        val_output = sliding_window_inference(val_img, roi_size, sw_batch_size, cisunet)        

        val_output = torch.argmax(val_output, dim=1, keepdim=True)
        val_output_one_hot = AsDiscrete(to_onehot = num_classes)(val_output)
        
        dice_metric(y_pred=val_output, y=val_mask)
        
        hd = compute_hausdorff_distance(y_pred = val_output_one_hot, y = val_mask_one_hot, 
                                        percentile=95, spacing = [2.0,2.0,2.0]).squeeze()
        
        
        sd = compute_average_surface_distance(y_pred = val_output_one_hot, y = val_mask_one_hot,
                                              spacing = [2.0,2.0,2.0]).squeeze()
        
        valid_mask_hd = torch.isfinite(hd)
        
        hd_filtered = hd[valid_mask_hd]
        
        valid_mask_sd = torch.isfinite(sd)
        
        sd_filtered = sd[valid_mask_sd]
        
        mean_hd += (hd_filtered.sum() / len(hd_filtered))
        mean_sd += (sd_filtered.sum() / len(sd_filtered))
        
    mean_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    mean_hd = mean_hd / 10
    mean_sd = mean_sd / 10

    print('Validation ended !!') 
    print("Mean Validation Dice is ",mean_dice)
    print("Mean Surface distance is ", mean_sd.item())
    print("Mean Hausdorff distance is ", mean_hd.item())

Begin to validate!!


100%|██████████| 10/10 [01:55<00:00, 11.54s/it]

Validation ended !!
Mean Validation Dice is  0.6957284212112427
Mean Surface distance is  4.5089898109436035
Mean Hausdorff distance is  23.2863826751709





### SWIN-UNETR Evaluation

In [12]:
swin_unetr = monai.networks.nets.SwinUNETR(
    img_size=128,
    spatial_dims=3,
    in_channels=1,
    out_channels=num_classes
).cuda()


monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward().


In [13]:
swin_unetr.load_state_dict(torch.load("SWIN_UNETR_model.pth", weights_only=True))

<All keys matched successfully>

In [14]:
from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance

dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False, num_classes=num_classes)
mean_hd = 0
mean_sd = 0

swin_unetr.eval()
print('Begin to validate!!')
with torch.no_grad():
    for batch in tqdm(val_loader):
        val_img = batch['img'].cuda()
        val_mask = batch['mask'].cuda()
        val_mask_one_hot = AsDiscrete(to_onehot = num_classes)(val_mask)

        roi_size = (128,128,128)
        sw_batch_size = 4
        val_output = sliding_window_inference(val_img, roi_size, sw_batch_size, swin_unetr)        

        val_output = torch.argmax(val_output, dim=1, keepdim=True)
        val_output_one_hot = AsDiscrete(to_onehot = num_classes)(val_output)
        
        dice_metric(y_pred=val_output, y=val_mask)
        
        hd = compute_hausdorff_distance(y_pred = val_output_one_hot, y = val_mask_one_hot, 
                                        percentile=95, spacing = [2.0,2.0,2.0]).squeeze()
        
        
        sd = compute_average_surface_distance(y_pred = val_output_one_hot, y = val_mask_one_hot,
                                              spacing = [2.0,2.0,2.0]).squeeze()
        
        valid_mask_hd = torch.isfinite(hd)
        
        hd_filtered = hd[valid_mask_hd]
        
        valid_mask_sd = torch.isfinite(sd)
        
        sd_filtered = sd[valid_mask_sd]
        
        mean_hd += (hd_filtered.sum() / len(hd_filtered))
        mean_sd += (sd_filtered.sum() / len(sd_filtered))
        
    mean_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    mean_hd = mean_hd / 10
    mean_sd = mean_sd / 10

    print('Validation ended !!') 
    print("Mean Validation Dice is ",mean_dice)
    print("Mean Surface distance is ", mean_sd.item())
    print("Mean Hausdorff distance is ", mean_hd.item())

Begin to validate!!


100%|██████████| 10/10 [02:00<00:00, 12.07s/it]

Validation ended !!
Mean Validation Dice is  0.659633457660675
Mean Surface distance is  7.83279275894165
Mean Hausdorff distance is  35.4498291015625





For obtaining 3D Visualizations, I used SimpleITK to write predicted output images to disk and then loaded in 3D Slicer