In [1]:
import os
import time
import numpy as np
import nibabel as nib
import random
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import pickle
# %matplotlib inline
import copy
from datetime import datetime
from imp import reload
import json
import logging
import SimpleITK as sitk
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torchvision.transforms as transforms
import ipywidgets as widgets
from ipywidgets import interact, interactive
from IPython.display import display

# Instructions #

Major part of the code is implemented using Classes to avoid duplication of code. The first few blocks of the notebook defines the various components i.e. model structure, training pipeline, dataloaders, loss functions etc. In each of the marked sections afterwards, the required functions are called with the appropriate parameters.

In [2]:
# for reproducability
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

### Dataset and Dataloaders ###

In [17]:
class ToTorchTensor:
    '''
        Transforms a numpy ndarray to a torch tensor of the supplied datatype
    '''

    def __init__(self, dtype=torch.float32):
        self.dtype = dtype

    def __call__(self, input):
        return torch.tensor(input, dtype=self.dtype)

class DatasetHepatic(Dataset):
    '''
        min = 24
        max = 181
        median = 49
        mean = 69
    '''

    def __init__(self, run_mode='train',
                 transform_image=None,
                 transform_label=None,
                 patch_size_normal=25,
                 patch_size_low=19,
                 patch_size_out=9,
                 patch_low_factor=3,
                 label_percentage=0.1,
                 batch_size_inner=100,
                 use_probabilistic=False,
                 create_numpy_dataset=False,
                 dataset_variant='nib',
                 train_percentage=0.8,
                 use_elastic_deformation=False,
                 user_affine_transformation=False,
                 num_controlpoints=20, sigma=5, rotation=10, scale=(0.90, 1.10), shear=(0.01, 0.02)
                 ):

        self.run_mode = run_mode
        self.create_numpy_dataset_cond = create_numpy_dataset
        self.dataset_variant = dataset_variant
        self.patch_size_normal = patch_size_normal
        self.patch_size_low = patch_size_low
        self.patch_size_out = patch_size_out
        self.patch_low_factor = patch_low_factor
        self.batch_size_inner = batch_size_inner
        self.train_percentage = train_percentage
        self.patch_size_low_up = self.patch_size_low * self.patch_low_factor

        self.label_percentage = label_percentage
        self.use_probabilistic = use_probabilistic
        self.fetch_filenames()
        self.create_numpy_dataset()
        self.use_elastic_deformation = use_elastic_deformation
        self.use_affine_transformation = user_affine_transformation
        self.elastic_deformation = ElasticDeformation(num_controlpoints=num_controlpoints, sigma=sigma)
        self.affine_transformation = AffineTransformation(rotation=rotation, scale=scale, shear=shear)

        if transform_image is None:
            self.transform_image = transforms.Compose([
                ToTorchTensor(dtype=torch.float32),
                transforms.Normalize(mean=0.5, std=0.5)
            ])
        else:
            self.transform_image = transform_image

        if transform_label is None:
            self.transform_label = transforms.Compose([
                ToTorchTensor(torch.int64)
            ])
        else:
            self.transform_label = transform_label

    def __getitem__(self, index):
        if self.dataset_variant == 'nib':
            image = self.read_file_nib(self.filenames_image_nib[index])
            label = self.read_file_nib(self.filenames_label_nib[index])
        elif self.dataset_variant == 'npy':
            image = self.read_file_npy(self.filenames_image_npy[index])
            label = self.read_file_npy(self.filenames_label_npy[index])

        image = self.transform_image(image)
        label = self.transform_label(label)
        if self.use_elastic_deformation:
            image, label = self.elastic_deformation(image, label)
        if self.use_affine_transformation:
            image, label, _, _ = self.affine_transformation(image, label)

        image = image.detach().numpy()
        label = label.detach().numpy()

        # index of the original filenames as in the dataset folders
        index_filename = self.filenames_image_npy[index][25:28]

        if self.run_mode in ['train', 'val']:
            if self.batch_size_inner > 1:
                image_patch_normal_stack = torch.zeros(
                    (self.batch_size_inner, self.patch_size_normal, self.patch_size_normal, self.patch_size_normal),
                    dtype=torch.float32)
                image_patch_low_up_stack = torch.zeros(
                    (self.batch_size_inner, self.patch_size_low_up, self.patch_size_low_up, self.patch_size_low_up),
                    dtype=torch.float32)
                label_patch_out_stack = torch.zeros(
                    (self.batch_size_inner, self.patch_size_out, self.patch_size_out, self.patch_size_out),
                    dtype=torch.int64)

                for index_inner in range(self.batch_size_inner):
                    # extract the three different patches of labels
                    label_patch_normal, label_patch_low_up, label_patch_out = self.get_random_patch(label)

                    # extract the three different patches of images
                    image_patch_normal = self.get_3D_crop(image, self.coordinate_center, self.patch_size_normal)
                    image_patch_low_up = self.get_3D_crop(image, self.coordinate_center, self.patch_size_low_up)
                    image_patch_out = self.get_3D_crop(image, self.coordinate_center, self.patch_size_out)

                    image_patch_normal_stack[index_inner] = torch.tensor(image_patch_normal, dtype=torch.float32).unsqueeze(0)
                    image_patch_low_up_stack[index_inner] = torch.tensor(image_patch_low_up, dtype=torch.float32).unsqueeze(0)
                    label_patch_out_stack[index_inner] = torch.tensor(label_patch_out, dtype=torch.float32).unsqueeze(0)

                return image_patch_normal_stack.unsqueeze(1), image_patch_low_up_stack.unsqueeze(1), label_patch_out_stack.unsqueeze(1)
            else:
                # extract the three different patches of labels
                label_patch_normal, label_patch_low_up, label_patch_out = self.get_random_patch(label)

                # extract the three different patches of images
                image_patch_normal = self.get_3D_crop(image, self.coordinate_center, self.patch_size_normal)
                image_patch_low_up = self.get_3D_crop(image, self.coordinate_center, self.patch_size_low_up)
                image_patch_out = self.get_3D_crop(image, self.coordinate_center, self.patch_size_out)

                return torch.tensor(image_patch_normal, dtype=torch.float32).unsqueeze(0), \
                       torch.tensor(image_patch_low_up, dtype=torch.float32).unsqueeze(0), \
                       torch.tensor(label_patch_out, dtype=torch.int64).unsqueeze(0)

        elif self.run_mode == 'inference':
            # TODO fix uneven dimensions, otherwise run with batch size = 1
            image = self.transform_image(image)
            return image, label, index_filename

    def __len__(self):
        return self.num_samples

    def create_numpy_dataset(self):
        '''
            Converts the dataset to numpy format to increase i/o operations
        '''
        def convert_to_numpy_from_nib(target_dir, filenames):
            os.makedirs(target_dir, exist_ok=True)

            for filename in tqdm(filenames, leave=False):
                data_np = nib.load(filename).get_fdata()

                filename_new = f'{filename[11:-7]}.npy'
                save_path = os.path.join(target_dir, filename_new)
                np.save(save_path, data_np)

        save_dir_train_im = 'imagesTrNP'
        train_filenames_im = self.filenames_image_nib

        save_dir_train_labels = 'labelsTrNP'
        train_filenames_labels = self.filenames_label_nib

        if self.create_numpy_dataset_cond:
            convert_to_numpy_from_nib(target_dir=save_dir_train_im, filenames=train_filenames_im)
            convert_to_numpy_from_nib(target_dir=save_dir_train_labels, filenames=train_filenames_labels)

    def get_label_percentage(self, input, label):
        '''
            Returns the percentage of supplied label in the voxel
        '''
        eps = 1e-9
        denominator = input.shape[0] * input.shape[1] * input.shape[2]
        numerator = np.sum(np.where(input == label, 1, 0))

        return numerator / (denominator + eps)

    def get_rand_index_3D(self, input, height=512, width=512, depth=20, patch_size=57):
        '''
            Returns a random starting index (top-left) of a valid 3D volume
        '''
        patch_size_half = patch_size // 2

        if self.use_probabilistic:
            # nearby currently selected label
            loop_condition = True
            background_count = 0
            while loop_condition:
                # crop the image so that the patch does not go outside the area of the image
                input_cropped = input[patch_size_half:height - patch_size_half, patch_size_half:width - patch_size_half,
                                patch_size_half:depth - patch_size_half]
                # all indices of the cropped image equal to the current selected label category
                indices_all = np.array(np.where(input_cropped == self.current_selected_label))
                # print(indices_all.shape[1])
                if indices_all.shape[1] >= 1:
                    selected_index_w = np.random.randint(indices_all.shape[1])
                    selected_index = indices_all[:, selected_index_w]

                    index_h, index_w, index_d = (
                        selected_index[0] + patch_size_half, selected_index[1] + patch_size_half,
                        selected_index[2] + patch_size_half)
                    # index_h, index_w, index_d = selected_index
                    loop_condition = False
                else:
                    # print('here')
                    if background_count > 0:
                        # if none of the other two labels are present in the image, randomly pick a coordinate
                        index_h = np.random.randint(patch_size_half, height - patch_size_half)
                        index_w = np.random.randint(patch_size_half, width - patch_size_half)
                        index_d = np.random.randint(patch_size_half, depth - patch_size_half)
                        loop_condition = False

                    else:
                        if self.current_selected_label == 1:
                            self.current_selected_label = 2
                            background_count += 1
                        elif self.current_selected_label == 2:
                            self.current_selected_label = 1
                            background_count += 1
                        loop_condition = True
        else:
            #  complete random
            index_h = np.random.randint(patch_size_half, height - patch_size_half)
            index_w = np.random.randint(patch_size_half, width - patch_size_half)
            index_d = np.random.randint(patch_size_half, depth - patch_size_half)

        return (index_h, index_w, index_d)

    def get_3D_crop(self, input, coordinate, patch_size):
        '''
            Returns a 3D patch of an input 3D image given a valid top-left coordinate
        '''
        assert patch_size % 2 == 1, 'Patch size should be an odd number'
        patch_size_half = patch_size // 2

        if len(input.shape) == 3:
            height, width, depth = input.shape

        if depth <= self.patch_size_low * self.patch_low_factor:
            temp_array = np.zeros((height, width, self.patch_size_low * self.patch_low_factor))
            temp_array[:, :, :depth] = input
            input = temp_array
            depth = temp_array.shape[2]

        return input[
               coordinate[0] - patch_size_half: coordinate[0] + patch_size_half + 1,
               coordinate[1] - patch_size_half: coordinate[1] + patch_size_half + 1,
               coordinate[2] - patch_size_half: coordinate[2] + patch_size_half + 1,
               ]

    def set_probabilistic_label(self):
        '''
            Randomly with equal probability select one of the three labels to be the current label
        '''
        label_probability = np.random.rand()
        mode = 'major'  # equal, biased
        if mode == 'biased':
            if label_probability > 0.5:
                self.current_selected_label = 1
            else:
                self.current_selected_label = 2
        elif mode == 'equal':
            if label_probability > 0.66:
                self.current_selected_label = 2
            elif label_probability < 0.33:
                self.current_selected_label = 1
            else:
                self.current_selected_label = 0
        elif mode == 'major':
            if label_probability > 0.45:
                self.current_selected_label = 2
            elif label_probability < 0.45:
                self.current_selected_label = 1
            else:
                self.current_selected_label = 0

    def get_random_patch(self, input):
        '''
            Returns a valid cubic sub-volume with edge lenth = patch_size from a supplied 3D input volume image_input
        '''
        # a = copy.deepcopy(input)
        if len(input.shape) == 3:
            height, width, depth = input.shape

        if depth <= self.patch_size_low * self.patch_low_factor:
            temp_array = np.zeros((height, width, self.patch_size_low * self.patch_low_factor))
            temp_array[:, :, :depth] = input
            input = temp_array
            depth = temp_array.shape[2]

        loop_condition = True
        if self.use_probabilistic:
            self.set_probabilistic_label()

        # keep sampling a new patch until the current label meets the desired overall percentage
        while loop_condition:
            # get a valid coordinate and extract the patch
            self.coordinate_center = self.get_rand_index_3D(input, height, width, depth, self.patch_size_low_up)

            patch_normal = self.get_3D_crop(input, self.coordinate_center, self.patch_size_normal)
            patch_low_up = self.get_3D_crop(input, self.coordinate_center, self.patch_size_low_up)
            patch_out = self.get_3D_crop(input, self.coordinate_center, self.patch_size_out)

            loop_condition = False

        return patch_normal, patch_low_up, patch_out

    def read_file_nib(self, filename):
        '''
            Reads a nibabel file and returns it in numpy ndarray format
        '''
        try:
            data_nib = nib.load(filename).get_fdata()
        except FileNotFoundError:
            print(f'Error reading file: {filename}')

        return data_nib

    def read_file_npy(self, filename):
        '''
            Reads a npy file and returns it in numpy ndarray format
        '''
        try:
            data_npy = np.load(filename)
        except FileNotFoundError:
            print(f'Error reading file: {filename}')

        return data_npy

    def fetch_filenames(self, path_meta='dataset.json'):
        '''
            Reads the dataset.json file and extracts the training and test image and/or labels
        :return:
        '''
        try:
            with open(path_meta) as file_meta:
                data_meta = json.loads(file_meta.read())
        except FileNotFoundError:
            print(f'Meta file: {self.path_meta} not found')

        num_samples = len(data_meta['training'])

        if self.run_mode == 'train':
            num_samples = int(np.floor(self.train_percentage * num_samples))

            self.filenames_image_nib = [current_sample['image'] for current_sample in data_meta['training']][
                                       :num_samples]
            self.filenames_label_nib = [current_sample['label'] for current_sample in data_meta['training']][
                                       :num_samples]

            self.filenames_image_npy = [os.path.join('.', 'imagesTrNP', f'{filename[11:-7]}.npy') for filename in
                                        self.filenames_image_nib]
            self.filenames_label_npy = [os.path.join('.', 'labelsTrNP', f'{filename[11:-7]}.npy') for filename in
                                        self.filenames_label_nib]
            self.num_samples = num_samples
        else:
            num_train = int(np.floor(self.train_percentage * num_samples))

            self.filenames_image_nib = [current_sample['image'] for current_sample in data_meta['training']][
                                       num_train:]
            self.filenames_label_nib = [current_sample['label'] for current_sample in data_meta['training']][
                                       num_train:]

            self.filenames_image_npy = [os.path.join('.', 'imagesTrNP', f'{filename[11:-7]}.npy') for filename in
                                        self.filenames_image_nib]
            self.filenames_label_npy = [os.path.join('.', 'labelsTrNP', f'{filename[11:-7]}.npy') for filename in
                                        self.filenames_label_nib]
            # self.num_samples = int(np.ceil((1 - self.train_percentage) * num_samples))
            self.num_samples = len(self.filenames_image_nib)
        if (not len(self.filenames_image_nib) == len(self.filenames_label_nib)):
            raise Exception('Inconsistent training image/label combination')
        if len(self.filenames_image_nib) == 0:
            raise Exception(f'Error reading {self.run_mode} images')
        if len(self.filenames_label_nib) == 0:
            raise Exception(f'Error reading {self.run_mode} labels')

        elif self.run_mode == 'test':
            # 'TODO' correct the train and test and inference variants
            self.filenames_image_nib = [current_sample for current_sample in data_meta['test']]
            if len(self.filenames_image_nib) == 0:
                raise Exception(f'Error reading {self.run_mode} images')


### DeepMedic Model ###
#### Adapted from: https://github.com/pykao/BraTS2018-tumor-segmentation/blob/master/models/deepmedic.py ####

