In [None]:
from google.colab import drive

drive.mount("/content/gdrive")

Mounted at /content/gdrive


# Import libraries

In [None]:
# Import libraries
import tarfile
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt
from matplotlib import colors
import imageio.v2 as imageio

import nibabel as nib
from tqdm import tqdm
import glob
import shutil
import keras
from torch.utils.tensorboard import SummaryWriter

from scipy.ndimage import distance_transform_edt

# convert medical image into SDF volume

In [None]:
import math
input_dir_bai = '/content/gdrive/My Drive/crop_preprocessed_bai/train_3D'
output_dir = '/content/gdrive/My Drive/sdf_preprocessed_bai/train_3D'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "*.nii.gz"))

#apply cropping to the images and save the cropped images to output_dir
for nii_file in nii_files:
    img = nib.load(nii_file)
    x, y, z, c = img.shape
    data = img.get_fdata()

    lv_mask = (data == 1)
    myo_mask = (data == 2)

    SDF_lv = distance_transform_edt(lv_mask)
    SDF_myo = distance_transform_edt(myo_mask)

    # stack the two SDF volumes
    SDF_stacked = np.stack((SDF_lv, SDF_myo), axis=-1)
    SDF_stacked = SDF_stacked.squeeze()

    # Create the output filenames
    sample_name = os.path.basename(nii_file)
    sdf_filename = os.path.join(output_dir, sample_name)

    # Save the sliced long axis views as a NIfTI file
    sdf_img = nib.Nifti1Image(SDF_stacked, img.affine)
    nib.save(sdf_img, sdf_filename)


input_dir_bai = '/content/gdrive/My Drive/crop_preprocessed_bai/test_3D'
output_dir = '/content/gdrive/My Drive/sdf_preprocessed_bai/test_3D'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "*.nii.gz"))

#apply cropping to the images and save the cropped images to output_dir
for nii_file in nii_files:
    img = nib.load(nii_file)
    x, y, z, c = img.shape
    data = img.get_fdata()

    lv_mask = (data == 1)
    myo_mask = (data == 2)

    SDF_lv = distance_transform_edt(lv_mask)
    SDF_myo = distance_transform_edt(myo_mask)

    # stack the two SDF volumes
    SDF_stacked = np.stack((SDF_lv, SDF_myo), axis=-1)
    SDF_stacked = SDF_stacked.squeeze()

    # Create the output filenames
    sample_name = os.path.basename(nii_file)
    sdf_filename = os.path.join(output_dir, sample_name)

    # Save the sliced long axis views as a NIfTI file
    sdf_img = nib.Nifti1Image(SDF_stacked, img.affine)
    nib.save(sdf_img, sdf_filename)


# Dataset loading class

In [None]:
class CardiacImageSet(keras.utils.Sequence):
    """ Cardiac image set """
    def __init__(self, image_path, label_path='', deploy=False):
        self.image_path = image_path
        self.deploy = deploy
        self.images = []
        self.labels = []

        image_names = [file for file in os.listdir(image_path) if file.endswith('.nii.gz')]
        for image_name in image_names:
            # Read the image
            image = nib.load(os.path.join(image_path, image_name))
            image = image.get_fdata()
            #transpose image dimension from XYZC to CXYZ
            image = np.transpose(image, (3, 0, 1, 2))
            self.images += [image]

            # Read the label map
            if not self.deploy:
                label_name = os.path.join(label_path, image_name)
                label = nib.load(label_name)
                label = label.get_fdata()
                label = np.transpose(label, (3, 0, 1, 2))
                self.labels += [label]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Get an image and perform intensity normalisation
        # Dimension: XYZ
        # image = normalise_intensity(self.images[idx])
        image = self.images[idx]

        # Get its label map
        # Dimension: XYZ
        label = self.labels[idx]
        return image, label

    def get_random_batch(self, batch_size):
        # Get a batch of paired images and label maps
        # Dimension of images: NCXYZ
        # Dimension of labels: NXYZ
        images, labels = [], []

        ### Insert your code ###
        for i in range(batch_size):
            #randomly retrieve an image and label map
            random_idx = random.randint(0,self.__len__() - 1)
            random_image, random_label = self.__getitem__(random_idx)
            images += [random_image]
            labels += [random_label]

        #Turn the list into np array
        images = np.array(images)
        labels = np.array(labels)
        ### End of your code ###
        return images, labels

    def get_batch(self, batch_size, iteration_num):
      images, labels = [], []
      batch_num = self.__len__()//batch_size
      image_idx = ((iteration_num % batch_num) - 1) * batch_size
      for i in range(batch_size):
        image, label = self.__getitem__(image_idx + i)
        images += [image]
        labels += [label]

      images = np.array(images)
      labels = np.array(labels)
      return images, labels

