# Calculate Metrics Segmentations

In [1]:
# # Make Nifti1Image from the segmentation results
data_path = "/home/jovyan/Project/ACDC/database"


Please update your data path to the correct folder (should contain train, val and test folders).


In [2]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import glob
import monai
from PIL import Image
import torch
import wandb
import nibabel as nib
from skimage import metrics
from sklearn.metrics import jaccard_score
import scipy.ndimage as ndi
from monai.config import print_config
from monai.utils import first
from monai.config import KeysCollection
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.transforms import (
    Transform,
    MapTransform,
    Randomizable,
    AddChannel,
    AddChanneld,
    CastToTyped,
    Compose,
    EnsureChannelFirst,
    LoadImage,
    LoadImaged,
    Lambda,
    Lambdad,
    RandSpatialCrop,
    RandSpatialCropd,
    Resize,
    ToTensor,
    ToTensord,
    Orientation, 
    Rotate,
    RandFlipd,
    RandAffined,
    RandGaussianNoised,
    RandRotated
)
import random
import math
import torch

In [3]:
# Build function that writes list of paths to images automatically, based on the root folder of the data and the filename structure:
def build_dict(data_path, mode):
    dicts = [] 
    patient_folders = glob.glob(os.path.join(data_path, mode, 'patient*'))
    for patient_folder in patient_folders:
        patient_id = os.path.basename(patient_folder)
        all_paths = glob.glob(os.path.join(patient_folder, f'{patient_id}_frame*.nii.gz')) # path to all 'frame' image files (both mask and img)
        time_path = os.path.join(patient_folder,f'{patient_id}_4d.nii.gz') # path to the '4d' image, so the image acquired over time
        frame_indices = sorted(list(set([os.path.basename(path).split('_frame')[-1].split('.')[0] for path in all_paths])))
        for i in range(0,len(frame_indices),2): # loop over every other frame index since you only want the number (and also got '_gt' in frame_indices)
            frame_index = frame_indices[i]
            if frame_index == '01':
                img_ED = f'{patient_id}_frame{frame_index}.nii.gz' # diastole images are always frame 01
                mask_ED = f'{patient_id}_frame{frame_index}_gt.nii.gz'
                ED_img_path = os.path.join(patient_folder,img_ED)
                ED_mask_path = os.path.join(patient_folder,mask_ED)
            else:
                img_ES = f'{patient_id}_frame{frame_index}.nii.gz' # systole images are the other frame (number varies)
                mask_ES = f'{patient_id}_frame{frame_index}_gt.nii.gz'
                ES_img_path = os.path.join(patient_folder,img_ES)
                ES_mask_path = os.path.join(patient_folder,mask_ES)
        dicts.append({'ED_img': ED_img_path, 'ED_mask': ED_mask_path, 'ES_img': ES_img_path, 'ES_mask': ES_mask_path})#,'time_img':time_path})  
    return dicts

train_dict = build_dict(data_path, 'training')
val_dict = build_dict(data_path, 'testing')

In [4]:
# Make a monai transform that loads the data from the dataset (so it retrieves the image via the path)
class LoadData(monai.transforms.Transform):
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        ES_image = nib.load(sample['ES_img']).get_fdata().astype(np.float64) # load Nifti image and transform to numpy
        ED_image = nib.load(sample['ED_img']).get_fdata().astype(np.float64) 
        ES_mask = nib.load(sample['ES_mask']).get_fdata().astype(np.float64)
        ED_mask = nib.load(sample['ED_mask']).get_fdata().astype(np.float64)

        # The function then returns the images and corresponding binary masks containing all 4 labels
        return {'ES_img': ES_image, 
                'ES_mask': ES_mask,
                'ED_img': ED_image, 
                'ED_mask': ED_mask}

train_dataset = monai.data.Dataset(train_dict, transform = LoadData())
val_dataset = monai.data.Dataset(val_dict, transform = LoadData())

In [5]:
# Function with which we can easily keep track of the transforms: necessary for wandb to recognise them (compose to list and the otherway around)
def from_compose_to_list(transform_compose):
    """
    Transform an object monai.transforms.Compose in a list fully describing the transform.
    /!\ Random seed is not saved, then reproducibility is not enabled.
    """
    from copy import deepcopy
        
    if not isinstance(transform_compose, monai.transforms.Compose):
        raise TypeError("transform_compose should be a monai.transforms.Compose object.")
    
    output_list = list()
    for transform in transform_compose.transforms:
        kwargs = deepcopy(vars(transform))
        
        # Remove attributes which are not arguments
        args = list(transform.__init__.__code__.co_varnames[1: transform.__init__.__code__.co_argcount])
        for key, obj in vars(transform).items():
            if key not in args:
                del kwargs[key]

        output_list.append({"class": transform.__class__, "kwargs": kwargs})
    return output_list