In [4]:
class ResBlock(nn.Module):
    '''
        Adapted from: https://github.com/pykao/BraTS2018-tumor-segmentation/blob/master/models/deepmedic.py
    '''

    def __init__(self, inplanes, planes):
        super(ResBlock, self).__init__()

        self.inplanes = inplanes
        self.conv1 = nn.Conv3d(inplanes, planes, 3, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, 3, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        x = x[:, :, 2:-2, 2:-2, 2:-2]
        y[:, :self.inplanes] += x
        y = self.relu(y)
        return y


def conv3x3(inplanes, planes, ksize=3):
    return nn.Sequential(
        nn.Conv3d(inplanes, planes, ksize, bias=False),
        nn.BatchNorm3d(planes),
        nn.ReLU(inplace=True))


def repeat(x, n=3):
    # nc333
    b, c, h, w, t = x.shape
    x = x.unsqueeze(5).unsqueeze(4).unsqueeze(3)
    x = x.repeat(1, 1, 1, n, 1, n, 1, n)
    return x.view(b, c, n * h, n * w, n * t)


class DeepMedic(nn.Module):
    '''
        Adapted from: https://github.com/pykao/BraTS2018-tumor-segmentation/blob/master/models/deepmedic.py
    '''

    def __init__(self, input_channels=1, n1=30, n2=40, n3=50, m=150, up=True):
        super(DeepMedic, self).__init__()
        # n1, n2, n3 = 30, 40, 50
        num_classes = 3
        n = 2 * n3
        self.branch1 = nn.Sequential(
            conv3x3(input_channels, n1),
            conv3x3(n1, n1),
            ResBlock(n1, n2),
            ResBlock(n2, n2),
            ResBlock(n2, n3))

        self.branch2 = nn.Sequential(
            conv3x3(input_channels, n1),
            conv3x3(n1, n1),
            conv3x3(n1, n2),
            conv3x3(n2, n2),
            conv3x3(n2, n2),
            conv3x3(n2, n2),
            conv3x3(n2, n3),
            conv3x3(n3, n3))

        self.up3 = nn.Upsample(scale_factor=3, mode='trilinear', align_corners=False) if up else repeat

        self.fc = nn.Sequential(
            conv3x3(n, m, 1),
            conv3x3(m, m, 1),
            nn.Conv3d(m, num_classes, 1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, inputs):
        x1, x2 = inputs
        x1 = self.branch1(x1)
        x2 = self.branch2(x2)
        x2 = self.up3(x2)
        x = torch.cat([x1, x2], 1)
        x = self.fc(x)
        return x

### Generalized Dice Loss

In [5]:
class GeneralizedDiceLoss(nn.Module):
    '''
        Following the equation from https://arxiv.org/abs/1707.03237 page 3
    '''

    def __init__(self):
        super(GeneralizedDiceLoss, self).__init__()

    def forward(self, im_pred, im_real):
        if len(im_pred.shape) == 4:
            im_pred = im_pred.unsqueeze(0)

        if len(im_real.shape) == 4:
            im_real = im_real.unsqueeze(0)

        im_real = im_real.permute((1, 0, 2, 3, 4))
        im_pred = im_pred.permute((1, 0, 2, 3, 4))

        eps = 1e-12
        sum_1 = torch.sum(im_real[0])
        sum_2 = torch.sum(im_real[1])
        sum_3 = torch.sum(im_real[2])

        weight_1 = 1 / (sum_1 ** 2 + eps) if sum_1 > 0 else 1
        weight_2 = 1 / (sum_2 ** 2 + eps) if sum_2 > 0 else 1
        weight_3 = 1 / (sum_3 ** 2 + eps) if sum_3 > 0 else 1

        numerator_1 = torch.sum(im_real[0] * im_pred[0]) * weight_1
        numerator_2 = torch.sum(im_real[1] * im_pred[1]) * weight_2
        numerator_3 = torch.sum(im_real[2] * im_pred[2]) * weight_3

        numerator = numerator_1 + numerator_2 + numerator_3

        denominator_1 = (torch.sum(im_real[0]) + torch.sum(im_pred[0])) * weight_1
        denominator_2 = (torch.sum(im_real[1]) + torch.sum(im_pred[1])) * weight_2
        denominator_3 = (torch.sum(im_real[2]) + torch.sum(im_pred[2])) * weight_3

        denominator = denominator_1 + denominator_2 + denominator_3

        dice_loss = 1 - ((2 * numerator) / (denominator + eps))

        return dice_loss

### Elastic Deformation Class

In [6]:
class ElasticDeformation:
    """
        Classs containing Elastic Deformation Transformation
        Adapted from week 3 AML tutorial materials
    """

    def __init__(self, num_controlpoints=None, sigma=None):

        self.num_controlpoints = num_controlpoints
        self.sigma = sigma

        # Random parameters if not defined
        if self.sigma == None:
            self.sigma = np.random.uniform(low=1, high=5)

        if self.num_controlpoints == None:
            self.num_controlpoints = int(np.random.uniform(low=1, high=6))

    def create_elastic_deformation(self, image):
        """
            We need to parameterise our b-spline transform
            The transform will depend on such variables as image size and sigma
            Sigma modulates the strength of the transformation
            The number of control points controls the granularity of our transform
        """
        # Create an instance of a SimpleITK image of the same size as our image
        itkimg = sitk.GetImageFromArray(np.zeros(image.shape))

        # This parameter is just a list with the number of control points per image dimensions
        trans_from_domain_mesh_size = [self.num_controlpoints] * itkimg.GetDimension()

        # We initialise the transform here: Passing the image size and the control point specifications
        bspline_transformation = sitk.BSplineTransformInitializer(itkimg, trans_from_domain_mesh_size)

        # Isolate the transform parameters: They will be all zero at this stage
        params = np.asarray(bspline_transformation.GetParameters(), dtype=float)

        # Let's initialise the transform by randomly initialising each parameter according to sigma
        params = params + np.random.randn(params.shape[0]) * self.sigma

        # Let's initialise the transform by randomly displacing each control point by a random distance (magnitude sigma)
        bspline_transformation.SetParameters(tuple(params))

        return bspline_transformation

    def __call__(self, image, label):
        # We need to choose an interpolation method for our transformed image, let's just go with b-spline
        resampler = sitk.ResampleImageFilter()
        resampler.SetInterpolator(sitk.sitkBSpline)

        # Let's convert our image to an sitk image
        sitk_image = sitk.GetImageFromArray(image)

        # Specify the image to be transformed: This is the reference image
        resampler.SetReferenceImage(sitk_image)
        resampler.SetDefaultPixelValue(0)

        # Initialise the transform
        bspline_transform = self.create_elastic_deformation(image)

        # Set the transform in the initialiser
        resampler.SetTransform(bspline_transform)

        # Carry out the resampling according to the transform and the resampling method
        out_img_sitk = resampler.Execute(sitk_image)

        # Convert the image back into a python array
        out_img = sitk.GetArrayFromImage(out_img_sitk)

        # We need to choose an interpolation method for our transformed image, let's just go with b-spline
        resampler_label = sitk.ResampleImageFilter()
        resampler_label.SetInterpolator(sitk.sitkNearestNeighbor)

        # Let's convert our image to an sitk image
        sitk_label = sitk.GetImageFromArray(label)

        # Specify the image to be transformed: This is the reference image
        resampler_label.SetReferenceImage(sitk_label)
        resampler_label.SetDefaultPixelValue(0)

        # Initialise the transform
        bspline_transform = self.create_elastic_deformation(label)

        # Set the transform in the initialiser
        resampler_label.SetTransform(bspline_transform)

        # Carry out the resampling according to the transform and the resampling method
        out_label_sitk = resampler_label.Execute(sitk_label)

        # Convert the image back into a python array
        out_label = sitk.GetArrayFromImage(out_label_sitk)

        return torch.tensor(out_img.reshape(image.shape), dtype=torch.float32), torch.tensor(out_label.reshape(image.shape), dtype=torch.int64)

### Affine Transformation Class ###

In [7]:
class AffineTransformation:
    """
        Classs containing Elastic Deformation Transformation
    """

    def __init__(self, rotation=5, scale=(0.95, 1.05), shear=(0.01, 0.02), return_inverse=False, inverse_matrix=None):
        self.rotation = rotation
        self.scale = scale
        self.shear = shear
        self.return_inverse = return_inverse
        self.inverse_matrix = inverse_matrix

    def get_transformation_matrix(self):
        # apply rotation on the z-axis
        degree_rotation = torch.tensor(1, dtype=torch.float32).uniform_(-self.rotation, self.rotation)
        degree_rotation = (degree_rotation * torch.pi) / 180

        matrix_rotation = torch.zeros((1, 4, 4), dtype=torch.float32)
        matrix_rotation[0, 0, 0] = torch.cos(degree_rotation)
        matrix_rotation[0, 0, 1] = torch.sin(degree_rotation)
        matrix_rotation[0, 1, 0] = -torch.sin(degree_rotation)
        matrix_rotation[0, 1, 1] = torch.cos(degree_rotation)
        matrix_rotation[0, 2, 2] = 1
        matrix_rotation[0, 3, 3] = 1

        # apply scaling on each dimension
        matrix_scale = torch.zeros((1, 4, 4), dtype=torch.float32)
        matrix_scale[0, 0, 0] = torch.tensor(1, dtype=torch.float32).uniform_(self.scale[0], self.scale[1])
        matrix_scale[0, 1, 1] = torch.tensor(1, dtype=torch.float32).uniform_(self.scale[0], self.scale[1])
        matrix_scale[0, 2, 2] = torch.tensor(1, dtype=torch.float32).uniform_(self.scale[0], self.scale[1])
        matrix_scale[0, 3, 3] = 1
        # print(matrix_scale.shape)

        # shear
        degree_shear = torch.tensor((
            torch.tensor(1, dtype=torch.float32).uniform_(self.shear[0], self.shear[1]),
            torch.tensor(1, dtype=torch.float32).uniform_(self.shear[0], self.shear[1])
        ))

        matrix_shear = torch.zeros((1, 4, 4), dtype=torch.float32)
        matrix_shear[0, 0, 0] = 1
        matrix_shear[0, 0, 1] = degree_shear[0]
        matrix_shear[0, 1, 0] = degree_shear[1]
        matrix_shear[0, 1, 1] = 1
        matrix_shear[0, 2, 2] = 1
        matrix_shear[0, 3, 3] = 1

        # generate the combined affine transformation matrix
        self.matrix_affine = torch.matmul(matrix_shear, torch.matmul(matrix_rotation, matrix_scale))

        # generate the inverse transformation matrix
        self.matrix_affine_inv = torch.inverse(self.matrix_affine)

        # return to original coordinates
        self.matrix_affine = self.matrix_affine[:, 0:3, :]
        self.matrix_affine_inv = self.matrix_affine_inv[:, 0:3, :]

    def __call__(self, image, label):
        if len(image.shape) == 3:
            image = image.unsqueeze(0).unsqueeze(0)
        if len(label.shape) == 3:
            label = label.unsqueeze(0).unsqueeze(0)

        # obtain transformation matrix
        self.get_transformation_matrix()

        # define the affine grid and apply transformation on images and labels
        if self.return_inverse:
            grid_affine = F.affine_grid(self.matrix_affine_inv, image.shape, align_corners=False)
            image_at = F.grid_sample(image.float(), grid_affine, padding_mode="border", align_corners=False)
            label_at = F.grid_sample(label.float(), grid_affine, mode='nearest', padding_mode="zeros",
                                        align_corners=False)
        else:
            grid_affine = F.affine_grid(self.matrix_affine, image.shape, align_corners=False)
            image_at = F.grid_sample(image.float(), grid_affine, padding_mode="border", align_corners=False)
            label_at = F.grid_sample(label.float(), grid_affine, mode='nearest', padding_mode="zeros",
                                        align_corners=False)

        return image_at.squeeze(0).squeeze(0), label_at.squeeze(0).squeeze(0), self.matrix_affine, self.matrix_affine_inv

### Container class for TRAIN/VAL/EVALUATE model

In [8]:
class ModelConainer():
    def __init__(self, params_model):
        self.params_model = params_model

    def __init_train_params(self):
        self.run_mode = 'train'

        self.loss_dict_train = {
            'total': [],
            'dice': [],
            'mse': [],
            'ce': [],
            'dice_n_mse': [],
            'dice_n_mse_n_ce': []
        }

        self.loss_dict_val = {
            'total': [],
            'dice': [],
            'mse': [],
            'ce': [],
            'dice_n_mse': [],
            'dice_n_mse_n_ce': []
        }

        self.loss_best_train = {
            'total': np.inf,
            'dice': np.inf,
            'mse': np.inf,
            'ce': np.inf,
            'dice_n_mse': np.inf,
            'dice_n_mse_n_ce': np.inf
        }

        self.loss_best_val = {
            'total': np.inf,
            'dice': np.inf,
            'mse': np.inf,
            'ce': np.inf,
            'dice_n_mse': np.inf,
            'dice_n_mse_n_ce': np.inf
        }

    def __init_inference_params(self):
        self.run_mode = 'inference'

    def __setup_logger(self):

        if not self.params_train['resume_condition']:
            reload(logging)

        logging.basicConfig(filename=self.params_train['path_logger_full'], encoding='utf-8', level=logging.DEBUG)

    def __create_params_file(self):
        params_dict = {
            'params_model': self.params_model,
            'params_train': self.params_train
        }

        params_dict = json.dumps(params_dict, indent=4, sort_keys=False)

        with open(self.params_train['path_params_full'], 'w') as outfile:
            outfile.write(params_dict)

        self.__print(f'{"*" * 100}')
        self.__print('\t\tTraining starting with params:')
        self.__print(f'{"*" * 100}')
        self.__print(f'{params_dict}')
        self.__print(f'{"*" * 100}')

    def __create_checkpoint_dir(self):
        if self.params_train['resume_condition']:
            self.params_train['dirname_checkpoint'] = self.params_train['resume_dir'][:11]
            self.params_train['path_checkpoint_full'] = self.params_train['resume_dir']
        else:
            self.params_train['dirname_checkpoint'] = f'{self.params_model["experiment_name"]}__' \
                                                      f'{self.params_model["init_timestamp"]}__' \
                                                      f'{self.params_model["model_name"]}__' \
                                                      f'{self.params_train["loss_name"]}__' \
                                                      f'{self.params_train["optimizer_name"]}__' \
                                                      f'lr_{self.params_train["learning_rate"]}__' \
                                                      f'ep_{self.params_train["num_epochs"]}'

            self.params_train['path_checkpoint_full'] = os.path.join(self.params_train['path_checkpoint'],
                                                                     self.params_train['dirname_checkpoint'])

        self.params_train['path_params_full'] = os.path.join(self.params_train['path_checkpoint_full'],
                                                             self.params_train['filename_params'])
        self.params_train['path_logger_full'] = os.path.join(self.params_train['path_checkpoint_full'],
                                                             self.params_train['filename_logger'])
        os.makedirs(self.params_train['path_checkpoint_full'], exist_ok=True)

    def train(self, params_train, transform_train=None):
        self.params_train = params_train
        self.transform_train = transform_train
        self.__init_train_params()
        self.__create_checkpoint_dir()
        self.__create_params_file()
        self.__setup_logger()
        self.__fit_model()

    def inference(self, params_inference):

        self.params_inference = params_inference

        self.__init_inference_params()
        self.__run_inference()

    def __set_device(self):
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    def __get_dataloaders(self, run_mode):
        if run_mode == 'train':
            self.dataset_train = DatasetHepatic(
                run_mode='train',
                transform_image=self.transform_train,
                label_percentage=0.0001,
                use_probabilistic=True,
                patch_size_normal=self.params_model['patch_size_normal'],
                patch_size_low=self.params_model['patch_size_low'],
                patch_size_out=self.params_model['patch_size_out'],
                patch_low_factor=self.params_model['patch_low_factor'],
                create_numpy_dataset=self.params_model['create_numpy_dataset'],
                dataset_variant=self.params_model['dataset_variant'],
                batch_size_inner=self.params_train['batch_size_inner'],
                train_percentage=self.params_train['train_percentage'],
                use_elastic_deformation=self.params_train['use_elastic_deformation'],
                user_affine_transformation=self.params_train['user_affine_transformation'],
                num_controlpoints=self.params_train['num_controlpoints'],
                sigma=self.params_train['sigma'],
                rotation=self.params_train['rotation'],
                scale=self.params_train['scale'],
                shear=self.params_train['shear']
            )

            self.dataset_val = DatasetHepatic(
                run_mode='val',
                label_percentage=0.0001,
                transform_image=None,
                use_probabilistic=True,
                patch_size_normal=self.params_model['patch_size_normal'],
                patch_size_low=self.params_model['patch_size_low'],
                patch_size_out=self.params_model['patch_size_out'],
                patch_low_factor=self.params_model['patch_low_factor'],
                create_numpy_dataset=self.params_model['create_numpy_dataset'],
                dataset_variant=self.params_model['dataset_variant'],
                batch_size_inner=self.params_train['batch_size_inner'],
                train_percentage=self.params_train['train_percentage']
            )

            self.dataloader_train = DataLoader(
                self.dataset_train,
                batch_size=self.params_train['batch_size'],
                shuffle=True,
                num_workers=self.params_train['num_workers'],
                pin_memory=self.params_train['pin_memory'],
                prefetch_factor=self.params_train['prefetch_factor'],
                persistent_workers=self.params_train['persistent_workers']
            )

            self.dataloader_val = DataLoader(
                self.dataset_val,
                batch_size=self.params_train['batch_size'],
                shuffle=False,
                num_workers=self.params_train['num_workers'],
                pin_memory=self.params_train['pin_memory'],
                prefetch_factor=self.params_train['prefetch_factor'],
                persistent_workers=self.params_train['persistent_workers']
            )

        elif run_mode == 'inference':
            self.dataset_inference = DatasetHepatic(
                run_mode='inference',
                transform_image=None,
                label_percentage=0.0001,
                use_probabilistic=True,
                patch_size_normal=self.params_model['patch_size_normal'],
                patch_size_low=self.params_model['patch_size_low'],
                patch_size_out=self.params_model['patch_size_out'],
                patch_low_factor=self.params_model['patch_low_factor'],
                create_numpy_dataset=self.params_model['create_numpy_dataset'],
                dataset_variant=self.params_model['dataset_variant']
            )

            self.dataloader_inference = DataLoader(
                self.dataset_inference,
                batch_size=self.params_inference['batch_size'],
                shuffle=False,
                num_workers=self.params_inference['num_workers'],
                pin_memory=self.params_inference['pin_memory'],
                prefetch_factor=self.params_inference['prefetch_factor'],
                persistent_workers=self.params_inference['persistent_workers']
            )

    def __define_model(self):
        if self.params_model['model_name'] == 'deep_medic':
            self.model = DeepMedic().to(self.device)

    def __define_criterions(self):
        self.criterion_mse = nn.MSELoss()
        self.criterion_dice = GeneralizedDiceLoss()
        self.criterion_ce = nn.CrossEntropyLoss()

    def __define_optimizr(self):
        if self.params_train['optimizer_name'] == 'adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.params_train['learning_rate'],
                betas=(self.params_train['beta_1'], self.params_train['beta_2']),
                amsgrad=self.params_train['use_amsgrad']
            )

        elif self.params_train['optimizer_name'] == 'sgd_w_momentum':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.params_train['learning_rate'],
                momentum=self.params_train['momentum']
            )
        else:
            raise NotImplementedError(f'Invalid choice of optimizer:\t{self.params_train["optimizer_name"]}')

    def __define_lr_scheduler(self):
        if self.params_train['lr_scheduler_name'] == 'plateau':
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                patience=self.params_train['patience_lr_scheduler'],
                factor=self.params_train['factor_lr_scheduler'],
                verbose=True)

    def __put_to_device(self, device, tensors):
        for index, tensor in enumerate(tensors):
            tensors[index] = tensor.to(device)
        return tensors

    def __get_one_hot_labels(self, input, labels, squeeze_dim=None):
        output = torch.zeros((input.shape[0], len(labels), input.shape[2], input.shape[3], input.shape[4]),
                             dtype=input.dtype).to(input.device)

        if not squeeze_dim is None:
            for index, label in enumerate(labels):
                output[:, index] = torch.where(input == label, 1, 0).squeeze(squeeze_dim)
        else:
            for index, label in enumerate(labels):
                output[:, label] = torch.where(input == label, 1, 0)

        return output

    def __criterion_generalized_dice(self, im_real, im_pred):
        '''
            Following the equation from https://arxiv.org/abs/1707.03237 page 3
        '''
        weights = torch.autograd.Variable(3, dtype=torch.float64, requires_grad=True)
        for index in range(3):
            count = torch.tensor(torch.sum(torch.where(im_real == index, 1, 0)), dtype=torch.double, requires_grad=True)
            # if none of the voxels are of the current category, set weight to 1
            if count == 0:
                weights[index] = torch.tensor(1, dtype=torch.double, requires_grad=True)
            else:
                weights[index] = 1 / count ** 2

        numerator = torch.zeros(3, dtype=torch.double, requires_grad=True)
        denominator = torch.zeros(3, dtype=torch.double, requires_grad=True)

        for index in range(3):
            r_l_n = torch.where(im_real == index, 1, 0)
            p_l_n = torch.where(im_pred == index, 1, 0)

            # numerator
            mult = r_l_n * p_l_n
            numerator[index] = weights[index] * torch.sum(mult)

            current_denominator = weights[index] * (torch.sum(r_l_n) + torch.sum(p_l_n))
            denominator[index] = current_denominator

        dice_loss = 1 - (2 * torch.sum(numerator) / torch.sum(denominator))

        return dice_loss

    def __save_model(self):
        '''
            Saves the model, best and the latest
        '''
        if self.params_train['save_condition']:

            save_dict = {
                'index_epoch': self.index_epoch + 1,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'lr_scheduler_state_dict': self.lr_scheduler.state_dict(),
                'params_model': self.params_model,
                'params_train': self.params_train,
                'loss_dict_train': self.loss_dict_train,
                'loss_dict_val': self.loss_dict_val,
                'loss_best_train': self.loss_best_train,
                'loss_best_val': self.loss_best_val
            }

            # save models at each epoch
            if self.params_train['save_every_epoch']:
                save_path = os.path.join(self.params_train['path_checkpoint_full'], f'{self.index_epoch + 1}.pth')
                torch.save(save_dict, save_path)

            # save the latest model
            save_path = os.path.join(self.params_train['path_checkpoint_full'], f'latest.pth')
            torch.save(save_dict, save_path)

            if self.loss_dict_val['total'][-1] <= min(self.loss_dict_val['total']):
                save_path = os.path.join(self.params_train['path_checkpoint_full'], f'best.pth')
                torch.save(save_dict, save_path)
                self.__print(f'{"*" * 10}\tNew best model saved at:\t{self.index_epoch + 1}\t{"*" * 10}')

    def __load_model(self):
        '''
            Loads the model
        '''
        if self.run_mode == 'train':
            if self.params_train['resume_condition']:
                filename_checkpoint = f'{self.params_train["resume_epoch"]}.pth'
                load_path = os.path.join(self.params_train['path_checkpoint'],
                                         self.params_train['resume_dir'],
                                         filename_checkpoint)

                if not os.path.exists(load_path):
                    raise FileNotFoundError(f'File {load_path} doesn\'t exist')

                checkpoint = torch.load(load_path)

                self.index_epoch = checkpoint['index_epoch']
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
                self.params_model = checkpoint['params_model']
                self.params_train = checkpoint['params_train']
                self.loss_dict_train = checkpoint['loss_dict_train']
                self.loss_dict_val = checkpoint['loss_dict_val']
                self.loss_best_train = checkpoint['loss_best_train']
                self.loss_best_val = checkpoint['loss_best_val']

                self.__print(f'Model loaded from epoch:\t{self.index_epoch}')
                self.index_epoch += 1
                self.start_epoch = self.index_epoch
                self.__print(f'Resuming training from epoch:\t{self.index_epoch}')

        elif self.run_mode == 'inference':
            filename_checkpoint = f'{self.params_inference["resume_epoch"]}.pth'
            load_path = os.path.join(self.params_inference['path_checkpoint'],
                                     self.params_inference['resume_dir'],
                                     filename_checkpoint)

            if not os.path.exists(load_path):
                raise FileNotFoundError(f'File {load_path} doesn\'t exist')

            checkpoint = torch.load(load_path)
            self.index_epoch = checkpoint['index_epoch']
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.params_model = checkpoint['params_model']
            self.index_epoch += 1
            self.__print(f'Model loaded from epoch:\t{self.index_epoch}')

    def __early_stop(self):
        '''
            If early stopping condition meets, break training
        '''
        # train for at least self.params_train['min_epochs_to_train'] epochs
        if self.index_epoch > self.params_train['min_epochs_to_train']:

            # if the latest loss is lower than the best, continue training,
            # otherwise check the last x losses loss_dict_val
            # if self.loss_best_val['total'][-1] < min(self.loss_best_val['total']):
            if self.loss_dict_val['total'][-1] < min(self.loss_dict_val['total']):
                self.break_training_condition = False
            else:
                index_start = len(self.loss_dict_val['total']) - 1
                index_stop = len(self.loss_dict_val['total']) - 1 - self.params_train['patience_early_stop']

                # if any of the last x losses are greater than the best loss, increase counter
                counter = 0
                for index in range(index_start, index_stop, -1):
                    if self.loss_dict_val['total'][index] > min(self.loss_dict_val['total']):
                        counter += 1

                # if counter equals the patience, break training
                if counter >= self.params_train['patience_early_stop']:
                    self.__print(f'Early stopping at epoch:\t{self.index_epoch + 1}')
                    self.break_training_condition = True

    def __run_epoch(self, dataloader, run_mode):
        '''
            Runs one epoch of training and validation loops
        '''
        loss_list_total = []
        loss_list_dice = []
        loss_list_ce = []
        loss_list_mse = []
        loss_list_dice_n_mse = []
        loss_list_dice_n_mse_n_ce = []

        loss_total, loss_dice, loss_mse, loss_ce, loss_dice_n_mse, loss_dice_n_mse_n_ce = np.Inf, np.Inf, np.Inf, np.Inf, np.Inf, np.inf

        for index_batch, batch in tqdm(enumerate(dataloader), leave=False, total=len(dataloader)):
            if run_mode == 'train':
                self.model.train()
                self.optimizer.zero_grad()
            else:
                self.model.eval()

            (image_patch_normal, image_patch_low_up, label_patch_out_real) = self.__put_to_device(self.device, batch)

            if self.params_train['batch_size_inner'] > 1:
                if len(image_patch_normal.shape) == 6:
                    batch_size_stacked = image_patch_normal.shape[0] * image_patch_normal.shape[1]

                    image_patch_normal = image_patch_normal.reshape(batch_size_stacked, image_patch_normal.shape[2],
                                                                    image_patch_normal.shape[3],
                                                                    image_patch_normal.shape[4],
                                                                    image_patch_normal.shape[5])
                    image_patch_low_up = image_patch_low_up.reshape(batch_size_stacked, image_patch_low_up.shape[2],
                                                                    image_patch_low_up.shape[3],
                                                                    image_patch_low_up.shape[4],
                                                                    image_patch_low_up.shape[5])
                    label_patch_out_real = label_patch_out_real.reshape(batch_size_stacked,
                                                                        label_patch_out_real.shape[2],
                                                                        label_patch_out_real.shape[3],
                                                                        label_patch_out_real.shape[4],
                                                                        label_patch_out_real.shape[5])

                    image_patch_low = torch.zeros((image_patch_low_up.shape[0],
                                                   self.params_model['patch_size_low'],
                                                   self.params_model['patch_size_low'],
                                                   self.params_model['patch_size_low'])).to(self.device)

                else:
                    image_patch_normal, image_patch_low_up, label_patch_out_real = image_patch_normal.squeeze(
                        0), image_patch_low_up.squeeze(0), label_patch_out_real.squeeze(0)

                image_patch_low = torch.zeros((image_patch_low_up.shape[0],
                                               image_patch_low_up.shape[1],
                                               self.params_model['patch_size_low'],
                                               self.params_model['patch_size_low'],
                                               self.params_model['patch_size_low'])).to(self.device)

                for index, current_low_up in enumerate(image_patch_low_up):
                    current_low = F.avg_pool3d(input=current_low_up, kernel_size=3, stride=None)
                    image_patch_low[index] = copy.deepcopy(current_low.detach())

            # forward pass
            label_patch_out_pred = self.model.forward((image_patch_normal, image_patch_low))

            # convert label_patch_out_real to one hot
            label_patch_out_real_one_hot = self.__get_one_hot_labels(label_patch_out_real, labels=[0, 1, 2],
                                                                     squeeze_dim=1)

            # generalized dice loss_total # dice, mse, dice_n_mse
            if self.params_train['loss_name'] == 'dice':
                loss_dice = self.criterion_dice(F.softmax(label_patch_out_pred.float(), dim=1),
                                                label_patch_out_real_one_hot.float())
                loss_list_dice.append(loss_dice.item())
                # print(f'loss_dice:\t{loss_dice.item():.5f}')
                loss_total = loss_dice
            elif self.params_train['loss_name'] == 'mse':
                loss_mse = self.criterion_mse(F.softmax(label_patch_out_pred.float(), dim=1),
                                              label_patch_out_real_one_hot.float())
                loss_list_mse.append(loss_mse.item())
                loss_total = loss_mse
            elif self.params_train['loss_name'] == 'ce':
                loss_ce = self.criterion_ce(label_patch_out_pred.float(), label_patch_out_real.squeeze(1).long())
                loss_list_ce.append(loss_ce.item())
                loss_total = loss_ce
            elif self.params_train['loss_name'] == 'dice_n_mse':
                loss_dice = self.criterion_dice(F.softmax(label_patch_out_pred.float(), dim=1),
                                                label_patch_out_real_one_hot.float())
                loss_mse = self.criterion_mse(F.softmax(label_patch_out_pred.float(), dim=1),
                                              label_patch_out_real_one_hot.float())
                loss_dice_n_mse = loss_dice + loss_mse
                loss_total = loss_dice_n_mse
                loss_list_dice.append(loss_dice.item())
                loss_list_mse.append(loss_mse.item())
                loss_list_dice_n_mse.append(loss_dice_n_mse.item())
            elif self.params_train['loss_name'] == 'dice_n_mse_n_ce':
                loss_dice = self.criterion_dice(F.softmax(label_patch_out_pred.float(), dim=1),
                                                label_patch_out_real_one_hot.float())
                loss_mse = self.criterion_mse(F.softmax(label_patch_out_pred.float(), dim=1),
                                              label_patch_out_real_one_hot.float())
                loss_ce = self.criterion_ce(label_patch_out_pred.float(), label_patch_out_real.squeeze(1).long())

                loss_dice_n_mse_n_ce = loss_dice + loss_mse + loss_ce

                loss_total = loss_dice_n_mse_n_ce
                loss_list_dice.append(loss_dice.item())
                loss_list_mse.append(loss_mse.item())
                loss_list_ce.append(loss_ce.item())
                loss_list_dice_n_mse_n_ce.append(loss_dice_n_mse_n_ce.item())
            else:
                raise NotImplementedError(f'Invalid criterion selected:\t{self.params_train["loss_name"]}')

            loss_list_total.append(loss_total)

            if run_mode == 'train':
                # calculate gradients and update weights
                loss_total.backward()
                self.optimizer.step()
            sep = '\t' if run_mode == 'train' else '\t\t'
            # self.__print(f'\tBatch:\t[{index_batch + 1} / {len(dataloader)}]'
            #                f'\n\t\t{str(run_mode).upper()}{sep}-->\t\tLoss ({self.params_train["loss_name"]}):\t\t{loss_total.item():.5f}')

            # ###############################
            # loss_mse = self.criterion_mse(F.softmax(label_patch_out_pred.float(), dim=1),
            #                               label_patch_out_real_one_hot.float())
            # loss_list_mse.append(loss_mse.item())
            #
            # print(f'Loss ({self.params_train["loss_name"]}):\t{loss_total.item():.5f}\t\tMSE:\t{loss_mse.item()}')
            # ###############################

            # print(loss_dice.item())
            # print(loss_mse.item())
            # print(loss_list_dice_n_mse.item())
            # break

        # loss_dice = sum(loss_list_dice) / len(loss_list_dice)
        # loss_mse = 0
        # loss_dice_n_mse = 0

        if len(loss_list_total) > 0:
            loss_total = sum(loss_list_total) / len(loss_list_total)
        if len(loss_list_dice) > 0:
            loss_dice = sum(loss_list_dice) / len(loss_list_dice)
        if len(loss_list_mse) > 0:
            loss_mse = sum(loss_list_mse) / len(loss_list_mse)
        if len(loss_list_ce) > 0:
            loss_ce = sum(loss_list_ce) / len(loss_list_ce)
        if len(loss_list_dice_n_mse) > 0:
            loss_dice_n_mse = sum(loss_list_dice_n_mse) / len(loss_list_dice_n_mse)
        if len(loss_list_dice_n_mse_n_ce) > 0:
            loss_dice_n_mse_n_ce = sum(loss_list_dice_n_mse_n_ce) / len(loss_list_dice_n_mse_n_ce)

        if run_mode == 'train':
            self.loss_dict_train['total'].append(loss_total.item())
            self.loss_dict_train['dice'].append(loss_dice)
            self.loss_dict_train['mse'].append(loss_mse)
            self.loss_dict_train['ce'].append(loss_ce)
            self.loss_dict_train['dice_n_mse'].append(loss_dice_n_mse)
            self.loss_dict_train['dice_n_mse_n_ce'].append(loss_dice_n_mse_n_ce)

        elif run_mode == 'val':
            self.loss_dict_val['total'].append(loss_total.item())
            self.loss_dict_val['dice'].append(loss_dice)
            self.loss_dict_val['mse'].append(loss_mse)
            self.loss_dict_val['ce'].append(loss_ce)
            self.loss_dict_val['dice_n_mse'].append(loss_dice_n_mse)
            self.loss_dict_val['dice_n_mse_n_ce'].append(loss_dice_n_mse_n_ce)

    def __update_best_losses(self):
        '''
            Updates the best loss found so far
        '''
        self.found_best_loss_flag = False

        for (key, value) in self.loss_dict_train.items():
            if self.loss_dict_train[key][-1] < min(self.loss_dict_train[key]):
                self.loss_best_train[key] = self.loss_dict_train[key][-1]
                # if self.params_train['loss_name'] == key:
                #     self.found_best_loss_flag = True

        for (key, value) in self.loss_dict_val.items():
            if self.loss_dict_val[key][-1] < min(self.loss_dict_val[key]):
                self.loss_best_val[key] = self.loss_dict_val[key][-1]
                if self.params_train['loss_name'] == key:
                    self.found_best_loss_flag = True

    def __print(self, message):
        print(message)
        logging.debug(message)
        logging.debug('working')

    def __fit_model(self):
        '''
            Trains and validates a model given hyperparameters, model and optimizer name, dataloaders, num_epochs
        '''

        # set up dataloaders, model, criterions, optimizers, schedulers
        self.__set_device()
        self.__get_dataloaders('train')
        self.__define_model()
        self.__define_criterions()
        self.__define_optimizr()
        self.__define_lr_scheduler()

        # variables to keep track of training progress
        self.start_epoch = 0
        self.break_training_condition = False
        self.end_epoch = self.params_train['num_epochs']

        self.__load_model()

        for index_epoch in range(self.start_epoch, self.end_epoch):
            time_start = time.time()
            self.index_epoch = index_epoch

            # train
            self.__run_epoch(
                dataloader=self.dataloader_train,
                run_mode='train'
            )

            # validation
            with torch.no_grad():
                self.__run_epoch(
                    dataloader=self.dataloader_val,
                    run_mode='val'
                )

            # choose which loss to put into the scheduler
            self.lr_scheduler.step(self.loss_dict_val['total'][-1])

            duration = time.time() - time_start

            self.__print(f'\n{"-" * 100}'
                         f'\nEpoch:\t[{index_epoch + 1} / {self.end_epoch}]\t\t'
                         f'Time:\t{duration:.2f} s'
                         f'\n\tTRAIN\t\t-->\t\tLoss Total:\t\t{self.loss_dict_train["total"][-1]:.5f}'
                         f'\n\tVAL\t\t\t-->\t\tLoss Total:\t\t{self.loss_dict_val["total"][-1]:.5f}'
                         f'\t\tBest:\t{min(self.loss_dict_val["total"]):.5f}'
                         f'\n{"-" * 100}\n')

            self.__update_best_losses()
            self.__save_model()
            self.__early_stop()
            # self.__early_stopper()

            if self.break_training_condition:
                break

    def __run_inference(self):
        '''
            Run 3D inference on the validation set. Generates a 3D volume of predicted labels with same shape as the original one
        '''
        self.__set_device()
        self.__get_dataloaders('inference')
        self.__define_model()
        self.__define_criterions()
        self.__load_model()

        self.__print(f'{"*" * 100}')
        self.__print('\t\tInference starting with params:')
        self.__print(f'{"*" * 100}')
        params_dict = {
            'params_model': self.params_model,
            'params_inference': self.params_inference
        }
        params_dict = json.dumps(params_dict, indent=4, sort_keys=False)
        self.__print(f'{params_dict}')
        self.__print(f'{"*" * 100}')

        for index_batch, batch in enumerate(self.dataloader_inference):
            # save only the last 30 samples to match with other group members
            if index_batch < 31:
                continue
            images, labels_real, index_filename = batch
            (images, labels_real) = self.__put_to_device(self.device, [images, labels_real])

            labels_pred, labels_pred_probabilistic, loss_dice, loss_mse = self.__stride_depth_and_inference(
                images_real=images,
                labels_real=labels_real
            )
            # assuming the batch size == 1
            index_filename = index_filename[0]
            self.__print(
                f'{index_batch + 1}: \t{index_batch}.npy\tLoss DICE:\t{loss_dice:.5f}\tLoss MSE:\t{loss_mse:.5f}')

            labels_real, labels_pred, labels_pred_probabilistic = labels_real.cpu().detach().numpy(), labels_pred.cpu().detach().numpy(), labels_pred_probabilistic.cpu().detach().numpy()

            predictions_path = os.path.join('.', 'Ensamble', 'Abhijit')
            labels_path = os.path.join('.', 'labels_true')
            os.makedirs(predictions_path, exist_ok=True)
            os.makedirs(labels_path, exist_ok=True)
            # saves the probabilistic outputs (bs x 3 x h x w x d)

            # save predictions
            np.save(os.path.join(predictions_path, f'{index_batch}.npy'), labels_pred_probabilistic,
                    allow_pickle=True)

            # save the actual labels
            np.save(os.path.join(labels_path, f'{index_batch}.npy'), labels_real,
                    allow_pickle=True)

    def __stride_depth_and_inference(self, images_real, labels_real):
        self.model.eval()
        patch_size_normal = self.params_model['patch_size_normal']
        patch_size_low = self.params_model['patch_size_low']
        patch_size_out = self.params_model['patch_size_out']
        patch_low_factor = self.params_model['patch_low_factor']

        with torch.no_grad():
            loss_list_dice = []
            loss_list_mse = []

            device = images_real.device
            batch_size, height, width, depth = images_real.shape

            # --------- loop through the whole image volume
            patch_size_low_up = patch_size_low * patch_low_factor

            patch_half_normal = (patch_size_normal - 1) // 2
            patch_half_low = (patch_size_low - 1) // 2
            patch_half_low_up = (patch_size_low_up - 1) // 2
            patch_half_out = (patch_size_out - 1) // 2

            height_new = height + patch_size_low_up
            width_new = width + patch_size_low_up
            depth_new = depth + patch_size_low_up

            # create a placeholder for the padded image
            images_padded = torch.zeros((batch_size, height_new, width_new, depth_new), dtype=torch.float32).to(device)
            labels_real_padded = torch.zeros((batch_size, height_new, width_new, depth_new), dtype=torch.float32).to(
                device)

            # labels_padded = torch.zeros((batch_size, height_new, width_new, depth_new), dtype=torch.float32).to(device)
            # print(f'images_real.shape:\t{images_real.shape}')
            # print(f'images_padded.shape:\t{images_padded.shape}')

            # copy the original image to the placeholder
            images_padded[
            :,
            patch_half_low_up: height + patch_half_low_up,
            patch_half_low_up: width + patch_half_low_up,
            patch_half_low_up: depth + patch_half_low_up
            ] = copy.deepcopy(images_real).to(device)

            labels_real_padded[
            :,
            patch_half_low_up: height + patch_half_low_up,
            patch_half_low_up: width + patch_half_low_up,
            patch_half_low_up: depth + patch_half_low_up
            ] = copy.deepcopy(labels_real).to(device)

            # print(f'{patch_half_low_up} -> {height + patch_half_low_up}')
            # print(f'{patch_half_low_up} -> {width + patch_half_low_up}')
            # print(f'{patch_half_low_up} -> {depth + patch_half_low_up}')

            # placeholder to store the inferred/reconstructed image labels
            labels_pred_whole_image = torch.zeros_like(images_real).to(device)
            labels_pred_whole_image_probabilistic = torch.zeros((batch_size, 3, height, width, depth),
                                                                dtype=torch.float32).to(device)
            # print(f'labels_pred_whole_image.shape:\t{labels_pred_whole_image.shape}')

            # indices of the original image
            h_start_orig = 0
            h_end_orig = h_start_orig + patch_size_out

            for index_h in tqdm(range(patch_half_low_up, height_new - patch_half_low_up, patch_size_out),
                                leave=False):
                # print(index_h)
                h_start_normal = index_h - patch_half_normal
                h_end_normal = index_h + patch_half_normal + 1

                h_start_low_up = index_h - patch_half_low_up
                h_end_low_up = index_h + patch_half_low_up + 1

                h_start_out = index_h - patch_half_out
                h_end_out = index_h + patch_half_out + 1

                # if the starting index of the out height > padded height; break
                if h_end_out > height_new:
                    break

                w_start_orig = 0
                w_end_orig = w_start_orig + patch_size_out

                for index_w in range(patch_half_low_up, width_new - patch_half_low_up, patch_size_out):

                    w_start_normal = index_w - patch_half_normal
                    w_end_normal = index_w + patch_half_normal + 1

                    w_start_low_up = index_w - patch_half_low_up
                    w_end_low_up = index_w + patch_half_low_up + 1

                    w_start_out = index_w - patch_half_out
                    w_end_out = index_w + patch_half_out + 1

                    if w_end_out > width_new:
                        break

                    d_start_orig = 0
                    d_end_orig = d_start_orig + patch_size_out

                    for index_d in range(patch_half_low_up, depth_new - patch_half_low_up, patch_size_out):

                        d_start_normal = index_d - patch_half_normal
                        d_end_normal = index_d + patch_half_normal + 1

                        d_start_low_up = index_d - patch_half_low_up
                        d_end_low_up = index_d + patch_half_low_up + 1

                        d_start_out = index_d - patch_half_out
                        d_end_out = index_d + patch_half_out + 1

                        if d_end_out > depth_new:
                            break

                        # extract the current patch of the expanded image
                        image_patch_normal = images_padded[
                                             :,
                                             h_start_normal: h_end_normal,
                                             w_start_normal: w_end_normal,
                                             d_start_normal: d_end_normal
                                             ]
                        # print('\nNormal')
                        # print(f'{h_start_normal} -> {h_end_normal}')
                        # print(f'{w_start_normal} -> {w_end_normal}')
                        # print(f'{d_start_normal} -> {d_end_normal}')

                        image_patch_low_up = images_padded[
                                             :,
                                             h_start_low_up: h_end_low_up,
                                             w_start_low_up: w_end_low_up,
                                             d_start_low_up: d_end_low_up
                                             ]

                        # print('\nlow_up')
                        # print(f'{h_start_low_up} -> {h_end_low_up}')
                        # print(f'{w_start_low_up} -> {w_end_low_up}')
                        # print(f'{d_start_low_up} -> {d_end_low_up}')

                        # extract the current output patch of the expanded label
                        label_patch_out_real = labels_real_padded[
                                               :,
                                               h_start_out: h_end_out,
                                               w_start_out: w_end_out,
                                               d_start_out: d_end_out
                                               ]
                        # print('\nout')
                        # print(f'{h_start_out} -> {h_end_out}')
                        # print(f'{w_start_out} -> {w_end_out}')
                        # print(f'{d_start_out} -> {d_end_out}')

                        # if d_start_out == 42:
                        #     print('sssss')

                        if not (label_patch_out_real.shape[1] * label_patch_out_real.shape[2] *
                                label_patch_out_real.shape[3] > 0):
                            # print('here')
                            continue

                        # pad uneven images (image patch normal)
                        image_patch_normal_temp = torch.zeros(
                            (batch_size, patch_size_normal, patch_size_normal, patch_size_normal)).to(device)
                        image_patch_normal_temp[:, :image_patch_normal.shape[1], :image_patch_normal.shape[2],
                        :image_patch_normal.shape[3]] = image_patch_normal
                        image_patch_normal = image_patch_normal_temp

                        # pad uneven images (image patch low_up)
                        image_patch_low_up_temp = torch.zeros(
                            (batch_size, patch_size_low_up, patch_size_low_up, patch_size_low_up)).to(device)
                        image_patch_low_up_temp[:, :image_patch_low_up.shape[1], :image_patch_low_up.shape[2],
                        :image_patch_low_up.shape[3]] = image_patch_low_up
                        image_patch_low_up = image_patch_low_up_temp

                        # resize (downsample) image_patch_low
                        image_patch_low = F.avg_pool3d(input=image_patch_low_up, kernel_size=3, stride=None)

                        # perform forward pass
                        label_patch_out_pred = self.model.forward(
                            (image_patch_normal.unsqueeze(0), image_patch_low.unsqueeze(0)))

                        # print(label_patch_out_real.shape)
                        # clip extra parts
                        if label_patch_out_real.shape[1] < patch_size_out:
                            label_patch_out_pred = label_patch_out_pred[:, :, :label_patch_out_real.shape[1], :, :]

                        if label_patch_out_real.shape[2] < patch_size_out:
                            label_patch_out_pred = label_patch_out_pred[:, :, :, :label_patch_out_real.shape[2], :]

                        if label_patch_out_real.shape[3] < patch_size_out:
                            label_patch_out_pred = label_patch_out_pred[:, :, :, :, :label_patch_out_real.shape[3]]

                        # # remove any dimensions with 0 elements
                        # if (label_patch_out_pred.shape[2] == 0) or (label_patch_out_pred.shape[3] == 0) or (label_patch_out_pred.shape[4] == 0) or (
                        #         label_patch_out_real.shape[2] == 0) or (label_patch_out_real.shape[3] == 0) or (
                        #         label_patch_out_real.shape[4] == 0):
                        #     break

                        # print(label_patch_out_pred.shape)
                        # convert label_patch_out_real to one hot
                        label_patch_out_real_one_hot = torch.zeros_like(label_patch_out_pred).to(device)
                        # print(label_patch_out_real_one_hot.shape)
                        label_patch_out_real_one_hot[:, 0] = torch.where(label_patch_out_real == 0, 1, 0)
                        label_patch_out_real_one_hot[:, 1] = torch.where(label_patch_out_real == 1, 1, 0)
                        label_patch_out_real_one_hot[:, 2] = torch.where(label_patch_out_real == 2, 1, 0)

                        # cross-entropy loss_dice
                        loss_dice = self.criterion_dice(F.softmax(label_patch_out_pred.float(), dim=1),
                                                        label_patch_out_real_one_hot.float())
                        loss_mse = self.criterion_mse(F.softmax(label_patch_out_pred.float(), dim=1),
                                                      label_patch_out_real_one_hot.float())
                        # print(loss_mse.item())
                        loss_list_dice.append(loss_dice)
                        loss_list_mse.append(loss_mse)
                        # print(loss_dice)

                        label_patch_out_pred_double = torch.argmax(label_patch_out_pred.detach(), dim=1)
                        label_patch_out_pred_double_temp = torch.zeros(batch_size, patch_size_out, patch_size_out,
                                                                       patch_size_out).to(device)
                        label_patch_out_pred_double_temp[:, :label_patch_out_pred_double.shape[1],
                        :label_patch_out_pred_double.shape[2],
                        :label_patch_out_pred_double.shape[3]] = label_patch_out_pred_double
                        label_patch_out_pred_double = label_patch_out_pred_double_temp

                        bs, h, w, d = labels_pred_whole_image[:, h_start_orig: h_end_orig,
                                      w_start_orig: w_end_orig,
                                      d_start_orig: d_end_orig].shape

                        # save the pixel wise predictions
                        labels_pred_whole_image[:, h_start_orig: h_end_orig, w_start_orig: w_end_orig,
                        d_start_orig: d_end_orig] = label_patch_out_pred_double[:, :h, :w, :d].detach()

                        # save the probabilistic predictions
                        labels_pred_whole_image_probabilistic[:, :, h_start_orig: h_end_orig, w_start_orig: w_end_orig,
                        d_start_orig: d_end_orig] = label_patch_out_pred[:, :, :h, :w, :d].detach()

                        d_start_orig = d_start_orig + patch_size_out
                        d_end_orig = d_end_orig + patch_size_out

                    w_start_orig = w_start_orig + patch_size_out
                    w_end_orig = w_end_orig + patch_size_out

                h_start_orig = h_start_orig + patch_size_out
                h_end_orig = h_end_orig + patch_size_out

                loss_dice = sum(loss_list_dice) / (len(loss_list_dice) + 1e-9)
                loss_mse = sum(loss_list_mse) / (len(loss_list_mse) + 1e-9)

        return labels_pred_whole_image, labels_pred_whole_image_probabilistic, loss_dice, loss_mse