# train_set = CardiacImageSet('/content/gdrive/My Drive/crop_preprocessed_bai/train_2D', '/content/gdrive/My Drive/crop_preprocessed_bai/train_3D')
# test_set = CardiacImageSet('/content/gdrive/My Drive/crop_preprocessed_bai/test_2D', '/content/gdrive/My Drive/crop_preprocessed_bai/test_3D')

# image_1, label_1 = train_set.__getitem__(88)
# print(image_1.shape)
# print(label_1.shape)

# test_images, test_labels = train_set.get_random_batch(4)
# print(test_images.shape)
# print(test_labels.shape)
# print(test_images.dtype)


# Construct SDF U-Net architecture

In [None]:
class UNet3d(nn.Module):
    def contracting_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(out_channels),
        )
        return block

    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.ConvTranspose3d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2,
                                     padding=1, output_padding=1)
        )
        return block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            # torch.nn.Sigmoid()
        )
        return block

    def __init__(self, in_channel, out_channel):
        super(UNet3d, self).__init__()
        # Encode
        self.conv_encode1 = self.contracting_block(in_channel, 16, 32)
        self.conv_maxpool1 = torch.nn.MaxPool3d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(32, 32, 64)
        self.conv_maxpool2 = torch.nn.MaxPool3d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(64, 64, 128)
        self.conv_maxpool3 = torch.nn.MaxPool3d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=3, in_channels=128, out_channels=128, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(128),
            torch.nn.Conv3d(kernel_size=3, in_channels=128, out_channels=256, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(256),
            torch.nn.ConvTranspose3d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1,
                                     output_padding=1)
        )
        # Decode
        self.conv_decode3 = self.expansive_block(128+256, 128, 128)
        self.conv_decode2 = self.expansive_block(64+128, 64, 64)
        self.final_layer = self.final_block(32+64, 32, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        # print("unsampled shape:", upsampled.shape)
        # print("bypass shape:", bypass.shape)

        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=False)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=False)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=False)
        final_layer = self.final_layer(decode_block1)
        return final_layer

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


toy_model = UNet3d(in_channel=1, out_channel=2)
total_parameters = toy_model.count_parameters()
print("Total number of parameters in the model:", total_parameters)

Total number of parameters in the model: 6406626


# Model Training

In [None]:
# CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: {0}'.format(device))

# Build the model
num_class = 2
model = UNet3d(in_channel=1, out_channel=num_class)
model = model.to(device)
params = list(model.parameters())

model_dir = '/content/gdrive/My Drive/saved_model/SDF-2'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Optimizer
optimizer = optim.Adam(params, lr=1e-3)

# Segmentation loss
criterion = nn.MSELoss()

train_image_folder = '/content/gdrive/My Drive/crop_preprocessed_bai/train_2D'
train_label_folder = '/content/gdrive/My Drive/sdf_preprocessed_bai/train_3D'
test_image_folder = '/content/gdrive/My Drive/crop_preprocessed_bai/test_2D'
test_label_folder = '/content/gdrive/My Drive/sdf_preprocessed_bai/test_3D'

# Datasets
train_set = CardiacImageSet(train_image_folder, train_label_folder)
test_set = CardiacImageSet(test_image_folder, test_label_folder)

# Create a SummaryWriter object to write TensorBoard logs
log_dir = '/content/gdrive/My Drive/tensorboard_logs/SDF'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
writer = SummaryWriter(log_dir)