def from_list_to_compose(transform_list):
    """
    Transform a list in the corresponding monai.transforms.Compose object.
    """
    
    if not isinstance(transform_list, list):
        raise TypeError("transform_list should be a list.")
    
    pre_compose_list = list()
    
    for transform_dict in transform_list:
        if not isinstance(transform_dict, dict) or 'class' not in transform_dict or 'kwargs' not in transform_dict:
            raise TypeError("transform_list should only contains dicts with keys ['class', 'kwargs']")
        
        try:
            transform = transform_dict['class'](**transform_dict['kwargs'])
        except TypeError: # Classes have been converted to str after saving
            transform = eval(transform_dict['class'].replace("__main__.", ""))(**transform_dict['kwargs'])
            
        pre_compose_list.append(transform)
        
    return monai.transforms.Compose(pre_compose_list)

In [6]:
# Data augmentation and making datasets for training and validation:

keys_img = ['ES_img', 'ED_img']
keys_mask = ['ES_mask','ED_mask']
maxlen = 512 # maximum dimension of the original samples, so resize to that dimension
max_size = [512, 512, 16]
roi_size = [336,336,16]

# Only the necessary transforms to be able to load the data into the UNet. This will be applied to all keys present: 
T_val = monai.transforms.Compose([
    LoadData(),
    monai.transforms.AddChanneld(keys=val_dict[0].keys()),
    monai.transforms.ScaleIntensityd(keys = keys_img, minv = 0, maxv = 1),
    monai.transforms.Resized(keys=val_dict[0].keys(),size_mode='all',spatial_size=max_size, mode= 'nearest'), # resize to the largest dimension present in the original images    
    monai.transforms.CenterSpatialCropd(keys=val_dict[0].keys(),roi_size=roi_size), # crop the images so that the UNet trains more quickly
])

val_dataset_T = monai.data.CacheDataset(val_dict, transform=T_val) # transforms necessary for the UNet




<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
Loading dataset: 100%|██████████| 50/50 [00:04<00:00, 10.37it/s]


In [7]:
batch_size = 1
val_dataloader = DataLoader(val_dataset_T, batch_size=batch_size,shuffle=True) # create batches of 16 images (shuffled)


In [8]:
import torch
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
# device ="cpu"
print(f'The used device is {device}')

The used device is cuda:5


In [9]:
# Load the model
model = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

# Load model the new different way
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function =  monai.losses.DiceCELoss(sigmoid=True, batch=True)


# MODEL PATH
checkpoint = torch.load('/home/jovyan/Project/Val/ACDCtrainedUNet_ED_lr_0.001_3_13_41_56_60 (1).pt',map_location = torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])



In [None]:
# Visualize a few results (only the segmentation and og image, not the gt mask because we didn't load that):

def visualize_results(state, slice, model, dataloader):
    model.eval()  # set the model to evaluation mode
    with torch.no_grad():  # turn off gradient computation bc we're not training anymore
        no_batches = len(val_dict)/batch_size # get the total number of batches: no_samples/batch_size
        no_batches = math.trunc(no_batches) # round down so you're certain you don't get a batch that doesn't exist
        batch = random.randrange(start=1, stop=no_batches) # generate random batch number within total number of batches
        
        # Get images and labels from the dataset, and push the images through the model to get an output image
        for batch in dataloader:
            images = batch[f'{state}_img'].float().to(device) 
            output = model(images)

            i = random.randrange(images.shape[0]) # generate random sample number for in batch
            print(i) # will now give 0 because batchsize = 1
            j = 0 # mask number, 0 1 or 2
            labelname = ['RV','MYO','LV']
            # Visualize the images, ground truth masks, and model output masks
            img = images[i,:,:,:].cpu().numpy()
            img = np.squeeze(img)
            pred_mask = output[i,j,:,:,:].cpu().numpy()
            # Squeeze the mask where you specified j for
            pred_mask = np.squeeze(pred_mask)
            
            # Plot only z-axis for now
            img_plane = img[:,:,slice]
            pred_mask_plane = pred_mask[:,:,slice]

            # Display the images and masks 
            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            axes[0].imshow(img_plane, cmap='gray')
            axes[0].set_title('OG image')
            
            axes[1].imshow(pred_mask_plane)
            axes[1].set_title(f'Pred mask with label {labelname[j]}')     
            
            # Visualisatie mbv sigmoids:
            sigmoid = torch.nn.Sigmoid()
            pred_mask0 = 1*torch.round(sigmoid(output[i,0,:,:,:])).cpu().numpy()
            pred_mask1 = 2*torch.round(sigmoid(output[i,1,:,:,:])).cpu().numpy()
            pred_mask2 = 3*torch.round(sigmoid(output[i,2,:,:,:])).cpu().numpy()
            overlay_mask0 = np.ma.masked_where(pred_mask0 == 0, pred_mask0 == 1) # this masks elements in the mask that are 0, and leaves elements that are 1 unmasked
            overlay_mask1 = np.ma.masked_where(pred_mask1 == 0, pred_mask1 == 1)
            overlay_mask2 = np.ma.masked_where(pred_mask2 == 0, pred_mask2 == 1)
            axes[2].imshow(img_plane, cmap='gray')
            axes[2].imshow(overlay_mask0[:,:,slice], 'summer', alpha=0.7, clim=[0, 1], interpolation='nearest')
            axes[2].imshow(overlay_mask1[:,:,slice], 'autumn', alpha=0.7, clim=[0, 1], interpolation='nearest')
            axes[2].imshow(overlay_mask2[:,:,slice], 'winter', alpha=0.7, clim=[0, 1], interpolation='nearest')
            axes[2].set_title('Total mask')
            plt.show() 