# STEP 1 #

Write a segmentation algorithm pipeline. Train on the training set, defined as a proportion of the data in imagesTr, and validate the algorithm performance on the remaining images. Do not use any auwgmentation for now. Use any choice of optimiser.
Describe how the algorithm was trained, and what were the final results using standard image segmentation validation metrics such as Dice Score or Hausdorff Distance.

Answer Marks:

[10] Working algorithmic implementation

[ 3] Comments on the code

[ 9] Description of the training process

[ 8] Validation presentation and description

### Explanation

The for all training variants, the data was split `80/20` for training and validation. This resulted in `242` training images and `61` validation images. As I used the DeepMedic model, I extracted 3D patches from the whole volume and passed them through the network. As there are much more voxels representing the background class, I used a biased sampling strategy to extract the patches. First I randomly selected the center voxel with `40%` probability to be hepatic vessel or tumour classes each and `20%` for the background class. This voxel was determined as the center pixel and the three different crops of patches were extracted from here. I used a batch size of `8` with 16 random patches extracted from each sample. Additionally, to speed up the `i/o` operation, I converted the provided dataset to numpy ndarray format.

To extract the three different crops of the image, the following parameters were used. `patch_size_normal`=`25` refers to original crop. `patch_size_low`=`19` and `patch_low_factor`=`3` refer to the additional large but downsampled crop. i.e. a `57x57x57` patch was downsampled by a factor of `3` to produce a patch of size `19x19x19`. `patch_size_out`=`9` refers to the output produced by the network.