# Train the model
# num_iter = 500
train_batch_size = 4
eval_batch_size = 4
start = time.time()
running_loss = 0
#number of batches in an epoch
num_batches = int(train_set.__len__()/train_batch_size)
num_epoch = 100
for it in range(1, 1 + (num_batches * num_epoch)):
    # Set the modules in training mode, which will have effects on certain modules, e.g. dropout or batchnorm.
    start_iter = time.time()
    model.train()

    # Get a batch of images and labels
    images, labels = train_set.get_batch(train_batch_size, it)
    images, labels = torch.from_numpy(images), torch.from_numpy(labels)
    # image.to() convert the array from system RAM to GPU RAM
    images, labels = images.to(device, dtype=torch.float32), labels.to(device, dtype=torch.float32)
    #remove the channel dimension in the labels array
    labels = labels.squeeze(axis = 1)
    # print("Images shape:", images.shape)
    logits = model(images)

    # Perform optimisation and print out the training loss
    # print('logits shape:', logits.shape)
    # print('label shape:', labels.shape)

    loss = criterion(logits, labels)
    running_loss += loss

    if it % num_batches == 0:
        epoch_loss = running_loss/num_batches
        running_loss = 0
        print ("training loss for epoch {}:".format(it/num_batches))
        print(epoch_loss.item())

        # Write the training loss to TensorBoard
        writer.add_scalar('Loss', epoch_loss.item(), int(it / num_batches))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    ###   ###

    # Evaluate
    if it % num_batches == 0:
        model.eval()
        # Disabling gradient calculation during reference to reduce memory consumption
        with torch.no_grad():
            # Evaluate on a batch of test images and print out the test loss
            ### Insert your code ###
            test_images, test_labels = test_set.get_random_batch(eval_batch_size)
            test_images, test_labels = torch.from_numpy(test_images), torch.from_numpy(test_labels)
            test_images, test_labels = test_images.to(device, dtype=torch.float32), test_labels.to(device, dtype=torch.long)
            test_labels = test_labels.squeeze(axis = 1)
            test_logits = model(test_images)
            test_loss = criterion(test_logits, test_labels)
            print ("test loss for epoch {}:".format(it/num_batches))
            print(test_loss.item())

            # Write the test loss to TensorBoard with a 'test' tag
            writer.add_scalar('Loss', test_loss.item(), int(it / num_batches))
            ### End of your code ###

    # # Save the model
    # if it % num_batches == 0:
    #     epoch = it/num_batches
    #     torch.save(model.state_dict(), os.path.join(model_dir, 'model_{0}.pt'.format(epoch)))
print('Training took {:.3f}s in total.'.format(time.time() - start))

Device: cuda
training loss for epoch 1.0:
0.16900870203971863
test loss for epoch 1.0:
0.7079792022705078
training loss for epoch 2.0:
0.0798821821808815
test loss for epoch 2.0:
0.1769847422838211
training loss for epoch 3.0:
0.07263513654470444
test loss for epoch 3.0:
0.06357325613498688
training loss for epoch 4.0:
0.06941662728786469
test loss for epoch 4.0:
0.05312737077474594
training loss for epoch 5.0:
0.06851514428853989
test loss for epoch 5.0:
0.10315446555614471
training loss for epoch 6.0:
0.06548455357551575
test loss for epoch 6.0:
0.25690215826034546
training loss for epoch 7.0:
0.0641406923532486
test loss for epoch 7.0:
0.09746898710727692
training loss for epoch 8.0:
0.0610945001244545
test loss for epoch 8.0:
0.07260861992835999
training loss for epoch 9.0:
0.0583309531211853
test loss for epoch 9.0:
0.04948240518569946
training loss for epoch 10.0:
0.05854983255267143
test loss for epoch 10.0:
0.037719469517469406
training loss for epoch 11.0:
0.05659528449177742


# Prediction generation

In [None]:
import datetime

device = torch.device("cpu")
def model_load():

    unet = UNet3d(in_channel=1, out_channel=2)
    unet.to(device, dtype=torch.float)

    model_list = ['/content/gdrive/My Drive/saved_model/SDF/model_10.0.pt']
    model_i = model_list[0]
    checkpoint = torch.load(model_i, map_location=torch.device('cpu'))
    unet.load_state_dict(checkpoint)
    return unet