slice = 3 # slice in z-axis used for visualization
state = 'ED'
visualize_results(state,slice,model,val_dataloader)

In [11]:
def label_to_mask(labels, gt):
    masks = []

    for batch in range(len(labels)):
        if gt == True:
            label = np.array(labels[batch])[0] # make an numpy array from the labels, where we remove the first channel (= 1, was needed for transforms) and take a batch. Original dimensions labels = [1, no_masks, x, y, z]
        else:
            label = np.array(labels[batch])  
        # Split masks into seperate masks based on label number
        mask1 = label == 1 # RV
        mask2 = label == 2 # MYO
        mask3 = label == 3 # LV
        
        # Merge all three masks together, containing values 1, 2, and 3:
        real_ES_mask = [] # preallocate

        real_ES_mask.append(1*mask1)
        real_ES_mask.append(2*mask2)
        real_ES_mask.append(3*mask3)

        masks.append(real_ES_mask)
        # print(masks.shape)
    masks=monai.data.MetaTensor(np.array(masks))
    return masks

In [22]:
from sklearn.metrics import classification_report

# Make a classification report of the segmentation
def calculate_report(gt, seg):
    gt = gt[0].flatten()
    seg = seg[0].flatten()
    
    target_names = ['BG', 'RV', 'MYO', 'LV']
    return classification_report(gt, seg,  target_names=target_names, output_dict = True)

In [23]:
from scipy.spatial.distance import directed_hausdorff

# Calculate the Huausdorff distance
def calculate_DH(gt, seg, slice):
    return directed_hausdorff(gt[:,:,slice], seg[:,:,slice])[0]
    
    

In [24]:
# Add all the masks
def post_processing(output):
    sigmoid = torch.nn.Sigmoid()
    output_softmax = torch.nn.functional.softmax(output, dim=1) # probability # dimension of the mask layers is 1 (counting from 0)
    output_argmax = torch.argmax(output_softmax, dim=1) # max van probability, output is indices over die as
    output_argmax = (output_argmax + 1).astype(np.int64)
    
    binary_masks = torch.round(sigmoid(output[0])).cpu().detach().numpy() # make binary masks (still different layers)
    binary_mask = np.sum([binary_masks[0,:,:,:],binary_masks[1,:,:,:], binary_masks[2,:,:,:]],axis = 0)
    binary_mask = torch.from_numpy(binary_mask)
    binary_mask = torch.round(sigmoid(binary_mask)).cpu().numpy()
    binary_mask = np.expand_dims(binary_mask, axis=0)
    
    argbin = output_argmax*binary_mask
    return argbin

In [44]:
model.eval()  # set the model to evaluation mode

rep_array = [[], [], []]
HD_array = []

for batch in val_dataloader:
    images = batch[f'ED_img'].float().to(device)
    labels = batch[f'ED_mask'].float()
    output = model(images)

    img = images[0,:,:,:].cpu().numpy()

    pred_mask = output[0,0,:,:,:].detach().cpu().numpy()
    # Squeeze the mask where you specified j for
    pred_mask = np.squeeze(pred_mask)

    img_plane = img[:,:,slice]
    pred_mask_plane = pred_mask[:,:,slice]

    sigmoid = torch.nn.Sigmoid()

    pp_output = post_processing(output)

    rep = calculate_report(labels[0], pp_output)
    rep_array[0].append(rep["RV"]["f1-score"])
    rep_array[1].append(rep["MYO"]["f1-score"])
    rep_array[2].append(rep["LV"]["f1-score"])
    HD_array.append(calculate_DH(labels[0][0], pp_output[0], 0))

In [46]:
print("Hausdorff min: ",min(HD_array))
print("Hausdorff max: ",max(HD_array))
print("Hausdorff avg: ",np.average(HD_array))

Hausdorff min:  0.0
Hausdorff max:  12.84523257866513
Hausdorff avg:  6.793926837319098


In [47]:
print("RV Dice min: ",min(rep_array[0]))
print("RV Dice max: ",max(rep_array[0]))
print("RV Dice avg: ",np.average(rep_array[0]))

RV Dice min:  0.5289883890487442
RV Dice max:  0.9385712801689808
RV Dice avg:  0.8230100600337217


In [48]:
print("MYO Dice min: ",min(rep_array[1]))
print("MYO Dice max: ",max(rep_array[1]))
print("MYO Dice avg: ",np.average(rep_array[1]))

MYO Dice min:  0.45842958288729885
MYO Dice max:  0.8644414786006317
MYO Dice avg:  0.7405527295192103


In [49]:
print("LV Dice min: ",min(rep_array[2]))
print("LV Dice max: ",max(rep_array[2]))
print("LV Dice avg: ",np.average(rep_array[2]))

LV Dice min:  0.522568112440621
LV Dice max:  0.9594361408960964
LV Dice avg:  0.9101154498015966