I used the Adam optimizer with `beta1=0.9` and `beta2=0.999` with an initial learning rate=`0.0002`. I also set `AMSGrad` to true for better convergence. Validation was performed after each epoch. The initial total number of epochs was set to `100`, while the model was trained using early stopping and learning rate decay. If the best validation loss did not improve (compared to the best found so far) for two consecutive epochs, the learning rate decreased by a factor of `10`. If the validation loss did not improve (compared to the best) for `5` consecutive epochs, the training was terminated. But regardless, the model was trained for at least `10`. The latest and best models were saved after each epoch.

At inference time, I used a sliding window approach across the whole volume to obtain the final prediction. I first padded the whole image so that the `output` patch is flushed at the "`top-left`" corner while the largest `57x57x57` patch does not go out of bounds. Then I strided by the size of the `output` size (`9`) to make predictions for each patch. The predictions for each patch were saved to an empty 3D output array.

At this step only the `Generalized DICE loss` was used, which is `1-DICE` Score (because we want to minimize).

The following block runs the training and outputs the training and validation losses after each epoch, what is the best found validation loss so far, whether the learning rate was decreased and when the model training was stopped. The model was trained for `33` epochs before early stopping. The best validation loss of `0.21903` was obtained on epoch `28`.

In [22]:
params_model = {
    'experiment_name': 'step_1',
    'model_name': 'deep_medic',
    'patch_size_normal': 25,
    'patch_size_low': 19,
    'patch_size_out': 9,
    'patch_low_factor': 3,
    'run_mode': None,
    'dataset_variant': 'npy',  # npy, nib
    'create_numpy_dataset': False,
    'init_timestamp': datetime.now().strftime("%H-%M-%S__%d-%m-%Y")
}