def mr_sdf_inference(unet, slice_img):

    img = nib.load(slice_img)
    affine = img.affine
    data = img.get_fdata()
    data = np.transpose(data, (3, 0, 1, 2))

    # print("sliced image shape:", data.shape)

    data = np.expand_dims(data, axis=(0))

    # print(data.shape)

    data = torch.from_numpy(data)

    # print(data.size())
    data = data.to(device, dtype=torch.float32)

    # print(data.size())

    unet.eval()
    output = unet(data)
    pred = output.detach().cpu().numpy()

    # print(pred.shape)

    pred = pred.squeeze()
    pred = np.transpose(pred, (1, 2, 3, 0))

    distance_field_1 = pred[..., 0]
    distance_field_2 = pred[..., 1]

    # Create a segmentation map where values greater than 0.5 are labeled accordingly
    segmentation_map = np.zeros_like(distance_field_1)

    # Label distance field 1 as 1 where distance > 0.5
    segmentation_map[distance_field_1 > 0.5] = 1

    # Label distance field 2 as 2 where distance > 0.5
    segmentation_map[distance_field_2 > 0.5] = 2

    # Add a new axis to the segmentation map to make it 128x128x64x1
    pred = segmentation_map = np.expand_dims(segmentation_map, axis=-1)

    # print(pred.shape)
    pred_nifti = nib.Nifti1Image(segmentation_map, affine=affine)

    #pred_nifti is a Nifti object, pred is a numpy array
    return pred_nifti, pred

# sliced_image = '/content/gdrive/My Drive/crop_preprocessed_bai/test_2D/crop_14AB01345_segmentation_ES.nii.gz'
# unet = model_load()
# pred_nifti, pred = mr_sdf_inference(unet, sliced_image)

# pred_save_dir = '/content/gdrive/My Drive/pred_output/SDF/pred_SDF.nii.gz'
# nib.save(pred_nifti, pred_save_dir)

# Evaluation Functions

In [None]:
from functools import partial

import numpy as np
!pip install SimpleITK
import SimpleITK as sitk
from SimpleITK import GetArrayViewFromImage as ArrayView

# dice_scores = []
def dice_score(ground_truth, predicted, num_labels):
    total_intersection = 0
    total_gt_count = 0
    total_pred_count = 0

    for label in range(1, num_labels):  # Start from label 1, assuming label 0 is background
        # Create binary masks for the specific label
        gt_mask = (ground_truth == label)
        pred_mask = (predicted == label)

        intersection = np.logical_and(gt_mask, pred_mask).sum()
        gt_count = gt_mask.sum()
        pred_count = pred_mask.sum()

        label_dice = (2.0 * intersection) / (gt_count + pred_count)
        # dice_scores.append(label_dice)

        total_intersection += intersection
        total_gt_count += gt_count
        total_pred_count += pred_count

    dice = (2.0 * total_intersection) / (total_gt_count + total_pred_count)
    return dice

distance_map = partial(sitk.SignedMaurerDistanceMap, squaredDistance=False, useImageSpacing=True)
def hausdorf(gold, prediction, num_labels = 1):
    for label in range(1, num_labels + 1):
        gold_surface = sitk.LabelContour(gold == label, False)
        prediction_surface = sitk.LabelContour(prediction == label, False)

        ### Get distance map for contours (the distance map computes the minimum distances)
        prediction_distance_map = sitk.Abs(distance_map(prediction_surface))
        gold_distance_map = sitk.Abs(distance_map(gold_surface))

        ### Find the distances to surface points of the contour.  Calculate in both directions
        gold_to_prediction = ArrayView(prediction_distance_map)[ArrayView(gold_surface) == 1]
        prediction_to_gold = ArrayView(gold_distance_map)[ArrayView(prediction_surface) == 1]

        ### Find the 95% Distance for each direction and average

        hausdorf_dis = (np.percentile(prediction_to_gold, 95) + np.percentile(gold_to_prediction, 95)) / 2.0
        return hausdorf_dis

def segmentation_accuracy(groundtruth, prediction):
    unique_categories = np.unique(groundtruth)
    unique_categories = unique_categories[unique_categories != 0]  # Remove background 0

    per_category_accuracies = []

    for category in unique_categories:
        category_mask = (groundtruth == category)
        category_pred_mask = (prediction == category)

        intersect = np.sum(category_pred_mask * category_mask)
        union = np.sum(category_pred_mask) + np.sum(category_mask) - intersect
        xor = np.sum(category_mask == category_pred_mask)

        category_acc = xor / (union + xor - intersect)
        per_category_accuracies.append(category_acc)

    overall_intersect = np.sum(prediction * groundtruth)
    overall_union = np.sum(prediction) + np.sum(groundtruth) - overall_intersect
    overall_xor = np.sum(groundtruth == prediction)

    overall_acc = overall_xor / (overall_union + overall_xor - overall_intersect)

    return per_category_accuracies, overall_acc



def jaccard_score(ground_truth, predicted, num_labels = 3):
    """
    Calculate the Jaccard similarity score (IoU) for multi-label 3D object segmentation.

    Parameters:
    ground_truth (numpy.ndarray): Ground truth 3D object segmentation (3D binary array).
    predicted (numpy.ndarray): Predicted 3D object segmentation (3D binary array).

    Returns:
    float: Average Jaccard similarity score across all labels.
    """
    # unique_labels = np.unique(np.concatenate((ground_truth, predicted)))
    # num_labels = len(unique_labels)
    jaccard_scores = np.zeros(num_labels - 1)

    for label in range(1, num_labels):
        gt_mask = (ground_truth == label)
        pred_mask = (predicted == label)

        intersection = np.logical_and(gt_mask, pred_mask).sum()
        union = np.logical_or(gt_mask, pred_mask).sum()

        # intersection = len(list(set(gt_mask).intersection(pred_mask)))
        # union = (len(gt_mask) + len(pred_mask)) - intersection

        if union == 0:
            jaccard_scores[label - 1] = 1.0  # Handle division by zero
        else:
            jaccard_scores[label - 1] = float(intersection) / union

    average_jaccard_score = np.mean(jaccard_scores)
    return jaccard_scores, average_jaccard_score




# Visualise Model Output

In [None]:
sliced_images = '/content/gdrive/My Drive/crop_preprocessed_bai/test_2D'
pred_save_dir = '/content/gdrive/My Drive/pred_output/SDF'
pred_load_dir = '/content/gdrive/My Drive/pred_output/08-13_15-32'
ground_truth_dir = '/content/gdrive/My Drive/crop_preprocessed_bai/test_3D'
predicted = False

# Get the current time
current_time = datetime.datetime.now()
time_string = current_time.strftime("%m-%d_%H-%M")

#create a new directory and name it with current time
if not predicted:
    pred_save_dir = os.path.join(pred_save_dir, time_string)
    os.makedirs(pred_save_dir, exist_ok=True)

sliced_nii_files = glob.glob(os.path.join(sliced_images, "*.nii.gz"))


dice_scores = []
hausdorfs = []
per_catogory_jaccards = []
jaccards = []

unet = model_load()