params_train = {
    'optimizer_name': 'adam',  # adam, sgd_w_momentum
    'loss_name': 'dice',  # dice, mse, ce, dice_n_mse, dice_n_mse_n_ce
    'beta_1': 0.9,
    'beta_2': 0.999,
    'momentum': 0.9,
    'use_amsgrad': True,
    'learning_rate': 0.0002,  # 0.0002
    'lr_scheduler_name': 'plateau',
    'patience_lr_scheduler': 2,
    'factor_lr_scheduler': 0.1,
    'early_stop_condition': True,
    'patience_early_stop': 5,
    'early_stop_patience_counter': 0,
    'min_epochs_to_train': 10,
    'num_epochs': 100,
    'save_every_epoch': True,

    'save_condition': True,  # whether to save the model
    'resume_condition': False,  # whether to resume training

    'resume_dir': 'step_1__15-08-38__05-04-2022__deep_medic__dice__adam__lr_0.0002__ep_100',
    'resume_epoch': 'latest',

    'batch_size': 8,  # 8
    'batch_size_inner': 16,  # 16 (how many patches to generate per sample)
    'train_percentage': 0.8,
    'num_workers': 8,  # 8
    'pin_memory': True,
    'prefetch_factor': 2,
    'persistent_workers': True,

    'path_checkpoint': os.path.join('.', 'checkpoints'),
    'path_checkpoint_full': '',
    'dirname_checkpoint': '',
    'filename_params': 'params.json',
    'filename_logger': 'logger.txt',
    'path_params_full': '',
    'path_logger_full': '',

    'use_elastic_deformation': False,
    'user_affine_transformation': False,

    'num_controlpoints': 20,
    'sigma': 5,

    'rotation': 10,
    'scale': (0.90, 1.10),
    'shear': (0.01, 0.02)
}

# instanciate model
set_seed(1)
params_model['experiment_name'] = 'step_1'
model_container = ModelConainer(params_model)

# train the model
model_container.train(params_train=params_train)

****************************************************************************************************
		Training starting with params:
****************************************************************************************************
{
    "params_model": {
        "experiment_name": "step_1",
        "model_name": "deep_medic",
        "patch_size_normal": 25,
        "patch_size_low": 19,
        "patch_size_out": 9,
        "patch_low_factor": 3,
        "run_mode": null,
        "dataset_variant": "npy",
        "create_numpy_dataset": false,
        "init_timestamp": "15-08-38__05-04-2022"
    },
    "params_train": {
        "optimizer_name": "adam",
        "loss_name": "dice",
        "beta_1": 0.9,
        "beta_2": 0.999,
        "momentum": 0.9,
        "use_amsgrad": true,
        "learning_rate": 0.0002,
        "lr_scheduler_name": "plateau",
        "patience_lr_scheduler": 2,
        "factor_lr_scheduler": 0.1,
        "early_stop_condition": true,
        "patience_ea

                                               


----------------------------------------------------------------------------------------------------
Epoch:	[1 / 100]		Time:	433.39 s
	TRAIN		-->		Loss Total:		0.54777
	VAL			-->		Loss Total:		0.43942		Best:	0.43942
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	1	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[2 / 100]		Time:	436.38 s
	TRAIN		-->		Loss Total:		0.44131
	VAL			-->		Loss Total:		0.37737		Best:	0.37737
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	2	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[3 / 100]		Time:	451.40 s
	TRAIN		-->		Loss Total:		0.40754
	VAL			-->		Loss Total:		0.40235		Best:	0.37737
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[4 / 100]		Time:	436.17 s
	TRAIN		-->		Loss Total:		0.40538
	VAL			-->		Loss Total:		0.35128		Best:	0.35128
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	4	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[5 / 100]		Time:	446.50 s
	TRAIN		-->		Loss Total:		0.35426
	VAL			-->		Loss Total:		0.33229		Best:	0.33229
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	5	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[6 / 100]		Time:	442.12 s
	TRAIN		-->		Loss Total:		0.38096
	VAL			-->		Loss Total:		0.33042		Best:	0.33042
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	6	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[7 / 100]		Time:	463.92 s
	TRAIN		-->		Loss Total:		0.34211
	VAL			-->		Loss Total:		0.30315		Best:	0.30315
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	7	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[8 / 100]		Time:	434.22 s
	TRAIN		-->		Loss Total:		0.33924
	VAL			-->		Loss Total:		0.33920		Best:	0.30315
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[9 / 100]		Time:	438.07 s
	TRAIN		-->		Loss Total:		0.32699
	VAL			-->		Loss Total:		0.28170		Best:	0.28170
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	9	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[10 / 100]		Time:	476.14 s
	TRAIN		-->		Loss Total:		0.32182
	VAL			-->		Loss Total:		0.27461		Best:	0.27461
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	10	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[11 / 100]		Time:	448.98 s
	TRAIN		-->		Loss Total:		0.30836
	VAL			-->		Loss Total:		0.28935		Best:	0.27461
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[12 / 100]		Time:	466.97 s
	TRAIN		-->		Loss Total:		0.31678
	VAL			-->		Loss Total:		0.25937		Best:	0.25937
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	12	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[13 / 100]		Time:	458.61 s
	TRAIN		-->		Loss Total:		0.29453
	VAL			-->		Loss Total:		0.28866		Best:	0.25937
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[14 / 100]		Time:	445.32 s
	TRAIN		-->		Loss Total:		0.32577
	VAL			-->		Loss Total:		0.25669		Best:	0.25669
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	14	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[15 / 100]		Time:	448.86 s
	TRAIN		-->		Loss Total:		0.27631
	VAL			-->		Loss Total:		0.30026		Best:	0.25669
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[16 / 100]		Time:	444.38 s
	TRAIN		-->		Loss Total:		0.28012
	VAL			-->		Loss Total:		0.27703		Best:	0.25669
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[17 / 100]		Time:	436.28 s
	TRAIN		-->		Loss Total:		0.29427
	VAL			-->		Loss Total:		0.24486		Best:	0.24486
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	17	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[18 / 100]		Time:	435.14 s
	TRAIN		-->		Loss Total:		0.28402
	VAL			-->		Loss Total:		0.29928		Best:	0.24486
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[19 / 100]		Time:	484.84 s
	TRAIN		-->		Loss Total:		0.29069
	VAL			-->		Loss Total:		0.25004		Best:	0.24486
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[20 / 100]		Time:	460.05 s
	TRAIN		-->		Loss Total:		0.26589
	VAL			-->		Loss Total:		0.23152		Best:	0.23152
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	20	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[21 / 100]		Time:	433.85 s
	TRAIN		-->		Loss Total:		0.25900
	VAL			-->		Loss Total:		0.25875		Best:	0.23152
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[22 / 100]		Time:	438.11 s
	TRAIN		-->		Loss Total:		0.26296
	VAL			-->		Loss Total:		0.25247		Best:	0.23152
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[23 / 100]		Time:	457.75 s
	TRAIN		-->		Loss Total:		0.26025
	VAL			-->		Loss Total:		0.23113		Best:	0.23113
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	23	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[24 / 100]		Time:	470.61 s
	TRAIN		-->		Loss Total:		0.25454
	VAL			-->		Loss Total:		0.23959		Best:	0.23113
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[25 / 100]		Time:	455.79 s
	TRAIN		-->		Loss Total:		0.27519
	VAL			-->		Loss Total:		0.24448		Best:	0.23113
----------------------------------------------------------------------------------------------------



                                               

Epoch 00026: reducing learning rate of group 0 to 2.0000e-05.

----------------------------------------------------------------------------------------------------
Epoch:	[26 / 100]		Time:	486.99 s
	TRAIN		-->		Loss Total:		0.26176
	VAL			-->		Loss Total:		0.24437		Best:	0.23113
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[27 / 100]		Time:	428.83 s
	TRAIN		-->		Loss Total:		0.24735
	VAL			-->		Loss Total:		0.23922		Best:	0.23113
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[28 / 100]		Time:	453.69 s
	TRAIN		-->		Loss Total:		0.23339
	VAL			-->		Loss Total:		0.21903		Best:	0.21903
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	28	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[29 / 100]		Time:	440.86 s
	TRAIN		-->		Loss Total:		0.24715
	VAL			-->		Loss Total:		0.22200		Best:	0.21903
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[30 / 100]		Time:	467.61 s
	TRAIN		-->		Loss Total:		0.25004
	VAL			-->		Loss Total:		0.22617		Best:	0.21903
----------------------------------------------------------------------------------------------------



                                               

Epoch 00031: reducing learning rate of group 0 to 2.0000e-06.

----------------------------------------------------------------------------------------------------
Epoch:	[31 / 100]		Time:	435.11 s
	TRAIN		-->		Loss Total:		0.22246
	VAL			-->		Loss Total:		0.23199		Best:	0.21903
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[32 / 100]		Time:	477.63 s
	TRAIN		-->		Loss Total:		0.23143
	VAL			-->		Loss Total:		0.23155		Best:	0.21903
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[33 / 100]		Time:	471.26 s
	TRAIN		-->		Loss Total:		0.21114
	VAL			-->		Loss Total:		0.22660		Best:	0.21903
----------------------------------------------------------------------------------------------------

Early stopping at epoch:	33


# STEP 2: Now run the same training process but now using Affine transformations and Elastic Deformations as augmentation techniques. Describe what the augmentation is doing, what parameters were used and why, and what was the outcome of the training/testing process when augmentation was used. Did you observe a smaller train/test performance gap? [10]
Answer Marks:
[ 6] 3 points for each implementation of the augmentation
[ 2] Description of the augmentation and parameters
[ 2] Description of the performance gains

### Explanation

Both `Elastic Deformation` and `Affine Transformation` are implemented as separate classes. The objects are instantiated from within the model container class. The model was instructed to use these transformations by setting the `use_elastic_deformation` and `use_affine_transformation` parameters of the `params_train` to `True`.

For elastic deformation, I set the number of control points to `20` and the sigma to `5`. The number of control points define how fine grained the transformation is going to be, where more control points mean that there will be high-frequency transformations while sigma determines the sense of smoothing. As I am using b-spline interpolation, `sigma` refers to how many control points away from the center of a control point should influence. A larger number would result in a smoother image.

For affine transformation, I randomly rotated the images by +/- `10` degrees (either clockwise or anti-clockwise) on the volume axis, which is meaningful because a person might not lie perfectly straight when taking the scan. +/- `10%` scaling was applied to the image while `0.01` and `0.02` shear in the height and width dimensions.

The model was trained for `29` epochs before early stopping. The best validation loss (`DICE`) `0.21015` was found at epoch `24`. This loss was slightly better (`0.0008`) than the loss from the previous step (`0.21095`).


In [24]:
params_model = {
    'experiment_name': 'step_1',
    'model_name': 'deep_medic',
    'patch_size_normal': 25,
    'patch_size_low': 19,
    'patch_size_out': 9,
    'patch_low_factor': 3,
    'run_mode': None,
    'dataset_variant': 'npy',  # npy, nib
    'create_numpy_dataset': False,
    'init_timestamp': datetime.now().strftime("%H-%M-%S__%d-%m-%Y")
}

params_train = {
    'optimizer_name': 'adam',  # adam, sgd_w_momentum
    'loss_name': 'dice',  # dice, mse, ce, dice_n_mse, dice_n_mse_n_ce
    'beta_1': 0.9,
    'beta_2': 0.999,
    'momentum': 0.9,
    'use_amsgrad': True,
    'learning_rate': 0.0002,  # 0.0002
    'lr_scheduler_name': 'plateau',
    'patience_lr_scheduler': 2,
    'factor_lr_scheduler': 0.1,
    'early_stop_condition': True,
    'patience_early_stop': 5,
    'early_stop_patience_counter': 0,
    'min_epochs_to_train': 10,
    'num_epochs': 100,
    'save_every_epoch': True,

    'save_condition': True,  # whether to save the model
    'resume_condition': False,  # whether to resume training

    'resume_dir': 'step_2__19-32-16__05-04-2022__deep_medic__dice__adam__lr_0.0002__ep_100',
    'resume_epoch': 'latest',

    'batch_size': 8,  # 8
    'batch_size_inner': 16,  # 16 (how many patches to generate per sample)
    'train_percentage': 0.8,
    'num_workers': 8,  # 8
    'pin_memory': True,
    'prefetch_factor': 2,
    'persistent_workers': True,

    'path_checkpoint': os.path.join('.', 'checkpoints'),
    'path_checkpoint_full': '',
    'dirname_checkpoint': '',
    'filename_params': 'params.json',
    'filename_logger': 'logger.txt',
    'path_params_full': '',
    'path_logger_full': '',

    'use_elastic_deformation': False,
    'user_affine_transformation': False,

    'num_controlpoints': 20,
    'sigma': 5,

    'rotation': 10,
    'scale': (0.90, 1.10),
    'shear': (0.01, 0.02)
}

# instanciate model
set_seed(1)
params_model['experiment_name'] = 'step_2'

# using elastic deformation
params_model['use_elastic_deformation'] = True
params_model['num_controlpoints'] = 20
params_model['sigma'] = 5

# using affine transformation
params_model['user_affine_transformation'] = True
params_model['rotation'] = 10
params_model['scale'] = (0.90, 1.10)
params_model['shear'] = (0.01, 0.02)

model_container = ModelConainer(params_model)

# train the model
model_container.train(params_train=params_train)

****************************************************************************************************
		Training starting with params:
****************************************************************************************************
{
    "params_model": {
        "experiment_name": "step_2",
        "model_name": "deep_medic",
        "patch_size_normal": 25,
        "patch_size_low": 19,
        "patch_size_out": 9,
        "patch_low_factor": 3,
        "run_mode": null,
        "dataset_variant": "npy",
        "create_numpy_dataset": false,
        "init_timestamp": "19-32-16__05-04-2022",
        "use_elastic_deformation": true,
        "num_controlpoints": 20,
        "sigma": 5,
        "user_affine_transformation": true,
        "rotation": 10,
        "scale": [
            0.9,
            1.1
        ],
        "shear": [
            0.01,
            0.02
        ]
    },
    "params_train": {
        "optimizer_name": "adam",
        "loss_name": "dice",
        "beta_1

                                               


----------------------------------------------------------------------------------------------------
Epoch:	[1 / 100]		Time:	458.88 s
	TRAIN		-->		Loss Total:		0.54768
	VAL			-->		Loss Total:		0.44073		Best:	0.44073
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	1	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[2 / 100]		Time:	470.35 s
	TRAIN		-->		Loss Total:		0.44151
	VAL			-->		Loss Total:		0.37718		Best:	0.37718
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	2	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[3 / 100]		Time:	471.91 s
	TRAIN		-->		Loss Total:		0.40489
	VAL			-->		Loss Total:		0.40777		Best:	0.37718
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[4 / 100]		Time:	452.36 s
	TRAIN		-->		Loss Total:		0.40420
	VAL			-->		Loss Total:		0.35287		Best:	0.35287
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	4	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[5 / 100]		Time:	555.81 s
	TRAIN		-->		Loss Total:		0.35344
	VAL			-->		Loss Total:		0.33996		Best:	0.33996
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	5	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[6 / 100]		Time:	519.59 s
	TRAIN		-->		Loss Total:		0.38305
	VAL			-->		Loss Total:		0.30759		Best:	0.30759
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	6	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[7 / 100]		Time:	501.14 s
	TRAIN		-->		Loss Total:		0.34097
	VAL			-->		Loss Total:		0.35565		Best:	0.30759
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[8 / 100]		Time:	496.80 s
	TRAIN		-->		Loss Total:		0.33303
	VAL			-->		Loss Total:		0.39398		Best:	0.30759
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[9 / 100]		Time:	593.79 s
	TRAIN		-->		Loss Total:		0.32447
	VAL			-->		Loss Total:		0.28083		Best:	0.28083
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	9	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[10 / 100]		Time:	601.48 s
	TRAIN		-->		Loss Total:		0.31451
	VAL			-->		Loss Total:		0.26961		Best:	0.26961
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	10	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[11 / 100]		Time:	660.61 s
	TRAIN		-->		Loss Total:		0.30092
	VAL			-->		Loss Total:		0.39474		Best:	0.26961
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[12 / 100]		Time:	647.37 s
	TRAIN		-->		Loss Total:		0.31724
	VAL			-->		Loss Total:		0.31914		Best:	0.26961
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[13 / 100]		Time:	630.68 s
	TRAIN		-->		Loss Total:		0.30178
	VAL			-->		Loss Total:		0.25256		Best:	0.25256
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	13	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[14 / 100]		Time:	518.63 s
	TRAIN		-->		Loss Total:		0.32304
	VAL			-->		Loss Total:		0.31157		Best:	0.25256
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[15 / 100]		Time:	532.93 s
	TRAIN		-->		Loss Total:		0.28392
	VAL			-->		Loss Total:		0.24674		Best:	0.24674
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	15	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[16 / 100]		Time:	814.09 s
	TRAIN		-->		Loss Total:		0.28136
	VAL			-->		Loss Total:		0.27104		Best:	0.24674
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[17 / 100]		Time:	617.02 s
	TRAIN		-->		Loss Total:		0.28540
	VAL			-->		Loss Total:		0.30630		Best:	0.24674
----------------------------------------------------------------------------------------------------



                                                 

Epoch 00018: reducing learning rate of group 0 to 2.0000e-05.

----------------------------------------------------------------------------------------------------
Epoch:	[18 / 100]		Time:	926.29 s
	TRAIN		-->		Loss Total:		0.28499
	VAL			-->		Loss Total:		0.25281		Best:	0.24674
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[19 / 100]		Time:	951.48 s
	TRAIN		-->		Loss Total:		0.28139
	VAL			-->		Loss Total:		0.22692		Best:	0.22692
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	19	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[20 / 100]		Time:	975.75 s
	TRAIN		-->		Loss Total:		0.25543
	VAL			-->		Loss Total:		0.22926		Best:	0.22692
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[21 / 100]		Time:	790.55 s
	TRAIN		-->		Loss Total:		0.24044
	VAL			-->		Loss Total:		0.23426		Best:	0.22692
----------------------------------------------------------------------------------------------------



                                                 

Epoch 00022: reducing learning rate of group 0 to 2.0000e-06.

----------------------------------------------------------------------------------------------------
Epoch:	[22 / 100]		Time:	890.60 s
	TRAIN		-->		Loss Total:		0.24628
	VAL			-->		Loss Total:		0.22822		Best:	0.22692
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[23 / 100]		Time:	849.21 s
	TRAIN		-->		Loss Total:		0.23753
	VAL			-->		Loss Total:		0.21492		Best:	0.21492
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	23	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[24 / 100]		Time:	905.11 s
	TRAIN		-->		Loss Total:		0.24468
	VAL			-->		Loss Total:		0.21015		Best:	0.21015
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	24	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[25 / 100]		Time:	656.93 s
	TRAIN		-->		Loss Total:		0.26441
	VAL			-->		Loss Total:		0.22856		Best:	0.21015
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[26 / 100]		Time:	887.41 s
	TRAIN		-->		Loss Total:		0.24436
	VAL			-->		Loss Total:		0.24515		Best:	0.21015
----------------------------------------------------------------------------------------------------



                                                 

Epoch 00027: reducing learning rate of group 0 to 2.0000e-07.

----------------------------------------------------------------------------------------------------
Epoch:	[27 / 100]		Time:	673.95 s
	TRAIN		-->		Loss Total:		0.25809
	VAL			-->		Loss Total:		0.22876		Best:	0.21015
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[28 / 100]		Time:	683.02 s
	TRAIN		-->		Loss Total:		0.24454
	VAL			-->		Loss Total:		0.22038		Best:	0.21015
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[29 / 100]		Time:	766.25 s
	TRAIN		-->		Loss Total:		0.26121
	VAL			-->		Loss Total:		0.22527		Best:	0.21015
----------------------------------------------------------------------------------------------------

Early stopping at epoch:	29


# STEP 3: To obtain even better results, you will need to optimise the parameters of the loss function, the optimiser learning rate, or the network parameters. Report how these hyperparameters were optimised and what performance gains were observed. [10]
Answer Marks:
[ 6] 2 points for each parameter that was optimised
[ 2] Description of the optimisation process
[ 2] Description of the performance gains

### Explanation

Although I have the code implemented to perform grid search in the following block, I could not perform a very wide range of parameter search due to computational resource limitations, specially the slow i/o times to load the large files in the dataset. I was also limited by PyTorch's dataloader's bug, which does not allow me to run threaded operations on an windows environment, unless I run from inside the if `__name__=='__main__'` block, so I had to make alternative arrangements. As I used learning rate decay, extensive search of the learning rate parameter was avoided.

Due to these limitations, I ran separate `.py` scripts on a separate machine, and was not able to attach the output of the results here. The semi-optimal parameters I found at this step were also used in the previous steps to maximize runtime and performance.

In [None]:
params_model = {
    'experiment_name': 'step_1',
    'model_name': 'deep_medic',
    'patch_size_normal': 25,
    'patch_size_low': 19,
    'patch_size_out': 9,
    'patch_low_factor': 3,
    'run_mode': None,
    'dataset_variant': 'npy',  # npy, nib
    'create_numpy_dataset': False,
    'init_timestamp': datetime.now().strftime("%H-%M-%S__%d-%m-%Y")
}

params_train = {
    'optimizer_name': 'adam',  # adam, sgd_w_momentum
    'loss_name': 'dice',  # dice, mse, ce, dice_n_mse, dice_n_mse_n_ce
    'beta_1': 0.9,
    'beta_2': 0.999,
    'momentum': 0.9,
    'use_amsgrad': True,
    'learning_rate': 0.0002,  # 0.0002
    'lr_scheduler_name': 'plateau',
    'patience_lr_scheduler': 2,
    'factor_lr_scheduler': 0.1,
    'early_stop_condition': True,
    'patience_early_stop': 5,
    'early_stop_patience_counter': 0,
    'min_epochs_to_train': 10,
    'num_epochs': 100,
    'save_every_epoch': True,

    'save_condition': True,  # whether to save the model
    'resume_condition': False,  # whether to resume training

    'resume_dir': 'step_3__09-08-25__05-04-2022__deep_medic__dice__adam__lr_0.0002__ep_100',
    'resume_epoch': 'latest',

    'batch_size': 8,  # 8
    'batch_size_inner': 16,  # 16 (how many patches to generate per sample)
    'train_percentage': 0.8,
    'num_workers': 8,  # 8
    'pin_memory': True,
    'prefetch_factor': 2,
    'persistent_workers': True,

    'path_checkpoint': os.path.join('.', 'checkpoints'),
    'path_checkpoint_full': '',
    'dirname_checkpoint': '',
    'filename_params': 'params.json',
    'filename_logger': 'logger.txt',
    'path_params_full': '',
    'path_logger_full': '',

    'use_elastic_deformation': False,
    'user_affine_transformation': False,

    'num_controlpoints': 20,
    'sigma': 5,

    'rotation': 10,
    'scale': (0.90, 1.10),
    'shear': (0.01, 0.02)
}

# instanciate model
set_seed(1)
params_model['experiment_name'] = 'step_3'
params_model['loss_name'] = 'dice'
params_model['use_elastic_deformation'] = True
params_model['user_affine_transformation'] = True

list_lr = [0.0002, 0.001, 0.01]
list_control_points = [10, 15, 20]
list_sigma = [5, 10, 20]
list_rotation = [5, 10]
list_scale = [(0.90, 1.10), (0.95, 1.05)]
list_shear = [(0.01, 0.02), (0.05, 0.10)]


index = 0
num_train_models = len(list_lr) * len(list_control_points) * len(list_sigma) * len(list_rotation) * len(list_scale) * len(list_shear)

print(f'Starting grid search:\n')

for lr in list_lr:
    for cp in list_control_points:
        for sigma in list_sigma:
            for rotation in list_rotation:
                for scale in list_scale:
                    for shear in list_shear: 
                        params_model['learning_rate'] = lr
                        params_model['num_controlpoints'] = cp
                        params_model['sigma'] = sigma


                        params_model['rotation'] = rotation
                        params_model['scale'] = scale
                        params_model['shear'] = shear
                        
                        print(f'Training Model:\t[{index+1}/{num_train_models}]')
                        
                        if index > 0:
                            # to avoid memory leak, but does not solve the issue altogether
                            del model_container
                            torch.cuda.empty_cache()
                            
                        model_container = ModelConainer(params_model)

                        # train the model
                        model_container.train(params_train=params_train)
                        index += 1


# STEP 4: Vessels are small and thin. This commonly results in a disconnected vessel three. Predicting distance maps (e.g. https://arxiv.org/pdf/1908.05099.pdf) is a good auxiliary
task to force a network to understand vessel geometry. Implement this auxiliary task and
assess the performance with and without the auxiliary task. [20]
Answer Marks:
[13] Correct implementation of the task
[ 2] Code comments
[ 5] Description of the performance gains

### Explanation

From the paper what I understood was that, in order to predict the distance maps, the Generalized Dice loss has to be optimized alongside the Mean Squared Loss (MSE). The performance gains from this is that, MSE helps the network localize and focus on the very small regions of the hepatic vessels and tumours (but does not handle class imbalance) while the Generalized Dice loss helps to tackle the class imbalance and segmentation performance.

To implement the complete loss from the paper (segmentation + distance + contour) we can additionally optimize the Cross-Entropy loss alongside the previous two losses. The necessary code is implemented after the following block.

#### DICE + MSE

In [9]:
params_model = {
    'experiment_name': 'step_1',
    'model_name': 'deep_medic',
    'patch_size_normal': 25,
    'patch_size_low': 19,
    'patch_size_out': 9,
    'patch_low_factor': 3,
    'run_mode': None,
    'dataset_variant': 'npy',  # npy, nib
    'create_numpy_dataset': False,
    'init_timestamp': datetime.now().strftime("%H-%M-%S__%d-%m-%Y")
}

params_train = {
    'optimizer_name': 'adam',  # adam, sgd_w_momentum
    'loss_name': 'dice',  # dice, mse, ce, dice_n_mse, dice_n_mse_n_ce
    'beta_1': 0.9,
    'beta_2': 0.999,
    'momentum': 0.9,
    'use_amsgrad': True,
    'learning_rate': 0.0002,  # 0.0002
    'lr_scheduler_name': 'plateau',
    'patience_lr_scheduler': 2,
    'factor_lr_scheduler': 0.1,
    'early_stop_condition': True,
    'patience_early_stop': 5,
    'early_stop_patience_counter': 0,
    'min_epochs_to_train': 10,
    'num_epochs': 100,
    'save_every_epoch': True,

    'save_condition': True,  # whether to save the model
    'resume_condition': False,  # whether to resume training

    'resume_dir': 'step_4__09-03-12__06-04-2022__deep_medic__dice__adam__lr_0.0002__ep_100',
    'resume_epoch': 'latest',

    'batch_size': 8,  # 8
    'batch_size_inner': 16,  # 16 (how many patches to generate per sample)
    'train_percentage': 0.8,
    'num_workers': 8,  # 8
    'pin_memory': True,
    'prefetch_factor': 2,
    'persistent_workers': True,

    'path_checkpoint': os.path.join('.', 'checkpoints'),
    'path_checkpoint_full': '',
    'dirname_checkpoint': '',
    'filename_params': 'params.json',
    'filename_logger': 'logger.txt',
    'path_params_full': '',
    'path_logger_full': '',

    'use_elastic_deformation': False,
    'user_affine_transformation': False,

    'num_controlpoints': 20,
    'sigma': 5,

    'rotation': 10,
    'scale': (0.90, 1.10),
    'shear': (0.01, 0.02)
}

# instanciate model
set_seed(1)
params_model['experiment_name'] = 'step_4'
params_model['loss_name'] = 'dice_n_mse'

params_model['use_elastic_deformation'] = True
params_model['num_controlpoints'] = 20
params_model['sigma'] = 5

params_model['user_affine_transformation'] = True
params_model['rotation'] = 10
params_model['scale'] = (0.90, 1.10)
params_model['shear'] = (0.01, 0.02)

model_container = ModelConainer(params_model)

# train the model
model_container.train(params_train=params_train)

****************************************************************************************************
		Training starting with params:
****************************************************************************************************
{
    "params_model": {
        "experiment_name": "step_4",
        "model_name": "deep_medic",
        "patch_size_normal": 25,
        "patch_size_low": 19,
        "patch_size_out": 9,
        "patch_low_factor": 3,
        "run_mode": null,
        "dataset_variant": "npy",
        "create_numpy_dataset": false,
        "init_timestamp": "09-03-12__06-04-2022",
        "loss_name": "dice_n_mse",
        "use_elastic_deformation": true,
        "num_controlpoints": 20,
        "sigma": 5,
        "user_affine_transformation": true,
        "rotation": 10,
        "scale": [
            0.9,
            1.1
        ],
        "shear": [
            0.01,
            0.02
        ]
    },
    "params_train": {
        "optimizer_name": "adam",
        "

                                               


----------------------------------------------------------------------------------------------------
Epoch:	[1 / 100]		Time:	456.72 s
	TRAIN		-->		Loss Total:		0.54791
	VAL			-->		Loss Total:		0.43348		Best:	0.43348
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	1	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[2 / 100]		Time:	495.79 s
	TRAIN		-->		Loss Total:		0.44145
	VAL			-->		Loss Total:		0.37291		Best:	0.37291
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	2	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[3 / 100]		Time:	512.69 s
	TRAIN		-->		Loss Total:		0.40581
	VAL			-->		Loss Total:		0.40209		Best:	0.37291
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[4 / 100]		Time:	562.41 s
	TRAIN		-->		Loss Total:		0.39930
	VAL			-->		Loss Total:		0.34903		Best:	0.34903
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	4	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[5 / 100]		Time:	626.50 s
	TRAIN		-->		Loss Total:		0.34937
	VAL			-->		Loss Total:		0.30827		Best:	0.30827
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	5	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[6 / 100]		Time:	753.72 s
	TRAIN		-->		Loss Total:		0.37694
	VAL			-->		Loss Total:		0.29229		Best:	0.29229
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	6	**********


                                               


----------------------------------------------------------------------------------------------------
Epoch:	[7 / 100]		Time:	705.96 s
	TRAIN		-->		Loss Total:		0.34502
	VAL			-->		Loss Total:		0.38347		Best:	0.29229
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[8 / 100]		Time:	736.27 s
	TRAIN		-->		Loss Total:		0.34238
	VAL			-->		Loss Total:		0.33099		Best:	0.29229
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[9 / 100]		Time:	758.10 s
	TRAIN		-->		Loss Total:		0.31946
	VAL			-->		Loss Total:		0.27836		Best:	0.27836
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	9	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[10 / 100]		Time:	948.15 s
	TRAIN		-->		Loss Total:		0.31934
	VAL			-->		Loss Total:		0.28000		Best:	0.27836
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[11 / 100]		Time:	908.63 s
	TRAIN		-->		Loss Total:		0.30784
	VAL			-->		Loss Total:		0.29417		Best:	0.27836
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[12 / 100]		Time:	1211.46 s
	TRAIN		-->		Loss Total:		0.30960
	VAL			-->		Loss Total:		0.25787		Best:	0.25787
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	12	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[13 / 100]		Time:	1028.96 s
	TRAIN		-->		Loss Total:		0.29819
	VAL			-->		Loss Total:		0.30882		Best:	0.25787
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[14 / 100]		Time:	1026.18 s
	TRAIN		-->		Loss Total:		0.33662
	VAL			-->		Loss Total:		0.32230		Best:	0.25787
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[15 / 100]		Time:	949.85 s
	TRAIN		-->		Loss Total:		0.28269
	VAL			-->		Loss Total:		0.24181		Best:	0.24181
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	15	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[16 / 100]		Time:	934.00 s
	TRAIN		-->		Loss Total:		0.28256
	VAL			-->		Loss Total:		0.34878		Best:	0.24181
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[17 / 100]		Time:	943.54 s
	TRAIN		-->		Loss Total:		0.29130
	VAL			-->		Loss Total:		0.23611		Best:	0.23611
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	17	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[18 / 100]		Time:	983.46 s
	TRAIN		-->		Loss Total:		0.27954
	VAL			-->		Loss Total:		0.25617		Best:	0.23611
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[19 / 100]		Time:	934.92 s
	TRAIN		-->		Loss Total:		0.29714
	VAL			-->		Loss Total:		0.25195		Best:	0.23611
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[20 / 100]		Time:	1177.66 s
	TRAIN		-->		Loss Total:		0.26778
	VAL			-->		Loss Total:		0.23050		Best:	0.23050
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	20	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[21 / 100]		Time:	920.18 s
	TRAIN		-->		Loss Total:		0.25666
	VAL			-->		Loss Total:		0.25889		Best:	0.23050
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[22 / 100]		Time:	952.53 s
	TRAIN		-->		Loss Total:		0.25525
	VAL			-->		Loss Total:		0.28110		Best:	0.23050
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[23 / 100]		Time:	871.72 s
	TRAIN		-->		Loss Total:		0.25715
	VAL			-->		Loss Total:		0.22780		Best:	0.22780
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	23	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[24 / 100]		Time:	837.13 s
	TRAIN		-->		Loss Total:		0.25507
	VAL			-->		Loss Total:		0.23877		Best:	0.22780
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[25 / 100]		Time:	928.81 s
	TRAIN		-->		Loss Total:		0.27627
	VAL			-->		Loss Total:		0.26304		Best:	0.22780
----------------------------------------------------------------------------------------------------



                                                 

Epoch 00026: reducing learning rate of group 0 to 2.0000e-05.

----------------------------------------------------------------------------------------------------
Epoch:	[26 / 100]		Time:	885.86 s
	TRAIN		-->		Loss Total:		0.25396
	VAL			-->		Loss Total:		0.23466		Best:	0.22780
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[27 / 100]		Time:	898.64 s
	TRAIN		-->		Loss Total:		0.25077
	VAL			-->		Loss Total:		0.23698		Best:	0.22780
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[28 / 100]		Time:	806.23 s
	TRAIN		-->		Loss Total:		0.22900
	VAL			-->		Loss Total:		0.21811		Best:	0.21811
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	28	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[29 / 100]		Time:	847.17 s
	TRAIN		-->		Loss Total:		0.24899
	VAL			-->		Loss Total:		0.21365		Best:	0.21365
----------------------------------------------------------------------------------------------------

**********	New best model saved at:	29	**********


                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[30 / 100]		Time:	872.10 s
	TRAIN		-->		Loss Total:		0.24870
	VAL			-->		Loss Total:		0.22354		Best:	0.21365
----------------------------------------------------------------------------------------------------



                                                 


----------------------------------------------------------------------------------------------------
Epoch:	[31 / 100]		Time:	819.09 s
	TRAIN		-->		Loss Total:		0.22305
	VAL			-->		Loss Total:		0.22730		Best:	0.21365
----------------------------------------------------------------------------------------------------



                                               

Epoch 00032: reducing learning rate of group 0 to 2.0000e-06.

----------------------------------------------------------------------------------------------------
Epoch:	[32 / 100]		Time:	792.36 s
	TRAIN		-->		Loss Total:		0.23197
	VAL			-->		Loss Total:		0.22747		Best:	0.21365
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[33 / 100]		Time:	755.72 s
	TRAIN		-->		Loss Total:		0.20683
	VAL			-->		Loss Total:		0.22076		Best:	0.21365
----------------------------------------------------------------------------------------------------



                                               


----------------------------------------------------------------------------------------------------
Epoch:	[34 / 100]		Time:	766.89 s
	TRAIN		-->		Loss Total:		0.22011
	VAL			-->		Loss Total:		0.21798		Best:	0.21365
----------------------------------------------------------------------------------------------------

Early stopping at epoch:	34


#### DCIE + MSE + CE

In [None]:
params_model = {
    'experiment_name': 'step_1',
    'model_name': 'deep_medic',
    'patch_size_normal': 25,
    'patch_size_low': 19,
    'patch_size_out': 9,
    'patch_low_factor': 3,
    'run_mode': None,
    'dataset_variant': 'npy',  # npy, nib
    'create_numpy_dataset': False,
    'init_timestamp': datetime.now().strftime("%H-%M-%S__%d-%m-%Y")
}

params_train = {
    'optimizer_name': 'adam',  # adam, sgd_w_momentum
    'loss_name': 'dice',  # dice, mse, ce, dice_n_mse, dice_n_mse_n_ce
    'beta_1': 0.9,
    'beta_2': 0.999,
    'momentum': 0.9,
    'use_amsgrad': True,
    'learning_rate': 0.0002,  # 0.0002
    'lr_scheduler_name': 'plateau',
    'patience_lr_scheduler': 2,
    'factor_lr_scheduler': 0.1,
    'early_stop_condition': True,
    'patience_early_stop': 5,
    'early_stop_patience_counter': 0,
    'min_epochs_to_train': 10,
    'num_epochs': 100,
    'save_every_epoch': True,

    'save_condition': True,  # whether to save the model
    'resume_condition': False,  # whether to resume training

    'resume_dir': 'step_4__09-03-12__06-04-2022__deep_medic__dice__adam__lr_0.0002__ep_100',
    'resume_epoch': 'latest',

    'batch_size': 8,  # 8
    'batch_size_inner': 16,  # 16 (how many patches to generate per sample)
    'train_percentage': 0.8,
    'num_workers': 8,  # 8
    'pin_memory': True,
    'prefetch_factor': 2,
    'persistent_workers': True,

    'path_checkpoint': os.path.join('.', 'checkpoints'),
    'path_checkpoint_full': '',
    'dirname_checkpoint': '',
    'filename_params': 'params.json',
    'filename_logger': 'logger.txt',
    'path_params_full': '',
    'path_logger_full': '',

    'use_elastic_deformation': False,
    'user_affine_transformation': False,

    'num_controlpoints': 20,
    'sigma': 5,

    'rotation': 10,
    'scale': (0.90, 1.10),
    'shear': (0.01, 0.02)
}

# instanciate model
set_seed(1)
params_model['experiment_name'] = 'step_4'
params_model['loss_name'] = 'dice_n_mse_n_ce'

model_container = ModelConainer(params_model)

# train the model
model_container.train(params_train=params_train)

# STEP 5: The process of image segmentation is naturally uncertain. Because of this, we would like to estimate how uncertaint the models are when segmenting the target labels. To achieve this, please implement the Augmentation-based Aleatoric method, as decribed here (https://arxiv.org/pdf/1807.07356.pdf). Optimise any parameters the method might have. Visualise the uncertainty estimates. Are the areas of high uncertainty areas where errors are more likeley to be made? [10]
Answer Marks:
[ 5] Implementation of the uncertainty estimation method
[ 2] Visualisation
[ 3] Describe relationship between error and uncertainty

# STEP 6: Ensemble all the models by averaging their probalility. You can achieve this by either sharing the models themselves among the team, or by sharing the probabilistic outputs of
the models. Comment on the algorithmic performance of the ensemble compared to your own method. [10]
Answer Marks:
[ 5] Implementation of the average ensemble
[ 5] Describe the differences in performance

### Save the outputs from the inference: I am using the model from the trained mdoel from step 4, which uses DICE loss as well as the distance maps

In [20]:
params_model = {
    'experiment_name': 'step_1',
    'model_name': 'deep_medic',
    'patch_size_normal': 25,
    'patch_size_low': 19,
    'patch_size_out': 9,
    'patch_low_factor': 3,
    'run_mode': None,
    'dataset_variant': 'npy',  # npy, nib
    'create_numpy_dataset': False,
    'init_timestamp': datetime.now().strftime("%H-%M-%S__%d-%m-%Y")
}

params_inference = {
    'loss_name': 'dice',
    'batch_size': 1,
    'train_percentage': 0.8,
    'num_workers': 0,  # 8
    'pin_memory': False,
    'prefetch_factor': 2,
    'persistent_workers': False,

    'resume_dir': 'step_4__09-03-12__06-04-2022__deep_medic__dice__adam__lr_0.0002__ep_100',
    'resume_epoch': 'best',
    'path_checkpoint': os.path.join('.', 'checkpoints'),
    'path_checkpoint_full': '',
    'dirname_checkpoint': '',
}

params_model['loss_name'] = 'dice_n_mse'

# instanciate model
set_seed(1)
model_container = ModelConainer(params_model)

# inference
model_container.inference(params_inference=params_inference)

Model loaded from epoch:	30
****************************************************************************************************
		Inference starting with params:
****************************************************************************************************
{
    "params_model": {
        "experiment_name": "step_4",
        "model_name": "deep_medic",
        "patch_size_normal": 25,
        "patch_size_low": 19,
        "patch_size_out": 9,
        "patch_low_factor": 3,
        "run_mode": null,
        "dataset_variant": "npy",
        "create_numpy_dataset": false,
        "init_timestamp": "09-03-12__06-04-2022",
        "loss_name": "dice_n_mse",
        "use_elastic_deformation": true,
        "num_controlpoints": 20,
        "sigma": 5,
        "user_affine_transformation": true,
        "rotation": 10,
        "scale": [
            0.9,
            1.1
        ],
        "shear": [
            0.01,
            0.02
        ]
    },
    "params_inference": {
        "l

                                               

31: 	30.npy	Loss DICE:	0.39051	Loss MSE:	0.03766


                                               

32: 	31.npy	Loss DICE:	0.38103	Loss MSE:	0.03800


                                               

33: 	32.npy	Loss DICE:	0.57124	Loss MSE:	0.03101


                                               

34: 	33.npy	Loss DICE:	0.35475	Loss MSE:	0.03918


                                               

35: 	34.npy	Loss DICE:	0.50300	Loss MSE:	0.03211


                                               

36: 	35.npy	Loss DICE:	0.40539	Loss MSE:	0.04297


                                               

37: 	36.npy	Loss DICE:	0.52183	Loss MSE:	0.03593


                                               

38: 	37.npy	Loss DICE:	0.51899	Loss MSE:	0.03090


                                               

39: 	38.npy	Loss DICE:	0.38306	Loss MSE:	0.03462


                                               

40: 	39.npy	Loss DICE:	0.58272	Loss MSE:	0.04245


                                               

41: 	40.npy	Loss DICE:	0.59060	Loss MSE:	0.04261


                                               

42: 	41.npy	Loss DICE:	0.45669	Loss MSE:	0.03946


                                               

43: 	42.npy	Loss DICE:	0.44385	Loss MSE:	0.02936


                                               

44: 	43.npy	Loss DICE:	0.21941	Loss MSE:	0.02203


                                               

45: 	44.npy	Loss DICE:	0.55297	Loss MSE:	0.03464


                                               

46: 	45.npy	Loss DICE:	0.50854	Loss MSE:	0.02983


                                               

47: 	46.npy	Loss DICE:	0.31312	Loss MSE:	0.02055


                                               

48: 	47.npy	Loss DICE:	0.53860	Loss MSE:	0.04269


                                               

49: 	48.npy	Loss DICE:	0.42796	Loss MSE:	0.02662


                                               

50: 	49.npy	Loss DICE:	0.51611	Loss MSE:	0.03639


                                               

51: 	50.npy	Loss DICE:	0.40107	Loss MSE:	0.02452


                                               

52: 	51.npy	Loss DICE:	0.37862	Loss MSE:	0.03711


                                               

53: 	52.npy	Loss DICE:	0.49972	Loss MSE:	0.04058


                                               

54: 	53.npy	Loss DICE:	0.44210	Loss MSE:	0.03262


                                               

55: 	54.npy	Loss DICE:	0.50054	Loss MSE:	0.03323


                                               

56: 	55.npy	Loss DICE:	0.43322	Loss MSE:	0.03024


                                               

57: 	56.npy	Loss DICE:	0.38116	Loss MSE:	0.02084


                                               

58: 	57.npy	Loss DICE:	0.54903	Loss MSE:	0.03079


                                               

59: 	58.npy	Loss DICE:	0.52512	Loss MSE:	0.03185


                                               

60: 	59.npy	Loss DICE:	0.34767	Loss MSE:	0.01777


                                               

61: 	60.npy	Loss DICE:	0.39607	Loss MSE:	0.03667


## Visualization of the predictions

At this stage we can visualize the predictions made by my model

In [2]:
index_vis = 0
im_real = np.load(os.path.join('.', 'images_true', f'{index_vis}.npy'))
label_real = np.load(os.path.join('.', 'labels_true', f'{index_vis}.npy'))
label_pred = np.load(os.path.join('.', 'Ensamble', 'Abhijit', f'{index_vis}.npy')).squeeze(0)

label_pred = torch.argmax(torch.softmax(torch.tensor(label_pred), axis=0), axis=0)

height, width, depth = label_pred.shape

def visualize_depth(index=0, label=-1):
    # cmap = 'gray'
    cmap = 'viridis'
    # cmap = 'inferno'
    # cmap = 'plasma'

    index_depth = index;
    plt.figure(figsize=(12, 6))
    plt.title('height x width')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real[:, :, index], cmap=cmap)
    plt.title('Real Image')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 2)
    if label == 0:
        current_label = ' (Background)'
    elif label == 1:
        current_label = ' (Vessel)'
    elif label == 2:
        current_label = ' (Tumour)'
    else:
        current_label = ''

    plt.title(f'Real Label{current_label}')
    if label < 0:
        plt.imshow(label_real[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_real[:, :, index] == label, 1, 0), cmap=cmap)

    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 3)
    plt.title(f'Predicted Label{current_label}')
    if label < 0:
        plt.imshow(label_pred[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_pred[:, :, index] == label, 1, 0), cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    plt.show()

interact(visualize_depth, index=widgets.IntSlider(min=0, max=depth-1, step=1, value=0), label=widgets.IntSlider(min=-1, max=2, step=1, value=-1));

interactive(children=(IntSlider(value=0, description='index', max=133), IntSlider(value=-1, description='label…

### Analysis of the visualization

The interactive plot above presents a 2D slice through the volume. Here, for the labels, the cyan colour represents the hepatic vessels and the yellow colour represents the tumours. We can see that although our model converged based on the criteria we defined, the localozation capability of our model is severly restricted.

# Team Section

# STEP 6: Ensemble all the models by averaging their probalility. You can achieve this by either sharing the models themselves among the team, or by sharing the probabilistic outputs of the models. Comment on the algorithmic performance of the ensemble compared to your own method. [10]
Answer Marks:
[ 5] Implementation of the average ensemble
[ 5] Describe the differences in performance

### Explanation

Although I performed full-sized volume predictions on `61` images of the validation set, my teammates performed prediction on subset of 30 images and a subsection of the volume which was `512x512x16`. The size `16` in the depth dimension was extracted by taking a center crop in the depth dimension. As different people saved their probabilistic predicitons in different formats (numpy, tensor, gpu tensor) and used different naming conventions, I manually renamed all the files, converted to standard numpy format and saved them as `[0-29].npy` format. We only shared the probabilistic outputs from the models among ourselves.

The non-weighted ensamble model `(DICE+MSE: 1.098875)` performed better than three of the five group members performance in terms of DICE+MSE loss, while my and Piyalitt's models performed better. The performance of the average ensamble `(1.098875)` is also much lower than the average of all our losses `(1.1304388)`.



In [11]:
# convert predictions to numpy
def prepare_npy_files(path_ensamble):
    '''
        Assumes that probabilistic predictions are placed in the Ensambled directiory,
        with the files renamed with indices between 0 and 30 (done manually)
    '''
    name_list = os.listdir(path_ensamble)
    for name in name_list:
        if name == 'Abhijit':
            continue
        for index in tqdm(range(30), leave=False):
            file_path_input = os.path.join(path_ensamble, name, f'{index}.pkl')
            with open(file_path_input, 'rb') as pickle_file:
                try:
                    content = pickle.load(pickle_file)
                    if name == 'Diego':
                        content = content.squeeze(0).permute(0, 2, 3, 1).detach().cpu().numpy()
                    elif name == 'Traudi-Beatrice':
                        content = content.cpu().detach().numpy()
                    elif name == 'Kate':
                        content = content.squeeze(0).cpu().detach().numpy()
                    np.save(os.path.join(path_ensamble, name, f'{index}.npy'), content)
                except Exception as e:
                    print(f'({name}) Exception at: {index}: {e}')

In [12]:
def get_ensamble_losses(class_weights=None, save_condition=False):
    path_ensamble = os.path.join('.', 'Ensamble')
    convert_to_numpy = False

    if convert_to_numpy:
        prepare_npy_files(path_ensamble)

    # names of all the group members
    names_list = os.listdir(path_ensamble)

    num_inference = 29  # not 30 because the last file of Piyalitt can not be read, thows exception

    # equally weight all the ensambles if no weight is supplied
    if class_weights is None:
        class_weights = np.ones(len(names_list)) / len(names_list)

    print(f'class_weights:\t{class_weights}')

    # directory to store the average ensambles
    path_avg_ensamble = os.path.join('.', 'Ensamble_avg')
    os.makedirs(path_avg_ensamble, exist_ok=True)

    # define the losses
    criterion_dice = GeneralizedDiceLoss()
    criterion_mse = nn.MSELoss()

    # arrays to hold the different losses
    num_people = len(names_list)
    loss_array_dice = np.zeros((num_inference, num_people+1))
    loss_array_mse = np.zeros((num_inference, num_people+1))
    loss_array_dice_n_mse = np.zeros((num_inference, num_people+1))

    for index_filename in range(num_inference):
        # placeholder to hold the data of each person of the current index
        labels_pred_list = []

        path_labels_true = os.path.join('.', 'labels_true', f'{index_filename}.npy')
        labels_true = np.load(path_labels_true)

        # convert the real labels to one-hot labels
        labels_true_oh = np.zeros((1, 3, labels_true.shape[0], labels_true.shape[1], labels_true.shape[2]))
        labels_true_oh[0, 0, :] = np.where(labels_true==0, 1, 0)
        labels_true_oh[0, 1, :] = np.where(labels_true==1, 1, 0)
        labels_true_oh[0, 2, :] = np.where(labels_true==2, 1, 0)

        # getting the center crop of size 16 in the depth dimension to match with other's predictions
        depth = labels_true_oh.shape[-1]
        depth_center = (depth // 2)
        depth_new = 16
        depth_new_half = depth_new // 2

        labels_true_oh = labels_true_oh[:, :, :, :, depth_center-depth_new_half:depth_center+depth_new_half]
        labels_true_oh = torch.tensor(labels_true_oh, dtype=torch.float32).squeeze(0)

        # read the current index file of each person and append to the list
        for index_name, name in enumerate(names_list):
            path_labels_pred = os.path.join(path_ensamble, name, f'{index_filename}.npy')

            labels_pred_proba = np.load(path_labels_pred)

            # getting the center crop of size 16 in the depth dimension to match with other's predictions
            if name=='Abhijit':
                labels_pred_proba = labels_pred_proba[:, :, :, :, depth_center-depth_new_half:depth_center+depth_new_half].squeeze(0)

            labels_pred_proba = torch.softmax(torch.tensor(labels_pred_proba * class_weights[index_name]), dim=0).detach().cpu().numpy()
            labels_pred_list.append(labels_pred_proba)

            # convert to torch tensors for loss calculation
            labels_pred_proba = torch.tensor(labels_pred_proba, dtype=torch.float32)

            loss_dice = criterion_dice(labels_pred_proba, labels_true_oh).item()
            loss_mse = criterion_mse(labels_pred_proba, labels_true_oh).item()
            loss_dice_n_mse = loss_dice + loss_mse

            loss_array_dice[index_filename, index_name] = loss_dice
            loss_array_mse[index_filename, index_name] = loss_mse
            loss_array_dice_n_mse[index_filename, index_name] = loss_dice_n_mse

            # print(f'{name}\t\t\tDICE:\t{loss_dice:.5f}\t\tMSE:\t{loss_mse:.5f}\t\tDICE+MSE:\t{loss_dice_n_mse:.5f}')

        # convert to array and calculate the mean
        labels_pred_list = np.array(labels_pred_list)
        labels_pred_mean = np.mean(labels_pred_list, axis=0)

        # loss of the Ensamble
        # convert to torch tensors for loss calculation
        labels_pred_mean = torch.tensor(labels_pred_mean, dtype=torch.float32)

        loss_dice = criterion_dice(labels_pred_mean, labels_true_oh).item()
        loss_mse = criterion_mse(labels_pred_mean, labels_true_oh).item()
        loss_dice_n_mse = loss_dice + loss_mse

        loss_array_dice[index_filename, index_name+1] = loss_dice
        loss_array_mse[index_filename, index_name+1] = loss_mse
        loss_array_dice_n_mse[index_filename, index_name+1] = loss_dice_n_mse

        # name = 'Ensamble'
        # print(f'{name}\t\t\tDICE:\t{loss_dice:.5f}\t\tMSE:\t{loss_mse:.5f}\t\tDICE+MSE:\t{loss_dice_n_mse:.5f}')

        # save the average of the ensamble
        if save_condition:
            path_out_current = os.path.join(path_avg_ensamble, f'{index_filename}.npy')
            np.save(path_out_current, labels_pred_mean)

    names_list_df = os.listdir(path_ensamble)
    names_list_df.append('Ensamble')

    # create dataframes of losses for easy interpretability
    df_dice = pd.DataFrame(loss_array_dice, columns=names_list_df)
    df_mse = pd.DataFrame(loss_array_mse, columns=names_list_df)
    df_dice_n_mse = pd.DataFrame(loss_array_dice_n_mse, columns=names_list_df)

    return df_dice, df_mse, df_dice_n_mse

In [14]:
df_dice, df_mse, df_dice_n_mse = get_ensamble_losses()
names_list_df = list(df_dice.columns)
loss_names = ['Dice', 'MSE', 'DICE+MSE']

df_mean_losses = pd.DataFrame(np.vstack((df_dice.mean().to_numpy(),
                                      df_mse.mean().to_numpy(),
                                      df_dice_n_mse.mean().to_numpy())))

df_mean_losses.columns = names_list_df
df_mean_losses.index = loss_names

print(f'Mean of all the losses with uniform weighting')
display(df_mean_losses)

class_weights:	[0.2 0.2 0.2 0.2 0.2]
Mean of all the losses with uniform weighting


Unnamed: 0,Abhijit,Diego,Kate,Piyalitt,Traudi-Beatrice,Ensamble
Dice,0.996092,0.998731,0.998809,0.995885,0.998883,0.998366
MSE,0.034254,0.193745,0.202906,0.037914,0.194976,0.100509
DICE+MSE,1.030346,1.192475,1.201715,1.033799,1.193859,1.098875


# STEP 7: Sometimes, one or several of the methods used to ensemble can be underperforming, so a simple avegare might be non-ideal. Instead, you can chose to weight the different models. Do this either by manually chosing the weights between different models or by using the performance on the training set to define these. Does a weighted ensemble perform better? Describe why. [10]
Answer Marks:
(If followed the performance-based route)
[ 7] Implementation of the performance-based weighted ensemble
[ 3] Describe the differences in performance

### Explanation

I weighted our models based on the respective performances in the previous step, which is proportional to to `1-LOSS`. I used all three variants of losses (`DICE`, `MSE`, `DICE+MSE`) available at this stage and compared their performance against one another. Then I applied a `softmax` function on them to represent probabilities. As Piyalitt's model had the best (lowest) loss in the previous step, it was given the highest weight. These weights were them multiplied with the raw probabilities to obtain the final performance.

From the table below it can be observed that regardless which loss we use to weight our models, the performance improves. But the highest gain (`0.2%`) is achieved when we use the `DICE+MSE` loss



In [15]:
name_weights_dice = torch.softmax(torch.tensor(1-df_dice.mean()[:5]).unsqueeze(0), dim=1).detach().numpy()[0]
name_weights_mse = torch.softmax(torch.tensor(1-df_mse.mean()[:5]).unsqueeze(0), dim=1).detach().numpy()[0]
name_weights_dice_n_mse = torch.softmax(torch.tensor(1-df_dice_n_mse.mean()[:5]).unsqueeze(0), dim=1).detach().numpy()[0]

df_weights = pd.DataFrame(np.vstack((name_weights_dice, name_weights_mse, name_weights_dice_n_mse)))
df_weights.columns = names_list_df[:5]
df_weights.index = loss_names

print(f'Three different sets of weights are displayed, based on each type of loss on the previous step')
display(df_weights)

Three different sets of weights are displayed, based on each type of loss on the previous step


Unnamed: 0,Abhijit,Diego,Kate,Piyalitt,Traudi-Beatrice
Dice,0.200318,0.19979,0.199774,0.200359,0.199759
MSE,0.22001,0.187575,0.185865,0.219206,0.187344
DICE+MSE,0.220335,0.187357,0.185634,0.219575,0.187098


In [16]:
for index, name_weights in enumerate([name_weights_dice, name_weights_mse, name_weights_dice_n_mse]):
    print(f'{"*"*100}')
    print(f'Current weighting selected based on:\t{loss_names[index]}')
    df_dice_w, df_mse_w, df_dice_n_mse_w = get_ensamble_losses(class_weights=name_weights)
    names_list_df = list(df_dice_w.columns)

    print(f'Mean of Different Losses after applying class weights to different models')

    df_mean_losses_w = pd.DataFrame(np.vstack((df_dice_w.mean().to_numpy(),
                                      df_mse_w.mean().to_numpy(),
                                      df_dice_n_mse_w.mean().to_numpy())))
    df_mean_losses_w.columns = df_dice_w.columns
    df_mean_losses_w.index = loss_names

    df_mean_losses_w['Ensamble_old'] = df_mean_losses['Ensamble']
    df_mean_losses_w['Ensamble_difference %'] = ((df_mean_losses['Ensamble'] - df_mean_losses_w['Ensamble']) / df_mean_losses['Ensamble']) * 100

    display(df_mean_losses_w)

****************************************************************************************************
Current weighting selected based on:	Dice
class_weights:	[0.20031758 0.1997898  0.19977408 0.20035917 0.19975936]
Mean of Different Losses after applying class weights to different models


Unnamed: 0,Abhijit,Diego,Kate,Piyalitt,Traudi-Beatrice,Ensamble,Ensamble_old,Ensamble_difference %
Dice,0.996091,0.998731,0.998809,0.995874,0.998883,0.998366,0.998366,4.1e-05
MSE,0.034252,0.193775,0.202927,0.037775,0.195009,0.100465,0.100509,0.043896
DICE+MSE,1.030343,1.192505,1.201737,1.033649,1.193892,1.09883,1.098875,0.004052


****************************************************************************************************
Current weighting selected based on:	MSE
class_weights:	[0.22000962 0.18757531 0.18586478 0.21920581 0.18734448]
Mean of Different Losses after applying class weights to different models


Unnamed: 0,Abhijit,Diego,Kate,Piyalitt,Traudi-Beatrice,Ensamble,Ensamble_old,Ensamble_difference %
Dice,0.996015,0.998741,0.998815,0.995293,0.998883,0.998347,0.998366,0.001958
MSE,0.034201,0.195499,0.204253,0.031127,0.196682,0.09831,0.100509,2.1879
DICE+MSE,1.030216,1.19424,1.203068,1.026421,1.195565,1.096656,1.098875,0.201896


****************************************************************************************************
Current weighting selected based on:	DICE+MSE
class_weights:	[0.22033462 0.18735747 0.18563431 0.2195752  0.18709839]
Mean of Different Losses after applying class weights to different models


Unnamed: 0,Abhijit,Diego,Kate,Piyalitt,Traudi-Beatrice,Ensamble,Ensamble_old,Ensamble_difference %
Dice,0.996014,0.998741,0.998815,0.995281,0.998883,0.998346,0.998366,0.001995
MSE,0.034201,0.19553,0.204275,0.031009,0.196715,0.098271,0.100509,2.226377
DICE+MSE,1.030214,1.194271,1.20309,1.026291,1.195598,1.096617,1.098875,0.205448