#identify the evaluation matrix for the test samples
for sliced_nii_file in sliced_nii_files:

    if not predicted:
        # print("if not predicted executed")
        # Create the output filenames
        sample_name = os.path.basename(sliced_nii_file)
        pred_filename = os.path.join(pred_save_dir, "pred" + sample_name)
        # Load the 2D NIfTI file and returns nifti, numpy and sitk prediction
        # print(sliced_nii_file)
        pred_nifti, pred = mr_sdf_inference(unet, sliced_nii_file)

        # pred = np.transpose(pred, (1, 2, 3, 0))
        # print("pred shape:", pred.shape)
        # Save the predicted output as a NIfTI file
        nib.save(pred_nifti, pred_filename)

        pred_sitk = sitk.ReadImage(pred_filename)

        # pred = np.transpose(pred, (1, 2, 3, 0))
    else:
        # print("else executed")
        sample_name = os.path.basename(sliced_nii_file)
        pred_filename = os.path.join(pred_load_dir, "pred" + sample_name)
        pred_nifti = nib.load(pred_filename)
        pred = pred_nifti.get_fdata()
        pred_sitk = sitk.ReadImage(pred_filename)
        pred = np.expand_dims(pred, axis=-1)

    ground_truth_file = os.path.join(ground_truth_dir, sample_name)
    ground_truth = nib.load(ground_truth_file)
    # print("ground truth shape:", ground_truth.shape)
    ground_truth_sitk = sitk.ReadImage(ground_truth_file)

    ground_truth = ground_truth.get_fdata()
    num_labels = 3

    #calculate the evaluation metrics and record them
    total_dice_score = dice_score(ground_truth, pred, num_labels)
    # print(total_dice_score)
    per_catogory_jaccard = []
    per_catogory_jaccard, jaccard = jaccard_score(ground_truth, pred)
    hausdorf_distance = hausdorf(ground_truth_sitk, pred_sitk)
    # per_category_accuracy, overall_acc = segmentation_accuracy(ground_truth, pred)

    dice_scores.append(total_dice_score)
    hausdorfs.append(hausdorf_distance)
    per_catogory_jaccards.append(per_catogory_jaccard)
    jaccards.append(jaccard)


print("Dice scores:", dice_scores)
mean_dice =  np.mean(dice_scores)
print("Mean Dice Score:", mean_dice)

print("Hausdorf distances:", hausdorfs)
mean_haus =  np.mean(hausdorfs)
print("Mean Hausdorf distance:", mean_haus)

print("per cat accuracies:", per_catogory_jaccards)
mean_cat_jac =  np.mean(per_catogory_jaccards)
print("Mean per cat jaccard score:", mean_cat_jac)


Dice scores: [0.9004480691515988, 0.9224785680832434, 0.8855808680630534, 0.8838227318509336, 0.9196916566352891, 0.8615268142095405, 0.9190192857241465, 0.872276237346457, 0.9161947968299309, 0.9102704938159918, 0.8921133000406328, 0.9004477571922674, 0.8883714963761145, 0.9260482189961446, 0.9229819883803514, 0.9062439923367076, 0.867508568252694, 0.9051546831965726, 0.9135940941349213, 0.880861217468509, 0.936171205499479, 0.9197476474112712, 0.9264137760973205, 0.9054669619512484, 0.9066125559565669, 0.88412804967589, 0.9075094707219395, 0.9298547871031257, 0.9125549185816715, 0.9445099646191056, 0.8759668637091301, 0.9320653538325584, 0.9110324703545043, 0.9268823556847686, 0.8426421450643339, 0.9089834190337394, 0.9127319019896726, 0.9122249278499278, 0.8878048780487805, 0.9082516672742479, 0.9161330921104381, 0.9237963794235485, 0.9344575842170803, 0.9108625307527298, 0.892262058582295, 0.8923886314428192, 0.9253039243348014, 0.8957348056575967, 0.9256055100272157, 0.92165285691

# Save evaluation metrics as csv file

In [None]:
import csv
# Path to the CSV file
csv_file_path = os.path.join(pred_save_dir, 'eva_metrics.csv')

# Write the eva metrics to the CSV file
# Write the evaluation metrics to the CSV file
with open(csv_file_path, mode="w", newline="") as csv_file:
    writer = csv.writer(csv_file)
    header = ["Dice Score", "Hausdorff Distance"]
    max_categories = 2
    # Add per-category accuracy column headers
    for i in range(max_categories):
        header.append(f"Category {i+1} Jaccard")

    header.append("Overall Accuracy")
    writer.writerow(header)  # Write the header

    for dice, hausdorff, category_jaccards, overall_jac in zip(dice_scores, hausdorfs, per_catogory_jaccards, jaccards):
        row = [dice, hausdorff]

        # Fill in per-category accuracy values
        row.extend(category_jaccards)

        # Fill in any remaining columns with None
        num_missing = max_categories - len(category_jaccards)
        row.extend([None] * num_missing)

        row.append(overall_jac)
        writer.writerow(row)


# Visualise tensorflow chart