# 2.5D Segmentation of lungs affected by Covid-19 pneumonia

* * *
The following code is implemented to be runned all in once by setting tasks and their parameters.

This notebook has been made to test different pre-processing techniques since it processes images in real-time and takes as input a 2D slice in `.npy` format. Inputs are statically made from orignal raw CT-scans by settings `preprocess_dataset=True`.

Use the panel below to set the desired settings.

In case of fine-tuning or evaluation, more settings must be edited in the relative section

Dice Coefficient taken from https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/losses/dice.html

In [1]:
#------------------------------------------------------DO---------------------------------------------------------------

save_training = False # Either save the model during the training phase or not
do_train = False # Perform the training 
do_fine_tuning = False # Perform the fine-tuning. More settings for fine-tuning must be setted in the relative section
do_predict = False # Perform evaluation. More settings for evaluation must be setted in the relative section

preprocess_dataset = False # Do only once, prepare ct-scans by rescaling them to the desired shape and store them as .npy

#-------------------------------------------------NORMALIZATION---------------------------------------------------------

eq_method = 'clahe+histeq' # Equalization method. Only if images are normalized in [0, 255]: None, histeq, clahe, clahe+histeq, histeq+clahe
norm_method = 'as_colab' # Normalization method. Possible values: None, best or as_colab, mean_std, as_paper, as_paper_in_0_1, as_paper_in_0_255, in_range, custom, adjust_gamma, sharp
lower, higher = 0, 255 # Lower and upper bounds in case of norm_method = in_range

#---------------------------------------------------APPROACH------------------------------------------------------------

approach = '2.5D' # Possible values: 2.5D, transformer

#----------------------------------------------------NETWORK------------------------------------------------------------

# For Encoder-Decoder
encoder_name = 'resnet101' # Encoder to be used. Possible values: mobilenet_v2, resnet101, densenet169, vgg16 ...
encoder_weights = 'imagenet' # Use either a pre-trained encoder or not. Possible values: imagenet, None
decoder_name = 'Unet' # Decoder to be used. Possible values: Unet, MAnet, DeepLabV3, ...
attention = None # Apply attention to the decoder. Possible values: scse, None
batch_norm = True # Batch normalization technique. Possible values: True, inplace, False

# For TransUnet  
vit_name = 'R50+ViT-B_16' # Transformer configuration. Possible values: ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14, R50+ViT-B_16, testing
vit_patches_size = 16 # Size of transformer's patches
n_skip = 2 # Number of skip to be used

#----------------------------------------------------DATASET------------------------------------------------------------

n_classes = 3 # Total number of classes. 0:Background, 1:Lungs, 2:Infection
shape = 512 # Shape of all the axes
target_shape = (shape, shape, shape) # Volume shape to which rescale original ct-scans
target_resolution = (334/shape, 334/shape, 1) # Resolution (mm^3) to which rescale original ct-scans
merging_method = 'softmax' # Plurality voting merging method

#-------------------------------------------------HYPERPRAMATERS--------------------------------------------------------

loss_name = 'CE' # Loss function to be used. Possible values: WDL, DL, BCElogits, CE, WDL+CE, DL+CE, WCE
sobel_loss = False # Add a Sobel error contribution to the loss function

optimizer_name = 'ADAM' # Optimizer. Possible values: ADAM, ADAMW, SGD
lr = 1e-3 # Learning rate
weight_decay = 0 # Optimizer's weight decay

use_scheduler = False # Use either a learning rate scheduler or not

training_batch_size = 8 # Size of each train-set batch
validation_batch_size = 8 # Size of each valid-set batch
n_epochs = 30 # Total number of training epochs

freeze_encoder = True # For Fine tuning; freeze encoder weights or not.

dataset_name = 'challenge' # Dataset to be used for the training. Possible values: zenodo, challenge, zenodo+challenge

#------------------------------------------------------PATHS------------------------------------------------------------
import os

model_save_path = "" # Insert here the desired path to store trained models otherwise a default one will be used

if model_save_path == "":
    model_save_path = "/home/" + os.listdir("/home/")[0] + "/models/"

#---------------------------------------------------MODEL SAVE PREFIX---------------------------------------------------
custom_prefix = '' # Insert here a desired prefix followed by '_'

# Define model prefix based on some settings
if norm_method == 'in_range':
    model_prefix_name = eq_method + "_" + norm_method + "[" + str(lower) + ',' + str(higher) + ']_' 
else:
    model_prefix_name = eq_method + "_" + norm_method + "_"
if batch_norm == 'inplace':
    model_prefix_name = 'bn_inplace_' + model_prefix_name
if attention == 'scse':
    model_prefix_name = 'scse_' + model_prefix_name
if not freeze_encoder:
    finetuning_prefix_name = 'all_trainable_' + 'finetuning_'
else:
    finetuning_prefix_name = 'encoder_freezed_' + 'finetuning_'

model_prefix_name = custom_prefix +  model_prefix_name # Add the custom prefix

#_______________________________________________________________________________________________________________________
#-----------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------SUMMARY------------------------------------------------------------
#-----------------------------------------------------------------------------------------------------------------------

print("\n----------------PIPELINE----------------\n")
print("Do Training: {}".format(do_train))
print("Do Fine tuning: {}".format(do_fine_tuning))
print("Saving enabled: {}".format(save_training))
print("Do Evaluation: {}".format(do_predict))
print("\n---------------PROCESSING---------------\n")
if norm_method == 'in_range':
    print("Normalization method: in range [{},{}]".format(lower,higher))
else:
    print("Normalization method: {}".format(norm_method))
print("Equalization method: {}".format(eq_method))
print("\n---------------APPROACH---------------\n")
print("Approach: {}".format(approach))
print("Network: {}".format(vit_name)) if approach=='transformer' else print("Network: {}+{}".format(encoder_name, decoder_name))
if do_train or do_fine_tuning:
    if do_train:
        print("\n---------------TRAINING---------------\n")
    elif do_fine_tuning:
        print("\n--------------FINE TUNING-------------\n")
    if do_fine_tuning:
        print("Freeze encoder: {}".format(freeze_encoder))
    print("Training on: {}".format(dataset_name))
    print("Loss: {}+Sobel".format(loss_name)) if sobel_loss else print("Loss: {}".format(loss_name))
    print("Learning Rate: {}".format(lr))
    print("Weight Decay: {}".format(weight_decay))
    print("Scheduler: {}".format(use_scheduler))
    print("Batch size Train/Valid: {}/{}".format(training_batch_size,validation_batch_size))
    print("Model savepath: {}".format(model_save_path))
    if do_train:
        print("Model save prefix: {}".format(model_prefix_name))
    elif do_fine_tuning:
        print("Model save prefix: {}".format(finetuning_prefix_name))
print("\n------------------------------------------\n")


----------------PIPELINE----------------

Do Training: False
Do Fine tuning: False
Saving enabled: False
Do Evaluation: False

---------------PROCESSING---------------

Normalization method: as_colab
Equalization method: clahe+histeq

---------------APPROACH---------------

Approach: 2.5D
Network: resnet101+Unet

------------------------------------------



# Globals and utilities

## import

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import nibabel as nib
import warnings
import cv2
import time
import shutil
import datetime
import random
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from skimage import measure, color, exposure
from sklearn.metrics import confusion_matrix
import pickle
from scipy import ndimage
import scipy
import itertools
import math
from os.path import join as pjoin
from collections import OrderedDict
import copy
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
import ml_collections

import libraries.segmentation_models_pytorch as smp

## load Transformer pre-trained weights

In [3]:
# For TransUNet approach
# Download pretrained weights for Transformer network
if vit_name == 'ViT-H_14':
    if not os.path.exists(model_save_path + 'vit_checkpoint/imagenet21k/' + vit_name + '.npz'):
        !mkdir -p `echo $model_save_path`vit_checkpoint/imagenet21k
        !wget https://storage.googleapis.com/vit_models/imagenet21k/{vit_name}.npz
        !mv {vit_name}.npz `echo $model_save_path`vit_checkpoint/imagenet21k/{vit_name}.npz
else:
    if not os.path.exists(model_save_path + 'vit_checkpoint/imagenet21k+imagenet2012/' + vit_name +'.npz'):
        !mkdir -p `echo $model_save_path`vit_checkpoint/imagenet21k+imagenet2012
        !wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{vit_name}.npz
        !mv {vit_name}.npz `echo $model_save_path`vit_checkpoint/imagenet21k+imagenet2012/{vit_name}.npz

## Globals

In [4]:
# Select device
device="cuda:0" if torch.cuda.is_available() else "cpu"
print("Actual device: ", device)
if 'cuda' in device:
    print("Device info: {}".format(str(torch.cuda.get_device_properties(device)).split("(")[1])[:-1])
print(torch.cuda.memory_summary(device, abbreviated=True))

# SEED
seed = 169366
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

Actual device:  cuda:0
Device info: name='Tesla V100-SXM2-32GB', major=7, minor=0, total_memory=32480MB, multi_processor_count=80
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| GPU reserved memory   |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Non-releas

# Dataset

In [5]:
# Natural language sorting method
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

## Datasets globals

### Original

In [6]:
# Raw ZENODO .nii.gz
zenodo_root = '/home/' + os.listdir("/home/")[0] + '/datasets/zenodo/'
zenodo_image_root = zenodo_root + 'image/'
zenodo_mask_root = zenodo_root + 'mask/'
zenodo = {'image': {'training': [], 'validation': []}, 'mask': {'training': [], 'validation': []}}
zenodo['image']['training'] = [zenodo_image_root + 'training/' + f for f in os.listdir(zenodo_image_root + 'training/')]
zenodo['image']['validation'] = [zenodo_image_root + 'validation/' + f for f in os.listdir(zenodo_image_root + 'validation/')]
zenodo['mask']['training'] = [zenodo_mask_root + 'training/' + f for f in os.listdir(zenodo_mask_root + 'training/')]
zenodo['mask']['validation'] = [zenodo_mask_root + 'validation/' + f for f in os.listdir(zenodo_mask_root + 'validation/')]      
zenodo['image']['training'].sort(key=natural_keys)
zenodo['mask']['training'].sort(key=natural_keys)
zenodo['image']['validation'].sort(key=natural_keys)
zenodo['mask']['validation'].sort(key=natural_keys)
zenodo['mean'] = -558.356
zenodo['std'] = 492.593
print("Zenodo -> files found: {}".format(len(zenodo['image']['training'])+len(zenodo['image']['validation'])))

# Raw CHALLENGE .nii.gz
challenge_root = '/home/' + os.listdir("/home/")[0] + '/datasets/challenge/'
challenge_image_root = challenge_root + 'image/'
challenge_mask_root = challenge_root + 'mask/'
challenge = {'image': {'training': [], 'validation': []}, 'mask': {'training': [], 'validation': []}}
challenge['image']['training'] = [challenge_image_root + 'training/' + f for f in os.listdir(challenge_image_root + 'training/')]
challenge['image']['validation'] = [challenge_image_root + 'validation/' + f for f in os.listdir(challenge_image_root + 'validation/')]
challenge['mask']['training'] = [challenge_mask_root + 'training/' + f for f in os.listdir(challenge_mask_root + 'training/')]
challenge['mask']['validation'] = [challenge_mask_root + 'validation/' + f for f in os.listdir(challenge_mask_root + 'validation/')]
challenge['image']['training'].sort(key=natural_keys)
challenge['mask']['training'].sort(key=natural_keys)
challenge['image']['validation'].sort(key=natural_keys)
challenge['mask']['validation'].sort(key=natural_keys)
challenge['mean'] = -882.439
challenge['std'] = 723.039 
print("Challenge -> files found: {}".format(len(challenge['image']['training'])+len(challenge['image']['validation'])))

Zenodo -> files found: 10
Challenge -> files found: 199


### Processed

In [7]:
# ZENODO .npy
zenodo_root_proc = '/home/' + os.listdir("/home/")[0] + '/datasets/processed/npy/zenodo2D/'
zenodo_image_root_proc = zenodo_root_proc + 'image/'
zenodo_mask_root_proc = zenodo_root_proc + 'mask/'
if os.path.isdir(zenodo_root_proc):
    zenodo_proc = {'image': {'training': [], 'validation': []}, 'mask': {'training': [], 'validation': []}}
    zenodo_proc['image']['training'] = [zenodo_image_root_proc + 'training/' + f for f in os.listdir(zenodo_image_root_proc + 'training/')]
    zenodo_proc['image']['validation'] = [zenodo_image_root_proc + 'validation/' + f for f in os.listdir(zenodo_image_root_proc + 'validation/')]
    zenodo_proc['mask']['training'] = [zenodo_mask_root_proc + 'training/' + f for f in os.listdir(zenodo_mask_root_proc + 'training/')]
    zenodo_proc['mask']['validation'] = [zenodo_mask_root_proc + 'validation/' + f for f in os.listdir(zenodo_mask_root_proc + 'validation/')]      
    zenodo_proc['image']['training'].sort(key=natural_keys)
    zenodo_proc['mask']['training'].sort(key=natural_keys)
    zenodo_proc['image']['validation'].sort(key=natural_keys)
    zenodo_proc['mask']['validation'].sort(key=natural_keys)
    zenodo_proc['mean'] = -558.356
    zenodo_proc['std'] = 492.593
    print("Zenodo -> files found: {}".format(len(zenodo_proc['image']['training'])+len(zenodo_proc['image']['validation'])))

# CHALLENGE .npy
challenge_root_proc = '/home/' + os.listdir("/home/")[0] + '/datasets/processed/npy/challenge2D/'
challenge_image_root_proc = challenge_root_proc + 'image/'
challenge_mask_root_proc = challenge_root_proc + 'mask/'
if os.path.isdir(challenge_root_proc):
    challenge_proc = {'image': {'training': [], 'validation': []}, 'mask': {'training': [], 'validation': []}}
    challenge_proc['image']['training'] = [challenge_image_root_proc + 'training/' + f for f in os.listdir(challenge_image_root_proc + 'training/')]
    challenge_proc['image']['validation'] = [challenge_image_root_proc + 'validation/' + f for f in os.listdir(challenge_image_root_proc + 'validation/')]
    challenge_proc['mask']['training'] = [challenge_mask_root_proc + 'training/' + f for f in os.listdir(challenge_mask_root_proc + 'training/')]
    challenge_proc['mask']['validation'] = [challenge_mask_root_proc + 'validation/' + f for f in os.listdir(challenge_mask_root_proc + 'validation/')]
    challenge_proc['image']['training'].sort(key=natural_keys)
    challenge_proc['mask']['training'].sort(key=natural_keys)
    challenge_proc['image']['validation'].sort(key=natural_keys)
    challenge_proc['mask']['validation'].sort(key=natural_keys)
    challenge_proc['mean'] = -882.439
    challenge_proc['std'] = 723.039 
    print("Challenge -> files found: {}".format(len(challenge_proc['image']['training'])+len(challenge_proc['image']['validation'])))

Zenodo -> files found: 10
Challenge -> files found: 18


## Dataset visualization

In [8]:
def load_scan(filepath):
    """
    Loads a volume from path. It can be either .npy or .nii.gz
    """
    if filepath[-2:] == 'gz':
        return nib.load(filepath)
    elif filepath[-3:] == 'npy':
        return np.load(filepath)

def slice_image(image, idx_slice, axis='z'): 
    """
    Takes a volume slice.
    ----------
    Parameters:
        image: nibabel object
            The raw .nii.gz ct-scan
        idx_slice: int
            Index of the derided slice
        axis: str, optional
            Axis along which the image is being sliced
    Returns:
        The sliced 2D image as numpy array
    """
    if axis == 'x':
        cropped_img = image.slicer[idx_slice:idx_slice+1, :, :]
    elif axis == 'y':
        cropped_img = image.slicer[:, idx_slice:idx_slice+1, :]
    elif axis == 'z':
        cropped_img = image.slicer[:, :, idx_slice:idx_slice+1]
    cropped_img = cropped_img.get_fdata().squeeze()
    return cropped_img
  
def show_sample_shape(dataset, show_each_file=True):
    """
    Shows shape, resolution and HU values of samples inside a dataset
    ----------
    Parameters:
        dataset: dict
            The dataset dict create above
        show_each_file: bool, optional
            Shows info about all the dataset samples if True, shows only a summary instead
    """
    if show_each_file:
        print('File:\t\t\t\t\t Shape:\t\t\t\t Voxel dim:')
    min_shape, min_res = (float('inf'), float('inf'), float('inf')), (float('inf'), float('inf'), float('inf'))
    max_shape, max_res = (float('-inf'), float('-inf'), float('-inf')), (float('-inf'), float('-inf'), float('-inf'))
    min_HU, max_HU = float('inf'), float('-inf')
    mean, std = 0, 0
    all_files = dataset['image']['training']
    all_files.extend(dataset['image']['validation'])
    for filename in all_files:
        img = load_scan(filename)
        header = img.header
        img = img.get_fdata()
        local_max, local_min = np.max(img), np.min(img)
        mean += np.mean(img.flatten())
        std += np.std(img.flatten())
        if show_each_file:
            print('{}\t\t {}\t\t {}'.format(filename.split("/")[-1], img.shape, header.get_zooms()))
        if sum(img.shape) < sum(min_shape):
            min_shape = img.shape
        if sum(header.get_zooms()) < sum(min_res):
            min_res = header.get_zooms()
        if sum(img.shape) > sum(max_shape):
            max_shape = img.shape
        if sum(header.get_zooms()) > sum(max_res):
            max_res = header.get_zooms()
        if local_max > max_HU:
            max_HU = local_max
        if local_min < min_HU:
            min_HU = local_min
        
    print("----------------------------------------------------------------------------------------------")
    print("Min shape: {}\tMin Resolution: {}\nMax shape: {}\tMax Resolution: {}".format(min_shape,min_res,max_shape,max_res))
    print("Min HU val: {} \t\tMax HU val: {}".format(min_HU, max_HU))
    print("HU Mean: {:.3f}\t\tHU STD: {:.3f}".format(mean/len(all_files), std/len(all_files)))
    print("______________________________________________________________________________________________")

print("Zenodo Dataset\n")
#show_sample_shape(zenodo)
print("\n\nChallenge Dataset\n")
#show_sample_shape(challenge)

Zenodo Dataset



Challenge Dataset



## Normalization

### Funcs

In [9]:
def rescale_to_standard(array, resolution, target_resolution=(334/512, 334/512, 1), target_shape=(512, 512, 512)):
    # pad and rescale the array to the same resolution and shape for further processing.
    # input: array must has shape (x, y, z) and resolution is a list or tuple with three elements

    original_shape = np.shape(array)
    target_volume = (target_resolution[0]*target_shape[0], target_resolution[1]*target_shape[1], target_resolution[2]*target_shape[2])
    shape_of_target_volume = (int(target_volume[0]/resolution[0]), int(target_volume[1]/resolution[1]), int(target_volume[2]/resolution[2]))

    if original_shape[2] * resolution[2] > target_volume[2]:
        warnings.warn('z-axis is longer than expectation. Make sure lung is near the center of z-axis.', SyntaxWarning)
        array = array[:, :, 100::]
        original_shape = np.shape(array)

    x = max(shape_of_target_volume[0], original_shape[0]) + 2
    y = max(shape_of_target_volume[1], original_shape[1]) + 2
    z = max(shape_of_target_volume[2], original_shape[2]) + 2

    x_start = int(x/2)-int(original_shape[0]/2)
    x_end = x_start + original_shape[0]
    y_start = int(y/2)-int(original_shape[1]/2)
    y_end = y_start + original_shape[1]
    z_start = int(z / 2) - int(original_shape[2] / 2)
    z_end = z_start + original_shape[2]

    array_intermediate = np.zeros((x, y, z), 'float32')
    array_intermediate[x_start:x_end, y_start:y_end, z_start:z_end] = array

    x_start = int(x / 2) - int(shape_of_target_volume[0] / 2)
    x_end = x_start + shape_of_target_volume[0]
    y_start = int(y / 2) - int(shape_of_target_volume[1] / 2)
    y_end = y_start + shape_of_target_volume[1]
    z_start = int(z / 2) - int(shape_of_target_volume[2] / 2)
    z_end = z_start + shape_of_target_volume[2]

    array_intermediate = array_intermediate[x_start:x_end, y_start:y_end, z_start:z_end]  # Now the array is padded

    # rescaling:
    array_standard_xy = np.zeros((target_shape[0], target_shape[1], shape_of_target_volume[2]), 'float32')
    for s in range(shape_of_target_volume[2]):
        array_standard_xy[:, :, s] = cv2.resize(array_intermediate[:, :, s], (target_shape[0], target_shape[1]), cv2.INTER_LANCZOS4)

    array_standard = np.zeros(target_shape, 'float32')
    for s in range(target_shape[0]):
        array_standard[s, :, :] = cv2.resize(array_standard_xy[s, :, :], (target_shape[1], target_shape[2]), cv2.INTER_LINEAR)

    return array_standard

def normalize_range(array, lower, upper):
    """
    Normalizes input array in a range.
    -----------
    Parameters:
        array: numpy array, 
            input array to be normalized.
        lower: float, 
            lower bound of normalization.
        upper: float, 
            upper bound of normalization.
    Returns:
        A numpy array with normalized values between lower and upper.
    """
    newarray = array

    newarray += (0 - np.min(newarray))
    if np.amax(newarray) == 0:
        return np.full(array.shape, lower)
    else:
        newarray *= ((upper - lower) / np.amax(newarray))
        newarray += lower
        
    return newarray

def normalize_mean_std(array, mean, std):   
    """
    Normalizes input array with mean and standard deviation.
    -----------
    Parameters:
        array: numpy array, 
            input array to be normalized.
        mean: float, 
            mean of the target distribution.
        std: float, 
            standard deviation of the target distribution.
    Returns:
        A numpy array with normalized values.
    """
 
    return (array-mean)/std


def normalization_custom(img, apply_histeq=True, apply_clahe=True, adjust_gamma=True, gamma_coeff=0.6):
    """
    Normalization test: normalize the input 2D image with a custom normalization technique.
    -----------
    Parameters:
        img: numpy array,
            the image to be normalized
        apply_histeq: bool, Optional,
            apply or not the histogram equalization
        apply_clahe: bool, Optional,
            apply or not the CLAHE
        adjust_gamma: bool, Optional,
            adjust or not the gamma
        gamma_coeff: float, Optional,
            gamma adjustment coefficient
    Returns:
        The normalized input as numpy array inside the interval [-0.5, 0.5]
    """
    if np.min(img) == np.max(img): # if the image is empty return all zeros
        return np.zeros(img.shape)
    out = normalize_range(img, -1250, 250) # bring HU in a sweet intervall for lung cts
    if apply_histeq:
        out = histeq(out, coeff=10) # equalize histogram in the intervall [-1250, 250]
    out = paper_normalization(out) # apply the normalization presented in the original paper
    out += 0.5 # bring in [0,1]
    if adjust_gamma:
        out = exposure.adjust_gamma(out, gamma_coeff)   
    if apply_clahe:
        clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8)) # clahe to enhance features
        out = (out*255).astype(np.uint8) # for clahe images must be in [0, 255]
        out = clahe.apply(out) # apply clahe
        out = out.astype(np.float32) / 255 # bring images back to [0,1]
    return out


def paper_normalization(img, in_range=None): 
    """
    Normalize the input 2D image as suggested by a paper, sees the report for more details.
    -----------
    Parameters:
        img: numpy array,
            the image to be normalized
        in_range: str, Optional,
            shifts to the interval [0, 1] if '0_1', [0, 255] if '0_255', keeps in [-0.5, 0.5] if None
    Returns:
        The normalized input as numpy array
    """
    data_array = np.array(img)
    min, max = np.min(img), np.max(img)
    ww = abs(min) + abs(max)
    if ww == 0:
        return np.full(img.shape, -0.5)
    wc = (min + max)/2
    data_array = data_array - wc
    data_array = data_array / ww   
    if in_range == '0_1':
        data_array += 0.5
    elif in_range == '0_255':
        data_array = (data_array + 0.5) * 255
    return data_array

def normalization_best(img):
    """
    Normalizes the input 2D image. Inspired by paper_normalization, sees the report for more details.
    -----------
    Parameters:
        img: numpy array,
            the image to be normalized
    Returns:
        The normalized input as numpy array inside the interval [0, 255]
    """
    data_array = normalize_range(img, -1250, 250)
    min, max = np.min(img), np.max(img)
    ww = abs(min) + abs(max)
    if ww == 0:
        return np.full(img.shape, 0)
    wc = (min + max)/2
    data_array = data_array - wc
    data_array = data_array / ww
    data_array += 0.5
    data_array *= 255
    return data_array 

def comp_h(im, n_bins):
    """
    Computes histogram
    -----------
    Parameters:
        im: uint8 numpy array,
            image of which the histogram has to be computed
        n_bins: int,
            number of bins
    Returns:
        The image histogram
    """
    s1, s2 = im.shape
    h = [0.0] * (n_bins+1)
    for i in range(s1):
        for j in range(s2):
            h[im[i, j]]+=1
    return np.array(h)

def histeq(image, coeff=100):
    """
    Histogram equalization for images in ranges different from [0,255]
    -----------
    Parameters:
        image: numpy array,
            the image to be equalized
        coeff: int, Optional,
            coefficient multiplied to the image to keep decimal approximations
    Returns:
        The equalized image
    """
    image *= coeff
    prev_min = np.abs(np.min(image))
    image += prev_min
    image = image.astype(np.uint32)
    n_bins = np.max(image)
    h = comp_h(image, n_bins)
    s1, s2 = image.shape
    cdf = np.cumsum(h)
    cdf_min = np.amin(cdf)
    val = (cdf-cdf_min)/((s1*s2)-cdf_min)
    y = np.uint32(n_bins * val)
    Y = np.zeros_like(image)
    for i in range(0, s1): 
        for j in range(0, s2):
            Y[i, j] = y[image[i, j]]
    Y = Y.astype(np.float32) - prev_min
    Y /= coeff
    return Y

def equalize(x, mode):
    """
    Equalizes the input 2D image.
    -----------
    Parameters:
        x: uint8 numpy array,
            the image to be equalized
        mode: str,
            the equalization method
    Returns:
        The equalized images as uin8 numpy array
    """
    if mode=='histeq':
        out = cv2.equalizeHist(x)
    elif mode == 'clahe':
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        out = clahe.apply(x)
    elif mode == 'clahe+histeq':
        # CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        out = clahe.apply(x)
        # Hist eq
        out = cv2.equalizeHist(out)
    elif mode == 'histeq+clahe':
        # Hist eq
        out = cv2.equalizeHist(x)
        # CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        out = clahe.apply(out)
    return out

In [10]:
#---------------------------------------NORMALIZATION APPROACH---------------------------------------------------------

# Normalization and Equalization Handler
def norm_and_eq(x, mean=None, std=None):
    if norm_method == 'None':
        out = np.array(x)
    elif norm_method == 'custom':
        out = normalization_custom(x, apply_histeq=True, apply_clahe=True, adjust_gamma=True, gamma_coeff=0.6)
    elif norm_method == 'mean_std':
        out = normalize_mean_std(x, mean, std)
    elif norm_method == 'as_paper':
        out = paper_normalization(x)
    elif norm_method == 'as_paper_in_0_1':
        out = paper_normalization(x, in_range='0_1')
    elif norm_method == 'as_paper_in_0_255':
        out = paper_normalization(x, in_range='0_255')
    elif norm_method == 'in_range':
        out = normalize_range(x, lower, higher)
    elif norm_method == 'adjust_gamma':
        out = normalize_range(x, 0, 1)
        out = exposure.adjust_gamma(out, 0.6)   
    elif norm_method == 'sharp':
        out = normalize_range(x, 0, 255).astype(np.uint8)
        kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
        out = cv2.filter2D(out, -1, kernel)
        out = equalize(out, 'clahe+histeq').astype(np.float32)/255
    elif norm_method == 'best' or norm_method == 'as_colab':
        out = normalization_best(x)     
    if eq_method != 'None':
        out = equalize(out.astype(np.uint8), eq_method).astype(np.float32)/255
    return out

## Preprocess dataset

In [11]:
def rescale_one_ct(scan):
    """
    Rescales the input .nii.gz ct-scan volume.
    -----------
    Parameters:
        scan: nibabel object,
            the ct-scan
    Returns:
        The rescaled ct-scan volume as numpy array
    """    
    resolution = scan.header.get_zooms()
    return rescale_to_standard(scan.get_fdata(), resolution, target_resolution, target_shape)

def rescale_one_gt(scan):
    """
    Rescales the input .nii.gz mask volume.
    -----------
    Parameters:
        scan: nibabel object,
            the ct-scan mask
    Returns:
        The rescaled mask volume as numpy array
    """    
    resolution = scan.header.get_zooms()
    mask = rescale_to_standard(scan.get_fdata(), resolution, target_resolution, target_shape)
    mask = mask.round()
    return mask

def preprocess_one_patient(ct_scan, mask):
    """
    Preprocess one patient by rescaling the ct-scan and relative mask.
    -----------
    Parameters:
        ct_scan: nibabel object,
            the ct-scan,
        mask: nibabel_object,
            the ct-scan mask
    Returns:
        The rescaled ct-scan volume and relative mask as numpy arrays
    """    
    if isinstance(ct_scan, str):
        ct_scan = nib.load(ct_scan)
    if isinstance(mask, str):
        mask = nib.load(mask)
    # Process CT
    ct = rescale_one_ct(ct_scan)
    # Process GT
    gt = rescale_one_gt(mask)
    return ct, gt

def preprocess_all(dataset, ct_dst, gt_dst, n_of_patients=-1):
    """
    Preprocess all the patients inside a dict by rescaling the ct-scan and relative mask, 
    equalizing the ct-scan and saving each slice as .npy
    -----------
    Parameters:
        dataset: dict,
            dict with all the dataset paths
        ct_dst: str,
            save path for ct-scans
        gt_dst: str,
            save path for masks
        n_of_patients: int, Optional,
            number of patients to be processed, default value -1 processes all the patients
    """   
    # Prepare paths
    for split in ['training', 'validation']:
        all_ct_paths = dataset['image'][split]
        all_ct_paths.sort(key=natural_keys)
        all_gt_paths = dataset['mask'][split]
        all_gt_paths.sort(key=natural_keys)
        # Check couples ct-gt
        start_time = time.time()
        for ct_path, gt_path in zip(all_ct_paths, all_gt_paths):
            if ct_path.split("/")[-1] != gt_path.split("/")[-1]:
                print("CT and GT pahts have different filenames inside")
                break  
        # Preprocess   
        iter = 1
        for ct_path, gt_path in zip(all_ct_paths, all_gt_paths):
            patient_time = time.time()
            ct, gt = preprocess_one_patient(ct_path, gt_path)
            new_ct_path = ct_dst + split + '/' 
            new_gt_path = gt_dst + split + '/' 
            save_data(ct, new_ct_path, ct_path.split("/")[-1].split(".")[0])
            save_data(gt, new_gt_path, gt_path.split("/")[-1].split(".")[0])
            iter+=1
            if iter == n_of_patients:
                break
            print("{} \t Done in {}".format(ct_path.split("/")[-1], str(datetime.timedelta(seconds=time.time()-patient_time)).split('.')[0]))
    print('Dataset processed in {}'.format(str(datetime.timedelta(seconds=time.time()-start_time)).split('.')[0]))

def save_data(img, dst, name, histeq=None):
    """
    Apply equalization and save each volume's slice as .npy to the destination folder.
    -----------
    Parameters:
        img: numpy array,
            volume to be processed
        dst: str,
            destination save path
        name: str,
            name of the folder containing all the slices
        histeq: str, Optional,
            histogram equalization method to be applied
    """
    savepath = dst + name +'/'
    os.makedirs(savepath + 'x/', exist_ok=True)
    os.makedirs(savepath + 'y/', exist_ok=True)
    os.makedirs(savepath + 'z/', exist_ok=True)
    for idx in range(img.shape[0]):
        x_slice = img[idx, :, :]
        y_slice = img[:, idx, :]
        z_slice = img[:, :, idx]
        np.save(savepath + 'x/' + str(idx) + '.npy', x_slice)
        np.save(savepath + 'y/' + str(idx) + '.npy', y_slice)
        np.save(savepath + 'z/' + str(idx) + '.npy', z_slice)
        
def prepare_dirs(root):
    """
    Prepare directories to store images for training and validation
    -----------
    Parameters:
        root: str,
            main folder path
    """
    if not os.path.isdir(root):
        for folder in ['image/', 'mask/']:
            for subfolder in ['training/', 'validation/']:
                os.makedirs(root + folder + subfolder, exist_ok=True)          

# If the flag is True, rescale, equalize and store as .npy the datasets
if preprocess_dataset:
    print("Processing Zenodo")
    prepare_dirs(zenodo_root_proc)
    preprocess_all(zenodo, zenodo_image_root_proc, zenodo_mask_root_proc)
    print("\n\nProcessing Challenge")
    prepare_dirs(challenge_root_proc)
    preprocess_all(challenge, challenge_image_root_proc, challenge_mask_root_proc)

## Data loader

In [12]:
class orthogonal_rotation():
    """
    Class to apply orthogonal rotaton to images with a random angle picked from a list of int
    """
    def __init__(self):
        self.angles = [90, 180, 270]

    def __call__(self, x):
        """
        Augmentation technique, rotate with a random angle picked from a list of int
        -----------
        Parameters:
            x: Tensor,
                image tensor to be augmented
        Returns:
            The rotated image tensor
        """
        angle = random.choice(self.angles)
        return transforms.functional.rotate(x, angle)
    
class COVID_dataset():
    """
    Class to handle datasets
    -----------
    Parameters:
        dataset: dict or list of dict,
            dataset dict with all the relative paths
        dts_type: str or list of str,
            defining the dataset split folder, 'training' or 'validation'
        aug: bool,
            apply or not augmentation
        shape: tuple,
            images target shape
        size: int,
            number of samples to work with
    """   
    def __init__(self, dataset, dts_type, aug, shape, size=None):
        self.dataset = dataset
        self.dts_type = dts_type
        self.aug = aug
        self.x = []
        self.y = []
        self.shape = shape
        self.mean = 0
        self.std = 0
        # Handle unique dataset
        if not isinstance(self.dataset, list):
            for idx in range(len(self.dataset['image'][self.dts_type])):
                for axis in ('/x/', '/y/', '/z/'):
                    if axis == '/x/':
                        n_samp = self.shape[0]
                    elif axis == '/y/':
                        n_samp = self.shape[1]
                    elif axis == '/z/':
                        n_samp = self.shape[2]   
                    for slice_n in range(n_samp):
                        self.x.append(self.dataset['image'][self.dts_type][idx] + axis + str(slice_n) + '.npy')
                        self.y.append(self.dataset['mask'][self.dts_type][idx] + axis + str(slice_n) + '.npy')   
            self.mean += self.dataset['mean']
            self.std += self.dataset['std']
                        
        # Handle more datasets
        else:
            for subset in self.dataset:
                for idx in range(len(subset['image'][self.dts_type])):
                    for axis in ('/x/', '/y/', '/z/'):
                        if axis == '/x/':
                            n_samp = self.shape[0]
                        elif axis == '/y/':
                            n_samp = self.shape[1]
                        elif axis == '/z/':
                            n_samp = self.shape[2]   
                        for slice_n in range(n_samp):
                            self.x.append(subset['image'][self.dts_type][idx] + axis + str(slice_n) + '.npy')
                            self.y.append(subset['mask'][self.dts_type][idx] + axis + str(slice_n) + '.npy')  
                self.mean += subset['mean']
                self.std += subset['std']
            self.mean /= len(self.dataset)
            self.std /= len(self.dataset)


        if len(self.x)!=len(self.y): raise SystemError('Problem with Img and Gt, no same size')
            
        self.x.sort(key=natural_keys)
        self.y.sort(key=natural_keys)
        
        if size is not None:
            self.x = self.x[:size]
            self.y = self.y[:size]

        self.info = len(self.x)

    def __len__(self):
        return self.info
  

    def augmentation2D(self, scan, mask):
        """
        Apply augmentation
        -----------
        Parameters:
            scan: Tensor,
                ct-scan tensor
            mask: Tensor,
                mask tensor
        Returns:
            the augmented ct and mask tensors
        """  
        #methods
        rand_crop = transforms.RandomResizedCrop(512, scale=(0.5, 0.5), ratio=(1.0, 1.0), interpolation=transforms.InterpolationMode.NEAREST) 
        rand_ho_flip = transforms.RandomHorizontalFlip(p=1.0)
        rand_ve_flip = transforms.RandomVerticalFlip(p=1.0)
        rand_rotate = transforms.RandomRotation(5, interpolation=transforms.InterpolationMode.NEAREST, expand=False, center=None, fill=0)
        rand_orth_rotate = orthogonal_rotation()
        
        if torch.sum(mask) == 0:
            return scan, mask
        else: 
            w,h = scan.shape[1:]
            output = torch.zeros([2,w,h])
            output[0,:,:] = scan
            output[1,:,:] = mask
            aug_method = np.random.randint(0,5)
            # random horizontal flip
            if aug_method == 0:
                output = rand_ho_flip(output)
            # random vertical flip
            elif aug_method == 1:
                output = rand_ve_flip(output)
            # random crop
            elif aug_method == 2:
                output = rand_crop(output)
            # random rotate
            elif aug_method == 3:
                output = rand_rotate(output)
            # random orthogonal rotate
            elif aug_method == 4:
                output = rand_orth_rotate(output)

            scan_output = output[0,:,:].unsqueeze(0)
            mask_output = output[1,:,:].unsqueeze(0)

            return scan_output, mask_output
        
    
    def __load_nii__(self, filepath):
        return nib.load(filepath)
    
    
    def __getitem__(self, index=None):    
        if index is None:
            index = np.random.randint(0, self.info)
        
        # Load Image
        ct_scan = np.load(self.x[index])
        mask = np.load(self.y[index]) 
        
        # Normalize and Equalize
        ct_scan = norm_and_eq(ct_scan, self.mean, self.std)

        ct_scan = torch.from_numpy(ct_scan).unsqueeze(0)
        mask = torch.from_numpy(mask).unsqueeze(0)
    
        # Augmentation
        if self.aug:
            ct_scan, mask = self.augmentation2D(ct_scan, mask)
        
        return ct_scan.float(), mask.float()

    def get_file(self, filename):    
        index = [i for i, s in enumerate(self.x) if filename == s][0]
        return self.__getitem__(index)
    

def init_train_test_loader(train_dataset, train_folder, valid_dataset, valid_folder, num_workers, size_train=None, size_valid=None):
    """
    Create data loaders for training and validation splits
    -----------
    Parameters:
        train_dataset: dict or list of dict,
            dataset dict with all the relative paths used for training
        train_folder: str or list of str,
            defining the dataset split folder, 'training' or 'validation'
        valid_dataset: dict or list of dict,
            dataset dict with all the relative paths used for validation
        valid_folder: str or list of str,
            defining the dataset split folder, 'training' or 'validation'
        num_workers: int,
            number of workers for each data loader
        size_train: int, Optional,
            number of samples for the training split
        size_valid: int, Optional,
            number of samples for the validation split
    Returns:
        Dataloaders for both training and validation
    """
    # Training Data loader
    training_Dataset = COVID_dataset(train_dataset, train_folder, aug=True, shape=target_shape, size=size_train)
    training_DataLoader = DataLoader(training_Dataset, batch_size=training_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    # Test Data loader
    test_Dataset = COVID_dataset(valid_dataset, valid_folder, aug=False, shape=target_shape, size=size_valid)
    test_DataLoader = DataLoader(test_Dataset, batch_size=validation_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return training_DataLoader, test_DataLoader

## Test data correctness

In [13]:
def dataset_processed_test(dataset_proc, dataset_root, filename=None):
    """
    Quick test to check dataset format
    -----------
    Parameters:
        dataset_proc: dict,
            processed dataset dict with all the relative paths
        dataset_root: dict
            original dataset dict with all the relative paths
        filename: str, Optional,
            dataset sample on which perform the test, if None it is randomly picked
    """
    # nii.gz
    if filename == None:
        idx = np.random.randint(0, len(dataset_proc['image']['training']))
        filename = dataset_proc['image']['training'][idx]
        filename_gt = dataset_proc['mask']['training'][idx]
    
    slice_id = target_shape[2] // 2
    
    scan = load_scan(dataset_root + 'image/training/' + filename.split("/")[-1] + '.nii.gz')
    mask = load_scan(dataset_root + 'mask/training/' + filename_gt.split("/")[-1] + '.nii.gz')
    
    # Original slice
    slice_mask = slice_image(mask, mask.shape[2]//2)
    slice_scan = slice_image(scan, scan.shape[2]//2)
    # Rescale and normalization
    scan_res, mask_res = preprocess_one_patient(scan, mask)
    slice_mask_res = mask_res[:,:,slice_id]
    slice_scan_res = scan_res[:,:,slice_id]

    # Dataloader
    training_Dataset = COVID_dataset(dataset_proc, 'training', aug=True, shape=target_shape)
    scan_aug, mask_aug = training_Dataset.get_file(filename=filename + '/z/256.npy')
    scan_aug = scan_aug.squeeze().numpy()
    mask_aug = mask_aug.squeeze().numpy()     

    images = [slice_scan, slice_scan_res, scan_aug, slice_mask, slice_mask_res, mask_aug]
    
    f, axarr = plt.subplots(1, 6, figsize=(20,20))
    for axis, i in zip(['CT-scan', 'CT Preprocessed', 'CT Augmented', 'Mask slice', 'Mask Normalized', 'Mask augmented'], range(6)):
        axarr[i].imshow(images[i], cmap='gray')
        axarr[i].set_xlabel(axis)
    plt.show()


    mask_values = np.unique(mask_aug)
    scan_min_check = 'OK' if np.min(scan_aug) >= 0 else 'NO'
    scan_max_check = 'OK' if (np.max(scan_aug) <= 1 and np.max(scan_aug) > 0) else 'NO'
    mask_min_check = 'OK' if np.min(mask_aug) == 0 else 'NO'
    mask_max_check = 'OK' if np.max(mask_aug) == 2  else 'NO'
    n_val_mask = 'OK' if len(mask_values) == 3 else 'NO'

    print('\nTest for DataLoader')
    print('Check:\nAugmented ct-scan:\n\t\t Expected min possible val: 0\t Actual min val: {}\t{}\n\t\t Expected max possible val: 1\t Actual max val: {:.2f}\t{}\n'.format(np.min(scan_aug), scan_min_check, np.max(scan_aug), scan_max_check))
    print('Check:\nAugmented mask:\n\t\t Expected min possible val: 0\t Actual min val: {}\t{}\n\t\t Expected max possible val: 2\t Actual max val: {}\t{}'.format(np.min(mask_aug), mask_min_check, np.max(mask_aug), mask_max_check))
    print('Check:\n\t\tPossible mask values: 3\t\t Actual val: {}\t{}'.format(mask_values,n_val_mask))
    
print('ZENODO')
#dataset_processed_test(zenodo_proc, zenodo_root)
print('\n\n\nCHALLENGE')
#dataset_processed_test(challenge_proc, challenge_root)

ZENODO



CHALLENGE


# Architectures

## TransUNet

### Configurations

In [14]:
def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1

    config.classifier = 'seg'
    config.representation_size = None
    config.resnet_pretrained_path = None
    config.pretrained_path = 'models/vit_checkpoint/imagenet21k+imagenet2012/ViT-B_16.npz'
    config.patch_size = 16

    config.decoder_channels = (256, 128, 64, 16)
    config.n_classes = n_classes
    config.activation = 'softmax'
    return config


def get_testing():
    """Returns a minimal configuration for testing."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 1
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 1
    config.transformer.num_heads = 1
    config.transformer.num_layers = 1
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    return config

def get_r50_b16_config():
    """Returns the Resnet50 + ViT-B/16 configuration."""
    config = get_b16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = 'seg'
    config.pretrained_path = 'models/vit_checkpoint/imagenet21k+imagenet2012/R50+ViT-B_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = n_classes
    config.n_skip = 3
    config.activation = 'softmax'

    return config


def get_b32_config():
    """Returns the ViT-B/32 configuration."""
    config = get_b16_config()
    config.patches.size = (32, 32)
    config.pretrained_path = 'models/vit_checkpoint/imagenet21k+imagenet2012/ViT-B_32.npz'
    return config


def get_l16_config():
    """Returns the ViT-L/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 1024
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 4096
    config.transformer.num_heads = 16
    config.transformer.num_layers = 24
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.representation_size = None

    # custom
    config.classifier = 'seg'
    config.resnet_pretrained_path = None
    config.pretrained_path = 'models/vit_checkpoint/imagenet21k+imagenet2012/ViT-L_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.n_classes = n_classes
    config.activation = 'softmax'
    return config


def get_l32_config():
    """Returns the ViT-L/32 configuration."""
    config = get_l16_config()
    config.patches.size = (32, 32)
    return config


def get_h14_config():
    """Returns the ViT-L/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (14, 14)})
    config.hidden_size = 1280
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 5120
    config.transformer.num_heads = 16
    config.transformer.num_layers = 32
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None

    return config

CONFIGS = {
    'ViT-B_16': get_b16_config(),
    'ViT-B_32': get_b32_config(),
    'ViT-L_16': get_l16_config(),
    'ViT-L_32': get_l32_config(),
    'ViT-H_14': get_h14_config(),
    'R50+ViT-B_16': get_r50_b16_config(),
    'testing': get_testing(),
}

### Layers

In [15]:
def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)

def swish(x):
    return x * torch.sigmoid(x)


class StdConv2d(nn.Conv2d):

    def forward(self, x):
        w = self.weight
        v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
        w = (w - m) / torch.sqrt(v + 1e-5)
        return F.conv2d(x, w, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)


def conv3x3(cin, cout, stride=1, groups=1, bias=False):
    return StdConv2d(cin, cout, kernel_size=3, stride=stride,
                     padding=1, bias=bias, groups=groups)


def conv1x1(cin, cout, stride=1, bias=False):
    return StdConv2d(cin, cout, kernel_size=1, stride=stride,
                     padding=0, bias=bias)


class PreActBottleneck(nn.Module):
    """Pre-activation (v2) bottleneck block.
    """

    def __init__(self, cin, cout=None, cmid=None, stride=1):
        super().__init__()
        cout = cout or cin
        cmid = cmid or cout//4

        self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
        self.conv1 = conv1x1(cin, cmid, bias=False)
        self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
        self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!
        self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
        self.conv3 = conv1x1(cmid, cout, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if (stride != 1 or cin != cout):
            # Projection also with pre-activation according to paper.
            self.downsample = conv1x1(cin, cout, stride, bias=False)
            self.gn_proj = nn.GroupNorm(cout, cout)

    def forward(self, x):

        # Residual branch
        residual = x
        if hasattr(self, 'downsample'):
            residual = self.downsample(x)
            residual = self.gn_proj(residual)

        # Unit's branch
        y = self.relu(self.gn1(self.conv1(x)))
        y = self.relu(self.gn2(self.conv2(y)))
        y = self.gn3(self.conv3(y))

        y = self.relu(residual + y)
        return y

    def load_from(self, weights, n_block, n_unit):
        conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
        conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
        conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)

        gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
        gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])

        gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
        gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])

        gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
        gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])

        self.conv1.weight.copy_(conv1_weight)
        self.conv2.weight.copy_(conv2_weight)
        self.conv3.weight.copy_(conv3_weight)

        self.gn1.weight.copy_(gn1_weight.view(-1))
        self.gn1.bias.copy_(gn1_bias.view(-1))

        self.gn2.weight.copy_(gn2_weight.view(-1))
        self.gn2.bias.copy_(gn2_bias.view(-1))

        self.gn3.weight.copy_(gn3_weight.view(-1))
        self.gn3.bias.copy_(gn3_bias.view(-1))

        if hasattr(self, 'downsample'):
            proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
            proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
            proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])

            self.downsample.weight.copy_(proj_conv_weight)
            self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
            self.gn_proj.bias.copy_(proj_gn_bias.view(-1))

class ResNetV2(nn.Module):
    """Implementation of Pre-activation (v2) ResNet mode."""

    def __init__(self, block_units, width_factor):
        super().__init__()
        width = int(64 * width_factor)
        self.width = width

        self.root = nn.Sequential(OrderedDict([
            ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
            ('gn', nn.GroupNorm(32, width, eps=1e-6)),
            ('relu', nn.ReLU(inplace=True)),
            # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        ]))

        self.body = nn.Sequential(OrderedDict([
            ('block1', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
                ))),
            ('block2', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
                ))),
            ('block3', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
                ))),
        ]))

    def forward(self, x):
        features = []
        b, c, in_size, _ = x.size()
        x = self.root(x)
        features.append(x)
        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
        for i in range(len(self.body)-1):
            x = self.body[i](x)
            right_size = int(in_size / 4 / (i+1))
            if x.size()[2] != right_size:
                pad = right_size - x.size()[2]
                assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
                feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
                feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
            else:
                feat = x
            features.append(feat)
        x = self.body[-1](x)
        return x, features[::-1]


class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights


class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = torch.nn.functional.gelu
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class Embeddings(nn.Module):
    """
    Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        self.config = config
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:   # ResNet
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])


    def forward(self, x):
        if self.hybrid:
            x, features = self.hybrid_model(x)
        else:
            features = None
        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2))
        x = x.flatten(2)
        x = x.transpose(-1, -2)  # (B, n_patches, hidden)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings, features


class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/query", "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/key", "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/value", "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/out", "kernel")]).view(self.hidden_size, self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/query", "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/key", "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/value", "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, "MultiHeadDotProductAttention_1/out", "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, "MlpBlock_3/Dense_0", "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, "MlpBlock_3/Dense_1", "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, "MlpBlock_3/Dense_0", "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, "MlpBlock_3/Dense_1", "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, "LayerNorm_0", "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, "LayerNorm_0", "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, "LayerNorm_2", "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, "LayerNorm_2", "bias")]))


class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights


class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output, features = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)
        return encoded, attn_weights, features


class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        bn = nn.BatchNorm2d(out_channels)

        super(Conv2dReLU, self).__init__(conv, bn, relu)


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            skip_channels=0,
            use_batchnorm=True,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, x, skip=None):
        x = self.up(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class SegmentationHead(nn.Sequential):

    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)


class DecoderCup(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        head_channels = 512
        self.conv_more = Conv2dReLU(
            config.hidden_size,
            head_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=True,
        )
        decoder_channels = config.decoder_channels
        in_channels = [head_channels] + list(decoder_channels[:-1])
        out_channels = decoder_channels

        if self.config.n_skip != 0:
            skip_channels = self.config.skip_channels
            for i in range(4-self.config.n_skip):  # re-select the skip channels according to n_skip
                skip_channels[3-i]=0

        else:
            skip_channels=[0,0,0,0]

        blocks = [
            DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, hidden_states, features=None):
        B, n_patch, hidden = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
        x = hidden_states.permute(0, 2, 1)
        x = x.contiguous().view(B, hidden, h, w)
        x = self.conv_more(x)
        for i, decoder_block in enumerate(self.blocks):
            if features is not None:
                skip = features[i] if (i < self.config.n_skip) else None
            else:
                skip = None
            x = decoder_block(x, skip=skip)
        return x


class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.decoder = DecoderCup(config)
        self.segmentation_head = SegmentationHead(
            in_channels=config['decoder_channels'][-1],
            out_channels=config['n_classes'],
            kernel_size=3,
        )
        self.config = config

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)
        x = self.decoder(x, features)
        logits = self.segmentation_head(x)
        return logits

    def load_from(self, weights):
        with torch.no_grad():

            res_weight = weights
            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))

            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])

            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            elif posemb.size()[1]-1 == posemb_new.size()[1]:
                posemb = posemb[:, 1:]
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                ntok_new = posemb_new.size(1)
                if self.classifier == "seg":
                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]
                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = posemb_grid
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            # Encoder whole
            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
                gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
                gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(res_weight, n_block=bname, n_unit=uname)

In [16]:
def init_transunet(vit_name, n_skip, vit_patches_size, pretrained=True):
    config_vit = CONFIGS[vit_name]
    config_vit.n_classes = n_classes
    config_vit.n_skip = n_skip
    if vit_name.find('R50') != -1:
        config_vit.patches.grid = (int(target_shape[0] / vit_patches_size), int( target_shape[1] / vit_patches_size))
    model = VisionTransformer(config_vit, target_shape[0], num_classes=config_vit.n_classes).float().to(device)
    if pretrained:
        model.load_from(weights=np.load(config_vit.pretrained_path))
    print(config_vit)
    return model

## Init model

In [17]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
  
def save_checkpoint(model, name):
    """
    Saves a model
    """
    if '_best' in name:
        folder = name.split("_best")[0]
    elif '_checkpoint' in name:
        folder = name.split("_checkpoint")[0]
    save_path = model_save_path + folder + '/'
    if not os.path.isdir(save_path):
        os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), save_path + name)
    
def load_model(model_config):
    """
    Loads a model
    -----------
    Parameters:
        model_config: dict,
            dictionary with all the information needed about the model
    Returns:
        The loaded model and its name
    """
    model_name = model_config['model_name'] + '_' + model_config['version']
    if norm_method == 'in_range':
        model_name = model_config['eq_method'] + '_' + model_config['norm_method'] +'['+str(model_config['lower'])+','+str(model_config['higher'])+']_' + model_name
    else:
        model_name = model_config['eq_method'] + '_' +  model_config['norm_method'] + '_' + model_name  
    if model_config['batch_norm'] == 'inplace':
        model_name = 'bn_inplace_' + model_name
    if model_config['attention'] == 'scse':
        model_name = 'scse_' + model_name
    if model_config['fine_tuning']:
        if not model_config['freeze_encoder']:
            model_name = 'all_trainable_finetuning_' + model_name
        else:
            model_name = 'encoder_freezed_finetuning_' + model_name
    model_name = model_config['prefix'] + model_name
    model_folder = model_name.split('_'+model_config['version'])[0] +'/'
   
    # Init network
    model, _ = init_model(model_config)
    # Load model
    print("\nloading checkpoint for {}..\n".format(model_name))
    model_dict = torch.load(model_config['weights_path'] + model_folder + model_name, map_location=torch.device(device))
    model.load_state_dict(model_dict)
    print("checkpoint loaded\n")
    return model, model_name

# Model
def init_model(model_config=None):
    """
    Build a model
    -----------
    Parameters:
        model_config: dict, Optional
            if None build the model using global variables defined at the beggining,
            otherwise it uses the model_config dictionary with all the information
            needed about the model
    Returns:
        The built model and its name
    """
    # Built for training using global variables
    if model_config == None:
        if approach == 'transformer':
            model = init_transunet(vit_name, n_skip, vit_patches_size)
            model_name = model_prefix_name + criterion.__class__.__name__ + '_' + model.__class__.__name__
        elif approach == '2.5D':
            if decoder_name == 'DeepLabV3':
                model = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=n_classes, activation=None, encoder_depth=5, decoder_channels=256)
            elif decoder_name == 'MAnet':
                model = smp.MAnet(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=n_classes, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16])
            elif decoder_name == 'Unet':
                model = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=n_classes, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16],
                                 decoder_attention_type=attention, decoder_use_batchnorm=batch_norm)

            if encoder_weights == None:
                model_name = model_prefix_name + 'scratch_' + encoder_name + '_' + criterion.__class__.__name__ + '_' + model.__class__.__name__
            else:
                model_name = model_prefix_name + encoder_weights + '_' + encoder_name + '_' + criterion.__class__.__name__ + '_' + model.__class__.__name__
    else:  
        if model_config['approach'] == 'transformer':
            model = init_transunet(model_config['vit_name'], model_config['n_skip'], model_config['vit_patches_size'])
        elif model_config['approach'] == '2.5D':
            if model_config['decoder_name'] == 'DeepLabV3':
                model = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=n_classes, activation=None, encoder_depth=5, decoder_channels=256)
            elif model_config['decoder_name'] == 'MAnet':
                model = smp.MAnet(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=n_classes, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16])
            elif model_config['decoder_name'] == 'Unet':
                model = smp.Unet(encoder_name=model_config['encoder_name'], encoder_weights=model_config['encoder_weights'], in_channels=1, classes=n_classes, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16],
                                decoder_attention_type=model_config['attention'], decoder_use_batchnorm=model_config['batch_norm'])
        model_name = None
    print("Model built") 
    total_params = count_parameters(model)
    trainable_params = count_trainable_parameters(model)
    print('\nTotal params: {}\nTrainable params: {}\nNon-Trainable params: {}'.format(total_params, trainable_params, (total_params - trainable_params)))

    model.float().to(device)
    return model, model_name

# Losses

In [18]:
class Sobel(nn.Module):
    """
    Mean Absolute Sobel Error class
    """
    def __init__(self):
        super(Sobel, self).__init__()
        self.edge_conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1, bias=False)
        edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
        edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
        edge_k = np.stack((edge_kx, edge_ky))
        edge_k = torch.from_numpy(edge_k).float().view(2, 1, 3, 3)
        self.edge_conv.weight = nn.Parameter(edge_k)

        self.kernel = np.array([ [1, 1, 1],
                  [1, 1, 1],
                  [1, 1, 1] ], dtype=np.float32)
        self.kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(self.kernel, 0), 0)).to(device) # size: (1, 1, 3, 3)

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, inputs, labels):
        # Sobel preds
        sobel_inputs = self.compute_sobel(inputs)
        inputs_x = torch.clamp(torch.nn.functional.conv2d(sobel_inputs[:,0,:,:].unsqueeze(1), self.kernel_tensor, padding=(1, 1)), 0, 1)
        inputs_y = torch.clamp(torch.nn.functional.conv2d(sobel_inputs[:,1,:,:].unsqueeze(1), self.kernel_tensor, padding=(1, 1)), 0, 1)
        # Sobel gt
        sobel_labels = self.compute_sobel(labels)
        labels_x = torch.clamp(torch.nn.functional.conv2d(sobel_labels[:,0,:,:].unsqueeze(1), self.kernel_tensor, padding=(1, 1)), 0, 1)
        labels_y = torch.clamp(torch.nn.functional.conv2d(sobel_labels[:,1,:,:].unsqueeze(1), self.kernel_tensor, padding=(1, 1)), 0, 1)

        return torch.abs(inputs_x - labels_x).mean(), torch.abs(inputs_y - labels_y).mean()

    def compute_sobel(self, x):
        out = self.edge_conv(x) 
        out = out.contiguous().view(-1, 2, x.size(2), x.size(3))
        return out


def one_hot2D(labels, num_classes, device=None, dtype=None, eps=1e-6):
    """
    Convert BxHxW image tensor to one hot encoding image tensor with shape BxNxHxW with
    B: batch size, H: height, W: width and N: number of classes
    -----------
    Parameters:
        labels: Tensor,
            BxHxW Tensor with the ground truth labels
        num_classes: int,
            total number of classes
    Returns:
        the one hot encoded tensor
    """
    if not torch.is_tensor(labels):
        raise TypeError("Input labels type is not a torch.Tensor. Got {}"
                      .format(type(labels)))
    if not len(labels.shape) == 3:
        raise ValueError("Invalid depth shape, we expect BxHxW. Got: {}"
                        .format(labels.shape))
    if not labels.dtype == torch.int64:
        raise ValueError(
          "labels must be of the same dtype torch.int64. Got: {}" .format(
              labels.dtype))
    if num_classes < 1:
        raise ValueError("The number of classes must be bigger than one."
                        " Got: {}".format(num_classes))
    batch_size, height, width = labels.shape
    one_hot = torch.zeros(batch_size, num_classes, height, width,
                        device=device, dtype=dtype)
    return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps


class DiceLoss2D(nn.Module):
    """
    Dice Loss on 2D image tensor
    -----------
    Parameters:
     weighted: bool,
         Flag to apply or not weights to the labels
    """
    def __init__(self, weighted) -> None:
        super(DiceLoss2D, self).__init__()
        self.eps: float = 1e-6
        self.weighted = weighted
        if self.weighted:
            self.__class__.__name__ = 'WeightedDiceLoss'
        else:
            self.__class__.__name__ = 'DiceLoss'

    def forward(self, input, target):
        """
        Computes dice loss
        -----------
        Parameters:
            input: Tensor,
                predictions tensor
            target: Tensor,
                ground truth masks tensor
        Returns:
            The dice loss value
        """
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                          .format(input.shape))
        if not input.shape[-2:] == target.shape[-2:]:
            raise ValueError("input and target shapes must be the same. Got: {}"
                          .format(input.shape, input.shape))
        if not input.device == target.device:
            raise ValueError(
                "input and target must be in the same device. Got: {}" .format(
                    input.device, target.device))
        # compute softmax over the classes axis
        input_soft = torch.nn.functional.softmax(input, dim=1)
        input_soft = input_soft.view(input_soft.shape[0], n_classes, -1) # (B, Nclasses, HxW)

        # create the labels one hot tensor
        target_one_hot = one_hot2D(target, num_classes=input.shape[1],
                                  device=input.device, dtype=input.dtype)
        target_one_hot = target_one_hot.view(target_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxW)

        # compute the actual dice score factors
        intersection = torch.sum(input_soft * target_one_hot, 2)
        cardinality = torch.sum(input_soft + target_one_hot, 2) # (B)

        # count n element for each class
        if self.weighted:
            counts = torch.sum(target_one_hot, dim=2)
            weights = torch.as_tensor(1. / (counts ** 2), dtype=float)
            weights = torch.where(torch.isfinite(weights), weights, self.eps) # (B, Nclasses)
            # apply weights
            intersection = torch.sum(weights*intersection, axis=-1)
            cardinality = torch.sum(weights*cardinality, axis=-1)
            # compute dice score
            dice_score = 1 - 2. * intersection / cardinality
            dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))
        else:
            # compute dice score
            dice_score = 1 - 2. * intersection / (cardinality + self.eps)
            dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))

        return torch.mean(dice_score)

class DiceLoss2D_CE(nn.Module):
    """
    Dice Loss on 2D image tensor with Cross Entropy Loss
    -----------
    Parameters:
     weighted: bool,
         Flag to apply or not weights to the labels
    """
    def __init__(self, weighted) -> None:
        super(DiceLoss2D_CE, self).__init__()
        self.eps: float = 1e-6
        self.weighted = weighted
        self.CE = nn.CrossEntropyLoss()
        if self.weighted:
            self.__class__.__name__ = 'WeightedDiceLoss+CE'
        else:
            self.__class__.__name__ = 'DiceLoss+CE'

    def forward(self, input, target):
        """
        Computes dice loss with cross entropy loss
        -----------
        Parameters:
            input: Tensor,
                predictions tensor
            target: Tensor,
                ground truth masks tensor
        Returns:
            The dice loss + cross entropy loss value
        """
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                              .format(input.shape))
        if not input.shape[-2:] == target.shape[-2:]:
            raise ValueError("input and target shapes must be the same. Got: {}"
                              .format(input.shape, input.shape))
        if not input.device == target.device:
            raise ValueError(
                "input and target must be in the same device. Got: {}" .format(
                    input.device, target.device))
        # compute softmax over the classes axis
        input_soft = torch.nn.functional.softmax(input, dim=1)
        input_soft = input_soft.view(input_soft.shape[0], n_classes, -1) # (B, Nclasses, HxW)

        # create the labels one hot tensor
        target_one_hot = one_hot2D(target, num_classes=input.shape[1],
                                  device=input.device, dtype=input.dtype)
        target_one_hot = target_one_hot.view(target_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxW)

        # compute the actual dice score factors
        intersection = torch.sum(input_soft * target_one_hot, 2)
        cardinality = torch.sum(input_soft + target_one_hot, 2) # (B)

        # count n element for each class
        if self.weighted:
            counts = torch.sum(target_one_hot, dim=2)
            weights = torch.as_tensor(1. / (counts ** 2), dtype=float)
            weights = torch.where(torch.isfinite(weights), weights, self.eps) # (B, Nclasses)
            # apply weights
            intersection = torch.sum(weights*intersection, axis=-1)
            cardinality = torch.sum(weights*cardinality, axis=-1)
            # compute dice score
            dice_score = 1 - 2. * intersection / cardinality
            dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))
        else:
            # compute dice score
            dice_score = 1 - 2. * intersection / (cardinality + self.eps)
            dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))

        return torch.mean(dice_score) + self.CE(input, target).item()
    
    
class CE_loss(nn.Module):
    """
    Cross entropy Loss
    -----------
    Parameters:
        weighted: bool,
            Flag to apply or not weights to the labels
        weight: Tensor,
            tensor with number of classes as shape and containing their weights
    """
    def __init__(self, weighted, weight=None) -> None:
        super(CE_loss, self).__init__()
        if not weighted:
            self.__class__.__name__ = 'CrossEntropyLoss'
            self.CE = nn.CrossEntropyLoss()
        else:
            self.__class__.__name__ = 'WeightedCrossEntropyLoss'
            self.CE = nn.CrossEntropyLoss(weight)

    def forward(self, input, target):
        """
        Computes cross entropy loss
        -----------
        Parameters:
            input: Tensor,
                predictions tensor
            target: Tensor,
                ground truth masks tensor
        Returns:
            The cross entropy loss value
        """
        return self.CE(input, target)
    
class FocalLoss(nn.modules.loss._WeightedLoss):
    """
    Focal Loss
    -----------
    Parameters:
        weight: Tensor,
            tensor with number of classes as shape and containing their weights
        gamma: float, Optional
            gamma coefficient
        reduction: str, Optional,
            reduction method for Cross entropy
    """
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight # weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):
        """
        Computes focal loss
        -----------
        Parameters:
            input: Tensor,
                predictions tensor
            target: Tensor,
                ground truth masks tensor
        Returns:
            The focal loss value
        """
        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss
    
def compute_loss(outputs, masks, criterion):
    """
    Function to handle different losses computation.
    -----------
    Parameters:
        outputs: Tensor,
            prediction tensor
        masks: Tensor,
            ground truth masks tensor
        criterion: loss object,
            the loss function class
    Returns:
        The loss value            
    """
    # Handle BCE
    if loss_name == 'BCElogits':
        masks_for_loss = masks.squeeze(1).long() 
        masks_for_loss = one_hot2D(masks_for_loss, n_classes, device, masks_for_loss.dtype) 
        outputs_for_loss = torch.nn.functional.softmax(outputs, dim=1).float() 
        return criterion(outputs_for_loss, masks_for_loss)
    # Handle other losses
    else:
        return criterion(outputs.float(), masks.squeeze(1).long())

# Training

## Aux functions

In [19]:
def pixel_accuracy(pred, mask):
    """
    Computes pixel accuracy.
    -----------
    Parameters:
        pred: Tensor,
            predictions tensor
        mask: Tensor,
            ground truth mask tensor
    Returns:
        The pixel accuracy value
    """
    with torch.no_grad():
        output = torch.argmax(torch.nn.functional.softmax(pred, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy


def mIoU(pred, mask, smooth=1e-10, n_classes=3):
    """
    Computes mean Intersection over Union.
    -----------
    Parameters:
        pred: Tensor,
            predictions tensor
        mask: Tensor,
            ground truth mask tensor
        smooth: float, Optional,
            smoothing factor
        n_classes: int, Optional,
            number of classes
    Returns:
        The pixel accuracy value
    """
    with torch.no_grad():
        pred_mask = torch.argmax(torch.nn.functional.softmax(pred, dim=1), dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)


def get_lr(optimizer):
    """
    Gets the learning rate recommended by the optimizer
    -----------
    Parameters:
        optimizer: optimizer class
    Returns:
        The learning rate
    """
    for param_group in optimizer.param_groups:
        return param_group['lr']

def slice_image(img, axis, slice_n):
    """
    Slices a numpy array image
    -----------
    Parameters:
        img: numpy array,
            the image to be slices
        axis: str,
            axis along which slice the image
        slice_n: int,
            index of the desired slice
    Returns:
        The 2D image slice
    """
    if axis == 'x':
        return img[int(slice_n), :, :]
    elif axis == 'y':
        return img[:, int(slice_n), :]
    elif axis == 'z':
        return img[:, :, int(slice_n)]
    
            
def show_prediction_example(model, device, title='', filename='coronacases_009', slice_n=256):
    """
    Shows prediction example
    -----------
    Parameters:
        model: model object,
            model used to predict
        device: str,
            device on which perform the computations
        title: str, Optional,
            title for the final plot
        filename: str, Optional
            name of the patient to be used as example
        slice_n: int, Optional,
            index of the desired slice
    """    
    idx = [i for i, s in enumerate(zenodo_proc['image']['validation']) if filename in s][0]
    
    test_ct_scan = np.load(zenodo_proc['image']['validation'][idx] + '/z/256.npy')
    test_ct_scan = norm_and_eq(test_ct_scan)
    
    test_gt = np.load(zenodo_proc['mask']['validation'][idx] + '/z/256.npy') 

    test = np.zeros((1, 1, target_shape[0], target_shape[1]), dtype='float32')
    test[0, :, :, :] = test_ct_scan
    test = torch.from_numpy(test).to(device)
    # Predict
    model.eval()
    with torch.no_grad():
        test_pred = model(test)
    test_pred = torch.argmax(torch.nn.functional.softmax(test_pred, dim=1), dim=1).detach().cpu().numpy().squeeze(0)

    # Plot
    f, axarr = plt.subplots(1, 4, figsize=(9,9))
    f.suptitle(title, y=0.62)
    axarr[0].imshow(test_ct_scan, cmap='gray')
    axarr[0].set_xlabel('input')
    axarr[1].imshow(test_gt, cmap='gray', vmin=0, vmax=2)
    axarr[1].set_xlabel('gt')
    axarr[2].imshow(test_pred, cmap='gray', vmin=0, vmax=2)
    axarr[2].set_xlabel('prediction')
    axarr[3].imshow(np.abs(np.subtract(test_gt, test_pred)), cmap='gray', vmin=0, vmax=2)
    axarr[3].set_xlabel('difference')
    plt.show()

def save_history(filepath, history):
    """
    Saves model history
    -----------
    Parameters:
        filepath: str,
            path to store the file
        history: dict,
            model history
    """
    tmp_file = open(filepath +'.pkl', "wb")
    pickle.dump(history, tmp_file)
    tmp_file.close()

def load_history(filepath):
    """
    Loads model history
    -----------
    Parameters:
        filepath: str,
            path to store the file
    Returns:
        The history dictionary
    """
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data

def plot_graph(f, g, f_label, g_label, title):
    """
    Plots a two function graph
    -----------
    Parameters:
        f: list of float,
            first list of y-values to be plotted
        g: list of float,
            second list of y-values to be plotted
        f_label: str,
            label for the first function in the plot
        g_label: str,
            label for the second function in the plot 
        title: str,
            title for the plot
    """
    epochs = range(0,len(f))
    plt.plot(epochs, f, 'b', label=f_label)
    plt.plot(epochs, g, 'orange', label=g_label)
    plt.title(title)
    plt.xlabel('Epochs')
    plt.legend()
    plt.show()

def plot_history(history):
    """
    Plots history graphs of loss, accuracy and mIoU
    -----------
    Parameters:
        history: dict,
            history dictionary
    """
    plot_graph(history['train_loss'], history['val_loss'], 'Training loss', 'Validation loss', 'Training and Validation loss')
    plot_graph(history['train_acc'], history['val_acc'], 'Training acc', 'Validation acc', 'Training and Validation pixel accuracy')
    plot_graph(history['train_miou'], history['val_miou'], 'Training mIoU', 'Validation mIoU', 'Training and Validation mIoU')
    
# Optimizer
def init_optimizer(model):
    """
    Initializes the optimizer.
    -----------
    Parameters:
        model: model class,
            the model to be trained
    """
    if optimizer_name == 'ADAM':
        return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'ADAMW':
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'SGD':
        return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

## Training Config

In [20]:
#--------------------------------------------Configuration Dicts---------------------------------------------------------
# Losses config dict
if 'WDL' in loss_name or 'WCE' in loss_name:
    weighted = True 
else:
    weighted = False
LOSSES = {
    'WDL': DiceLoss2D(weighted),
    'DL': DiceLoss2D(weighted),
    'CE': CE_loss(weighted),
    'BCElogits': nn.BCEWithLogitsLoss(),
    'DL+CE': DiceLoss2D_CE(weighted),
    'WDL+CE': DiceLoss2D_CE(weighted),
    'WCE': CE_loss(weighted, weight=torch.as_tensor([0.15, 0.35, 0.50]).to(device)),
    'FOCAL': FocalLoss(gamma=5),
    'Sobel': Sobel()
}

# Datasets config dict
DATASETS = {
    'zenodo': zenodo_proc,
    'challenge': challenge_proc,
    'zenodo+challenge': [zenodo_proc, challenge_proc]
}

## Training

In [21]:
# Do train if flag is True
if do_train:
    torch.cuda.empty_cache()

    # Some globals
    history = {'train_loss': [], 'val_loss': [],
               'train_miou': [], 'val_miou': [],
               'train_acc': [], 'val_acc': [],
               'lrs': [lr]}
    min_loss = float('inf')
    max_iou = float('-inf')

    # Data Loaders
    chosen_dataset = DATASETS[dataset_name]
    training_DataLoader, test_DataLoader = init_train_test_loader(chosen_dataset, 'training',
                                                                  chosen_dataset, 'validation',
                                                                  num_workers=4, size_valid=512*3*1)

    # Loss
    criterion = LOSSES[loss_name]
  
    if sobel_loss:
        sobel_loss = LOSSES['Sobel'].to(device)

    # Model
    model, model_name = init_model()
    
    # Optimizer
    optimizer = init_optimizer(model)
    
    # Scheduler
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=n_epochs, steps_per_epoch=len(training_DataLoader))

    print("Training: {}\n".format(model_name))
    time.sleep(0.3)
    for epoch in range(n_epochs):
        running_loss = 0
        iou_score = 0
        accuracy = 0
        iter = 1
        with tqdm(training_DataLoader, unit="step", position=0, leave=True) as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}/{n_epochs} - Training")
                # Load data
                inputs, masks = batch[0].to(device), batch[1].to(device)
                # Forward
                outputs = model(inputs.float())            
                # Compute loss
                loss = compute_loss(outputs, masks, criterion)
                # Sobel
                if sobel_loss:
                    sobel_x, sobel_y = sobel_loss(torch.argmax(outputs, dim=1).unsqueeze(1).float(), masks.float())
                    # Update loss with Sobel
                    loss += sobel_x + sobel_y
                # Update loss for stats
                running_loss += loss.item()
                # Evaluation metrics
                iou_score += mIoU(outputs, masks)
                accuracy += pixel_accuracy(outputs, masks)
                # Backward
                loss.backward()
                optimizer.step() # Update weight          
                optimizer.zero_grad() # Empty gradient
                tepoch.set_postfix({'Loss':running_loss/iter,'Acc':accuracy/iter,'IoU':iou_score/iter})
                time.sleep(0.1)
                iter += 1
          
        # Validation
        iter = 1
        model.eval()
        test_loss = 0
        test_accuracy = 0
        test_iou_score = 0
        with tqdm(test_DataLoader, unit="step", position=0, leave=True) as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}/{n_epochs} - Validation")
                time.sleep(0.3)
                inputs, masks = batch[0].to(device), batch[1].to(device)
                # Validation loop
                with torch.no_grad():
                    outputs = model(inputs.float())
                    # Evaluation metrics
                    test_iou_score += mIoU(outputs, masks)
                    test_accuracy += pixel_accuracy(outputs, masks)
                    # Loss
                    loss = compute_loss(outputs, masks, criterion)
                    # Sobel
                    if sobel_loss:
                        sobel_x, sobel_y = sobel_loss(torch.argmax(outputs, dim=1).unsqueeze(1).float(), masks.float())
                        # Update loss with Sobel
                        loss += sobel_x + sobel_y
                    # Update loss for stats
                    test_loss += loss.item()
                    tepoch.set_postfix({'Loss':test_loss/iter,'Acc':test_accuracy/iter,'IoU':test_iou_score/iter})
                    time.sleep(0.1)
                    iter += 1

        # Update scheduler
        history['lrs'].append(get_lr(optimizer))
        if use_scheduler:
            scheduler.step() 

        #calculation mean for each batch
        history['train_loss'].append(running_loss/len(training_DataLoader))
        history['val_loss'].append(test_loss/len(test_DataLoader))

        #iou
        history['val_miou'].append(test_iou_score/len(test_DataLoader))
        history['train_miou'].append(iou_score/len(training_DataLoader))
        history['train_acc'].append(accuracy/len(training_DataLoader))
        history['val_acc'].append(test_accuracy/len(test_DataLoader))
    
        # Save model
        '''
        # Save by best loss
        if min_loss >= test_loss/len(test_DataLoader):
            min_loss = test_loss/len(test_DataLoader)
            if save_training:
                save_checkpoint(model, model_name + '_best')
                print('New best Val Loss: {:.6f} at epoch {}'.format(min_loss, epoch+1))
        elif save_training:
            save_checkpoint(model, model_name + '_checkpoint')
        '''  
        # Save by best mIoU
        if max_iou <= test_iou_score/len(test_DataLoader):
            max_iou = test_iou_score/len(test_DataLoader)
            if save_training:
                save_checkpoint(model, model_name + '_best')
                print('New best Val IoU: {:.6f} at epoch {}'.format(max_iou, epoch+1))
        elif save_training:
            save_checkpoint(model, model_name + '_checkpoint')
            
        show_prediction_example(model, device, 'Result for epoch: '+str(epoch+1), 'coronacases_009')
    
        if save_training:
            save_history(model_save_path + model_name + '/' + model_name + '_history', history)
    
    print('Finished Training')
    plot_history(history)

## Fine Tuning

In [22]:
# Do fine tuning is flag is True
if do_fine_tuning:
    model_config = {
        'model_name': '',  # name of the model without prefix and processing to be loaded
        'approach': approach,  # 2.5D, transformer
        'version': 'checkpoint', # By default uses the saved checkpoint
        'fine_tuning': False, # True if the model has been already fine tuned once
        'criterion': LOSSES[loss_name].__class__.__name__ , # Loss function
        'vit_name': vit_name, # Vit model name, used only if approach:'transformer'
        'n_skip': n_skip, # Vit n of skip, used only if approach:'transformer'
        'vit_patches_size': vit_patches_size, # vit size of patches, used only if approach:'transformer'
        'encoder_name': encoder_name, # name of the encoder, used only if approach:'2.5D'
        'encoder_weights': encoder_weights, # weights for the encoder, used only if approach:'2.5D'
        'decoder_name': 'Unet', # name of the decoder, used only if approach:'2.5D'
        'attention': None, # uses or not the attention layers, used only if approach:'2.5D'
        'batch_norm': True, # batch_normalization technique, used only if approach:'2.5D'
        'weights_path': model_save_path + '', # path to weights to be fine tuned
        'norm_method': norm_method, # normalization method
        'lower': lower, # lower bound for normalization, used only if norm_method:'in_range'
        'higher': higher, # upper bound for normalization, used only if norm_method:'in_range'
        'eq_method': eq_method, # equalization method
        'prefix': '' # custom prefix of the model to be loaded
    }
    
    # Stores the previous normalization and equalization methods and changes with those 
    # defined in the model config avoiding to use the wrong ones
    prev_eq_method = eq_method
    prev_norm_method = norm_method
    prev_lower, prev_higher = lower, higher
    
    eq_method = model_config['eq_method']
    norm_method = model_config['norm_method']
    lower, higher = model_config['lower'], model_config['higher'] 


    torch.cuda.empty_cache()

    # Some globals
    history = {'train_loss': [], 'val_loss': [],
               'train_miou': [], 'val_miou': [],
               'train_acc': [], 'val_acc': [],
               'lrs': [lr]}
    min_loss = float('inf')
    max_IoU = float('-inf')

    # Data Loaders
    chosen_dataset = DATASETS[dataset_name]
    training_DataLoader, test_DataLoader = init_train_test_loader(chosen_dataset, 'training',
                                                                  chosen_dataset, 'validation',
                                                                  num_workers=4, size_valid=512*3*2)

    # Loss
    criterion = LOSSES[loss_name]
  
    if sobel_loss:
        sobel_loss = LOSSES['Sobel'].to(device)

    # Model
    model, model_name = load_model(model_config)
    if freeze_encoder:
        print("Freezing encoders..")
        for param in model.encoder.parameters():
            param.requires_grad = False
    model_name = finetuning_prefix_name + model_name.split('_'+model_config['version'])[0]
    # Optimizer
    optimizer = init_optimizer(model)
    # Scheduler
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=n_epochs, steps_per_epoch=len(training_DataLoader))

    print("Training: {}\n".format(model_name))
    time.sleep(0.3)
    for epoch in range(n_epochs):
        running_loss = 0
        iou_score = 0
        accuracy = 0
        iter = 1
        with tqdm(training_DataLoader, unit="step", position=0, leave=True) as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}/{n_epochs} - Training")
                # Load data
                inputs, masks = batch[0].to(device), batch[1].to(device)
                # Forward
                outputs = model(inputs.float())
                # Compute loss
                loss = compute_loss(outputs, masks, criterion)
                # Sobel
                if sobel_loss:
                    sobel_x, sobel_y = sobel_loss(torch.argmax(outputs, dim=1).unsqueeze(1).float(), masks.float())
                    # Update loss with Sobel
                    loss += sobel_x + sobel_y
                # Update loss for stats
                running_loss += loss.item()
                # Evaluation metrics
                iou_score += mIoU(outputs, masks)
                accuracy += pixel_accuracy(outputs, masks)
                # Backward
                loss.backward()
                optimizer.step() # Update weight          
                optimizer.zero_grad() # Empty gradient
                tepoch.set_postfix({'Loss':running_loss/iter,'Acc':accuracy/iter,'IoU':iou_score/iter})
                time.sleep(0.1)
                iter += 1
          
        # Validation
        iter = 1
        model.eval()
        test_loss = 0
        test_accuracy = 0
        test_iou_score = 0
        with tqdm(test_DataLoader, unit="step", position=0, leave=True) as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}/{n_epochs} - Validation")
                time.sleep(0.3)
                inputs, masks = batch[0].to(device), batch[1].to(device)
                # Validation loop
                with torch.no_grad():
                    outputs = model(inputs.float())
                    # Evaluation metrics
                    test_iou_score += mIoU(outputs, masks)
                    test_accuracy += pixel_accuracy(outputs, masks)
                    # Loss
                    loss = compute_loss(outputs, masks, criterion)
                    # Sobel
                    if sobel_loss:
                        sobel_x, sobel_y = sobel_loss(torch.argmax(outputs, dim=1).unsqueeze(1).float(), masks.float())
                        # Update loss with Sobel
                        loss += sobel_x + sobel_y
                    # Update loss for stats
                    test_loss += loss.item()
                    tepoch.set_postfix({'Loss':test_loss/iter,'Acc':test_accuracy/iter,'IoU':test_iou_score/iter})
                    time.sleep(0.1)
                    iter += 1

        # Update scheduler
        history['lrs'].append(get_lr(optimizer))
        if use_scheduler:
            scheduler.step() 

        #calculation mean for each batch
        history['train_loss'].append(running_loss/len(training_DataLoader))
        history['val_loss'].append(test_loss/len(test_DataLoader))

        #iou
        history['val_miou'].append(test_iou_score/len(test_DataLoader))
        history['train_miou'].append(iou_score/len(training_DataLoader))
        history['train_acc'].append(accuracy/len(training_DataLoader))
        history['val_acc'].append(test_accuracy/len(test_DataLoader))
    
        # Save model
        '''
        # Save by best loss
        if min_loss >= test_loss/len(test_DataLoader):
            min_loss = test_loss/len(test_DataLoader)
            if save_training:
                save_checkpoint(model, model_name + '_best')
                print('New best Val Loss: {:.6f} at epoch {}'.format(min_loss, epoch+1))
        elif save_training:
            save_checkpoint(model, model_name + '_checkpoint')
        '''     
        # Save by best mIoU
        if max_IoU <= test_iou_score/len(test_DataLoader):
            max_IoU = test_iou_score/len(test_DataLoader)
            if save_training:
                save_checkpoint(model, model_name + '_best')
                print('New best Val IoU: {:.6f} at epoch {}'.format(max_IoU, epoch+1))
                
        show_prediction_example(model, device, 'Result for epoch: '+str(epoch+1), 'coronacases_009')
    
        if save_training:
            save_history(model_save_path + model_name + '/' + model_name + '_history', history)
    
    print('Finished Training')
    plot_history(history)


    eq_method = prev_eq_method 
    norm_method = prev_norm_method
    lower, higher  = prev_lower, prev_higher 

# Prediction

## DataLoader

In [23]:
class COVID_PredictionDataLoader():
    """
    Class to handle datasets for predictions. Different from the previous one, this class 
    is meant to pass, as input, 2D images in the correct order to recreate the original volumes
    -----------
    Parameters:
        dataset: dict or list of dict,
            dataset dict with all the relative paths
        dts_type: str or list of str,
            defining the dataset split folder, 'training' or 'validation'
        shape: tuple,
            images target shape
        size_n_patients: int, Optional,
            number of patients to work with, with default value -1 it uses all the patients
    """  
    def __init__(self, dataset, dts_type, shape, size_n_patients=-1):
        self.dataset = dataset
        self.dts_type = dts_type
        self.shape = shape
        self.all_x = {}
        self.all_y = {}
        self.size_n_patients = size_n_patients
        self.x = []
        self.y = []
        self.mean = 0
        self.std = 0
        self.n_patiens = 0
        self.__getdataset__()

    def __getdataset__(self):
        if isinstance(self.dts_type, str):
            self.__populate_dataset__(self.dts_type)
        else:
            for folder in self.dts_type:
                self.__populate_dataset__(folder)
            self.mean /= len(self.dts_type)
            self.std /= len(self.dts_type)
   

    def __populate_dataset__(self, folder):
        # Handle unique dataset
        if not isinstance(self.dataset, list):
            counter = 0
            for patient_x, patient_y in zip(self.dataset['image'][folder],self.dataset['mask'][folder]):
                self.all_x[patient_x] = {'x':[], 'y':[], 'z':[]}
                self.all_y[patient_y] = {'x':[], 'y':[], 'z':[]}
                for axis in ['x', 'y', 'z']:
                    if axis == 'x':
                        n_samp = self.shape[0]
                    elif axis == 'y':
                        n_samp = self.shape[1]
                    elif axis == 'z':
                        n_samp = self.shape[2]
                    for slice_n in range(n_samp):
                        self.all_x[patient_x][axis].append(patient_x + '/' + axis + '/' + str(slice_n) + '.npy')
                        self.all_y[patient_y][axis].append(patient_y + '/' + axis + '/' + str(slice_n) + '.npy') 
                counter += 1
                if self.size_n_patients == counter:
                    break
            self.mean += self.dataset['mean']
            self.std += self.dataset['std']
        else:
            for subset in self.dataset:
                counter = 0
                for patient_x, patient_y in zip(subsett['image'][folder],subset['mask'][folder]):
                    self.all_x[patient_x] = {'x':[], 'y':[], 'z':[]}
                    self.all_y[patient_y] = {'x':[], 'y':[], 'z':[]}
                    for axis in ['x', 'y', 'z']:
                        if axis == 'x':
                            n_samp = self.shape[0]
                        elif axis == 'y':
                            n_samp = self.shape[1]
                        elif axis == 'z':
                            n_samp = self.shape[2]
                        for slice_n in range(n_samp):
                            self.all_x[patient_x][axis].append(patient_x + '/' + axis + '/' + str(slice_n) + '.npy')
                            self.all_y[patient_y][axis].append(patient_y + '/' + axis + '/' + str(slice_n) + '.npy') 
                    counter += 1
                    if self.size_n_patients == counter:
                        break
                self.mean += self.dataset['mean']
                self.std += self.dataset['std']
            self.mean /= len(self.dataset)
            self.std /= len(self.dataset) 
            
        self.info = len(self.all_x)
        self.n_patients = len(self.all_x)

        if len(self.x)!=len(self.y): raise SystemError('Problem with Img and Gt, no same size')
        
    def __len__(self):
        return self.info

    def get_n_patients(self):
        return self.n_patients
    
    def get_patient_path(self, patient_index):
        # Get one patient only from the list of all patiens
        return list(self.all_y.keys())[patient_index].split("mask/")[-1]

    def get_one_axis(self, patient_index, axis):
        # Get one axis only among the three axes
        patient_x = list(self.all_x.keys())[patient_index]
        patient_y = list(self.all_y.keys())[patient_index]
        self.x = self.all_x[patient_x][axis]
        self.y = self.all_y[patient_y][axis]
        self.x.sort(key=natural_keys)
        self.y.sort(key=natural_keys)
        self.info = len(self.x)

    def __getitem__(self, index=None):    
        if index is None:
            index = np.random.randint(0, self.info)
        ct_scan = np.load(self.x[index])
        ct_scan = norm_and_eq(ct_scan, self.mean, self.std)
        
        mask = np.load(self.y[index])
        ct_scan = torch.from_numpy(ct_scan).unsqueeze(0)
        mask = torch.from_numpy(mask).unsqueeze(0)
         
        return ct_scan.float(), mask.float()

## Aux functions

### Prediction functions

In [24]:
def predict_one_axis(model, test_Dataset, axis, bs=32):
    """
    Predicts all slices along one axis.
    ---------
    Parameters:
        model: model object,
            the model used to predict
        test_Dataset: dict,
            dataset dictinary with all the relative paths
        axis: str,
            axis to be predicted
        bs: int, Optional,
            batch size
    Returns:
        The prediction and the relative dice score
    """
    # Prepare Dataloader
    test_DataLoader = DataLoader(test_Dataset, batch_size=bs, num_workers=4)
    # data_array should be in shape (batch_size, channel, height, width)
    x_size, y_size, z_size = target_shape
    model.eval()
    time.sleep(0.3)
    with torch.no_grad():
        dice_score = 0
        iter = 1
        prediction = np.zeros([x_size,y_size,z_size])
        channel_from = 0
        channel_to = bs
        with tqdm(test_DataLoader, unit="step", position=0, leave=True) as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch [1/1] - Predicting axis {axis}")
                predict = model(batch[0].float().to(device))
                dice_score += weighted_dice_score_2D(predict.float(), batch[1].squeeze(1).long().to(device))
                predict = torch.argmax(predict, dim=1)
                predict = predict.cpu().numpy()
                if axis == "x":
                    prediction[channel_from:channel_to,:,:] = predict
                elif axis == 'y':
                    prediction[:,channel_from:channel_to,:] = predict.transpose([1,0,2])
                elif axis == 'z':
                    prediction[:,:,channel_from:channel_to] = predict.transpose([1,2,0])
                else:
                    assert False
                channel_from = channel_to
                channel_to = channel_to + len(batch[0])
            tepoch.set_postfix({'Weighted Dice Accuracy':dice_score/iter})
            time.sleep(0.1)
            iter += 1
    return prediction, dice_score/len(test_DataLoader)


def predict_one_patient(model, dataset, dts_type, patient_index, dataset_class=None):
    """
    Predicts one patient considering all the slices along all the axes.
    ---------
    Parameters:
        model: model object,
            the model used to predict
        dataset: dict,
            dataset dictinary with all the relative paths
        dts_type: str,
            dataset split folder, 'training' or 'validation'
        patient_index: int,
            index of the patient to be predicted
        dataset_class: COVID_PredictionDataLoader class,
            the COVID_PredictionDataLoader class
    Returns:
        The three predictions along the three axes, the dice score along each axis and the path 
        to the predicted patient data
    """
    if dataset_class == None:
        dataset_class = COVID_PredictionDataLoader(dataset, dts_type, target_shape)
    patient_path = dataset_class.get_patient_path(patient_index)
    print("Predicting...")
    # Predict x
    dataset_class.get_one_axis(patient_index, 'x')
    xpred, xdice = predict_one_axis(model, dataset_class, 'x', 32)
    # Predict y
    dataset_class.get_one_axis(patient_index, 'y')
    ypred, ydice = predict_one_axis(model, dataset_class, 'y', 32)
    # predict z
    dataset_class.get_one_axis(patient_index, 'z')
    zpred, zdice = predict_one_axis(model, dataset_class, 'z', 32)  
    return xpred, ypred, zpred, [xdice, ydice, zdice], patient_path


def predict_all_patients(dataset_mask_root, dataset, dts_type, model_config, size_n_patients=-1):
    """
    Predicts all the patients considering all the slices along all the axes.
    ---------
    Parameters:
        dataset_mask_root: str,
            path to the raw ground truth masks
        dataset: dict,
            dataset dictinary with all the relative paths
        dts_type: str,
            dataset split folder, 'training' or 'validation'
        model_config: dict,
            dict with all the information needed about the model
        size_n_patients: int, Optional,
            number of patients to be predicted
    Returns:
        A dict with the 2D dice score along each axis, the 3D dice score along each axis, the 
        total 3D dice score, the total 3D weighted dice score, the 3D dice score of each label, 
        the Pearson Correlation Coefficient, the precision and recall values and the 
        confusion matrix
    """    
    test_Dataset = COVID_PredictionDataLoader(dataset, dts_type, target_shape, size_n_patients)
    tot_x_dice2D, tot_y_dice2D, tot_z_dice2D = 0, 0, 0
    tot_x_dice3D, tot_y_dice3D, tot_z_dice3D = 0, 0 ,0
    dice3D, weighted_dice3D, pcc = 0, 0, 0
    each_label_dice = np.zeros(n_classes)
    pcc = 0
    cm = np.zeros((n_classes,n_classes))
    # Load model
    model, _ = load_model(model_config)
    n_patients = test_Dataset.get_n_patients()
    for i in range(n_patients):
        print("Predicting patient ", i+1)
        xpred, ypred, zpred, scores, patient_path = predict_one_patient(model, dataset, dts_type, i, test_Dataset)
        print("Combining predictions")
        if merging_method == 'bagging':
            print("    Bagging")
            pred = bagging(xpred, ypred, zpred)   
        elif merging_method == 'boosting':
            print("    Boosting")
            pred = boosting(xpred, ypred, zpred, scores)   
        elif merging_method == 'threshold bagging':
            print("    Bagging with Threshold")
            pred = bagging_with_threshold(xpred, ypred, zpred)
                      
        gt = load_scan(dataset_mask_root + patient_path + '.nii.gz')
        
        'Option 1: Rescale the GT'
        #print("Rescaling GT")  
        #gt = rescale_one_gt(gt)
        
        'Option 2: Rescale the Prediction'
        print("Rescaling predictions")                
        pred = rescale_to_original(pred, (334/512, 334/512, 1), target_resolution=gt.header.get_zooms(), target_shape=gt.shape)
        xpred = rescale_to_original(xpred, (334/512, 334/512, 1), target_resolution=gt.header.get_zooms(), target_shape=gt.shape).round()
        ypred = rescale_to_original(ypred, (334/512, 334/512, 1), target_resolution=gt.header.get_zooms(), target_shape=gt.shape).round()
        zpred = rescale_to_original(zpred, (334/512, 334/512, 1), target_resolution=gt.header.get_zooms(), target_shape=gt.shape).round()
        gt = gt.get_fdata()

        gt = gt.round()
        pred = pred.round()
        
        tot_x_dice2D += scores[0]
        tot_y_dice2D += scores[1]
        tot_z_dice2D += scores[2]
            
        print("Computing each prediction 3D Weighted Dice Score")
        tot_x_dice3D = weighted_dice_score_3D(xpred, gt)
        tot_y_dice3D = weighted_dice_score_3D(ypred, gt)
        tot_z_dice3D = weighted_dice_score_3D(zpred, gt)
        print("Computing final prediction 3D Dice Scores")        
        weighted_dice3D += weighted_dice_score_3D(pred, gt)
        dice3D += dice_score_3D(pred, gt)
        print("Computing each label Dice score")
        each_label_dice += each_label_dice_score_3D(pred, gt)
        print("Computing final prediction Pearson Correlation Coefficient")        
        pcc += np.corrcoef(pred.flatten(), gt.flatten())[0][1]        
        print("Computing final prediction Confusion Matrix\n")        
        cm += confusion_matrix(pred.flatten(), gt.flatten(), normalize='true')
        
    cm /= n_patients
    
    print("Computing Precision & Recall\n") 
    true_pos = np.diag(cm)
    false_pos = np.sum(cm, axis=0) - true_pos
    false_neg = np.sum(cm, axis=1) - true_pos

    precision = true_pos / (true_pos + false_pos)
    recall = true_pos / (true_pos + false_neg)
    
    each_label_dice /= n_patients
    
    return {'2D Dice x':tot_x_dice2D/n_patients, '2D Dice y':tot_y_dice2D/n_patients, '2D Dice z':tot_z_dice2D/n_patients,
            '3D Dice x':tot_x_dice3D/n_patients, '3D Dice y':tot_y_dice3D/n_patients, '3D Dice z':tot_z_dice3D/n_patients,
            '3D Dice':dice3D/n_patients, '3D Weighted Dice':weighted_dice3D/n_patients, 'Each Label Dice': each_label_dice,
            'PCC':pcc/n_patients, 'precision':precision, 'recall':recall, 'CM':cm}
    

def predict_one_axis_softmax(model, test_Dataset, axis, axis_w, class_w, bs=4, prediction=None):
    """
    Predicts all slices along one axis for the probabilities combination method.
    ---------
    Parameters:
        model: model object,
            the model used to predict
        test_Dataset: dict,
            dataset dictinary with all the relative paths
        axis: str,
            axis to be predicted
        axis: list of float,
            list of weights for each axis
        class_w: list of float,
            list of weights for each axis
        bs: int, Optional,
            batch size
        prediction: Tensor,
            tensor storing previous predicted probabilities
    Returns:
        The prediction
    """
    # Prepare Dataloader
    test_DataLoader = DataLoader(test_Dataset, batch_size=bs, num_workers=4)
    # data_array should be in shape (batch_size, channel, height, width)
    x_size, y_size, z_size = target_shape
    model.eval()
    time.sleep(0.3)
    with torch.no_grad():
        if prediction == None:
            prediction = torch.zeros([n_classes,x_size,y_size,z_size]).to(device)
        # Load test data
        test_DataLoader = DataLoader(test_Dataset, batch_size=bs, num_workers=2)
        channel_from = 0
        channel_to = bs
        with tqdm(test_DataLoader, unit="step", position=0, leave=True) as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch [1/1] - Predicting axis {axis}")
                predict = model(batch[0].float().to(device))
                predict = torch.softmax(predict, dim=1)
                if axis == 'x':
                    predict *= axis_w[0]
                    prediction[0,channel_from:channel_to,:,:] += predict[:,0,:,:] * class_w[0]
                    prediction[1,channel_from:channel_to,:,:] += predict[:,1,:,:] * class_w[1]
                    prediction[2,channel_from:channel_to,:,:] += predict[:,2,:,:] * class_w[2]
                elif axis == 'y':
                    predict *= axis_w[1]
                    prediction[0,:,channel_from:channel_to,:] += predict[:,0,:,:].permute([1,0,2]) * class_w[0]
                    prediction[1,:,channel_from:channel_to,:] += predict[:,1,:,:].permute([1,0,2]) * class_w[1]
                    prediction[2,:,channel_from:channel_to,:] += predict[:,2,:,:].permute([1,0,2]) * class_w[2]
                elif axis == 'z':
                    predict *= axis_w[2]
                    prediction[0,:,:,channel_from:channel_to] += predict[:,0,:,:].permute([1,2,0]) * class_w[0]
                    prediction[1,:,:,channel_from:channel_to] += predict[:,1,:,:].permute([1,2,0]) * class_w[1]
                    prediction[2,:,:,channel_from:channel_to] += predict[:,2,:,:].permute([1,2,0]) * class_w[2]
                channel_from = channel_to
                channel_to = channel_to + len(batch[0])
    return prediction

def predict_one_patient_softmax(model, dataset, dts_type, patient_index, axis_w, class_w, dataset_class, bs=32):
    """
    Predicts one patient considering all the slices along all the axes and merges with the 
    probabilities combination method.
    ---------
    Parameters:
        model: model object,
            the model used to predict
        dataset: dict,
            dataset dictinary with all the relative paths
        dts_type: str,
            dataset split folder, 'training' or 'validation'
        patient_index: int,
            index of the patient to be predicted
        axis_w: list of float,
            list of weights for each axis
        class_w: list of float,
            list of weights for each class
        dataset_class: COVID_PredictionDataLoader class,
            the COVID_PredictionDataLoader class
        bs: int,
            batch size
    Returns:
        The three predictions along the three axes and the path to the predicted patient data
    """
    if dataset_class == None:
        dataset_class = COVID_PredictionDataLoader(dataset, dts_type, target_shape)
    patient_path = dataset_class.get_patient_path(patient_index)
    print("Predicting...")
    # Predict x
    dataset_class.get_one_axis(patient_index, 'x')
    pred = predict_one_axis_softmax(model, dataset_class, 'x', axis_w=axis_w, class_w=class_w, bs=bs)
    # Predict y
    dataset_class.get_one_axis(patient_index, 'y')
    pred = predict_one_axis_softmax(model, dataset_class, 'y', axis_w=axis_w, class_w=class_w, bs=bs, prediction=pred)
    # predict z
    dataset_class.get_one_axis(patient_index, 'z')
    pred = predict_one_axis_softmax(model, dataset_class, 'z', axis_w=axis_w, class_w=class_w, bs=bs, prediction=pred)
    
    pred = torch.argmax(pred, dim=0).cpu().numpy()
    return pred, patient_path

def predict_all_patients_softmax(dataset_mask_root, dataset, dts_type, model_config, axis_w=[1,1,1], 
                                 class_w=[1,1,1], size_n_patients=-1):
    """
    Predicts all the patients considering all the slices along all the axes,
    using the probabilities combination method.
    ---------
    Parameters:
        dataset_mask_root: str,
            path to the raw ground truth masks
        dataset: dict,
            dataset dictinary with all the relative paths
        dts_type: str,
            dataset split folder, 'training' or 'validation'
        model_config: dict,
            dict with all the information needed about the model
        axis_w: list of float, Optional,
            list with weights for each axis
        class_w: list of float, Optional,
            list with weights for each class
        size_n_patients: int, Optional,
            number of patients to be predicted
    Returns:
        A dict with the total 3D dice score, the total 3D weighted dice score, the 3D dice score of each label, 
        the Pearson Correlation Coefficient, the precision and recall values and the 
        confusion matrix
    """    
    test_Dataset = COVID_PredictionDataLoader(dataset, dts_type, target_shape, size_n_patients)
    dice3D, weighted_dice3D, pcc = 0, 0, 0
    each_label_dice = np.zeros(n_classes)
    pcc = 0
    cm = np.zeros((n_classes,n_classes))
    # Load model
    model, _ = load_model(model_config)
    n_patients = test_Dataset.get_n_patients()
    for i in range(n_patients):
        print("Predicting patient ", i+1)
        pred, patient_path = predict_one_patient_softmax(model, dataset, dts_type, i, axis_w=axis_w, class_w=class_w, 
                                                         dataset_class=test_Dataset)        
                      
        gt = load_scan(dataset_mask_root + patient_path + '.nii.gz')
        
        'Option 1: Rescale the GT'
        #print("Rescaling GT")  
        #gt = rescale_one_gt(gt)
        
        'Option 2: Rescale the Prediction'
        print("Rescaling predictions")                
        pred = rescale_to_original(pred, (334/512, 334/512, 1), target_resolution=gt.header.get_zooms(), target_shape=gt.shape)
        gt = gt.get_fdata()

        gt = gt.round()
        pred = pred.round()
        
        print("Computing final prediction 3D Dice Scores")        
        weighted_dice3D += weighted_dice_score_3D(pred, gt)
        dice3D += dice_score_3D(pred, gt)
        print("Computing each label Dice score")
        each_label_dice += each_label_dice_score_3D(pred, gt)
        print("Computing final prediction Pearson Correlation Coefficient")        
        pcc += np.corrcoef(pred.flatten(), gt.flatten())[0][1]        
        print("Computing final prediction Confusion Matrix\n")        
        cm += confusion_matrix(pred.flatten(), gt.flatten(), normalize='true')
        
    cm /= n_patients
    
    print("Computing Precision & Recall\n") 
    true_pos = np.diag(cm)
    false_pos = np.sum(cm, axis=0) - true_pos
    false_neg = np.sum(cm, axis=1) - true_pos

    precision = true_pos / (true_pos + false_pos)
    recall = true_pos / (true_pos + false_neg)
    
    each_label_dice /= n_patients
    
    return {'3D Dice':dice3D/n_patients, '3D Weighted Dice':weighted_dice3D/n_patients, 'Each Label Dice': each_label_dice,
            'PCC':pcc/n_patients, 'precision':precision, 'recall':recall, 'CM':cm}

def rescale_to_original(array, resolution, target_resolution=(334/512, 334/512, 1), target_shape=(512, 512, 512)):
    # pad and rescale the array to the same resolution and shape for further processing.
    # input: array must has shape (x, y, z) and resolution is a list or tuple with three elements

    original_shape = np.shape(array)
    target_volume = (target_resolution[0]*target_shape[0], target_resolution[1]*target_shape[1], target_resolution[2]*target_shape[2])
    shape_of_target_volume = (int(target_volume[0]/resolution[0]), int(target_volume[1]/resolution[1]), int(target_volume[2]/resolution[2]))

    x = max(shape_of_target_volume[0], original_shape[0]) + 2
    y = max(shape_of_target_volume[1], original_shape[1]) + 2
    z = max(shape_of_target_volume[2], original_shape[2]) + 2

    x_start = int(x/2)-int(original_shape[0]/2)
    x_end = x_start + original_shape[0]
    y_start = int(y/2)-int(original_shape[1]/2)
    y_end = y_start + original_shape[1]
    z_start = int(z / 2) - int(original_shape[2] / 2)
    z_end = z_start + original_shape[2]

    array_intermediate = np.zeros((x, y, z), 'float32')
    array_intermediate[x_start:x_end, y_start:y_end, z_start:z_end] = array

    x_start = int(x / 2) - int(shape_of_target_volume[0] / 2)
    x_end = x_start + shape_of_target_volume[0]
    y_start = int(y / 2) - int(shape_of_target_volume[1] / 2)
    y_end = y_start + shape_of_target_volume[1]
    z_start = int(z / 2) - int(shape_of_target_volume[2] / 2)
    z_end = z_start + shape_of_target_volume[2]

    array_intermediate = array_intermediate[x_start:x_end, y_start:y_end, z_start:z_end]  # Now the array is padded

    # rescaling:
    array_standard_xy = np.zeros((target_shape[0], target_shape[1], shape_of_target_volume[2]), 'float32')
    for s in range(shape_of_target_volume[2]):
        array_standard_xy[:, :, s] = cv2.resize(array_intermediate[:, :, s], (target_shape[0], target_shape[1]), cv2.INTER_LANCZOS4)
    array_standard = np.zeros(target_shape, 'float32')
    for s in range(target_shape[0]):
        array_standard[s, :, :] = cv2.resize(array_standard_xy[s, :, :], (target_shape[2], target_shape[1]), cv2.INTER_LINEAR)

    return array_standard

### Merging functions

In [25]:
def bagging(x, y, z, filter=True):
    """
    Merges the three predictions accordingly to the bagging method
    -----------
    Parameters:
        x: numpy array,
            array with the prediction along the x axis
        y: numpy array,
            array with the prediction along the y axis
        z: numpy array,
            array with the prediction along the z axis
        filter: bool, Optional,
            apply either or not the maximum_filter
    Returns:
        The merged final prediction
    """
    # Combining x-y-z predictions
    final_prediction = np.array((x + y + z) // 3, 'float32')
    if filter:
        # Filter result
        final_prediction = scipy.ndimage.maximum_filter(final_prediction,3)
    return final_prediction

def bagging_with_threshold(x, y, z, thresholds=[1,3], filter=True):
    """
    Merges the three predictions accordingly to the bagging method 
    with a threshold
    -----------
    Parameters:
        x: numpy array,
            array with the prediction along the x axis
        y: numpy array,
            array with the prediction along the y axis
        z: numpy array,
            array with the prediction along the z axis
        thresholds: list of int,
            threshold values
        filter: bool, Optional,
            apply either or not the maximum_filter
    Returns:
        The merged final prediction
    """
    # Combining x-y-z predictions
    pred = np.array((x + y + z), 'float32')
    pred[pred  <= thresholds[0]] = 0
    pred[(pred > thresholds[0]) & (pred <= thresholds[1])] = 1
    pred[pred  > thresholds[1]] = 2
    if filter:
        # Filter result
        pred = scipy.ndimage.maximum_filter(pred,3)
    return pred

def boosting(x_pred, y_pred, z_pred, w, filter=True):
    """
    Merges the three predictions accordingly to the boosting method
    -----------
    Parameters:
        x_pred: numpy array,
            array with the prediction along the x axis
        y_pred: numpy array,
            array with the prediction along the y axis
        z_pred: numpy array,
            array with the prediction along the z axis
        w: list of int,
            list of weights, one for each class
        filter: bool, Optional,
            apply either or not the maximum_filter
    Returns:
        The merged final prediction
    """
    vote_0 = torch.zeros(x_pred.shape).to(device)
    vote_1 = torch.zeros(x_pred.shape).to(device)
    vote_2 = torch.zeros(x_pred.shape).to(device)

    vote_0[torch.where(torch.from_numpy(x_pred)==0)] += w[0]
    vote_0[torch.where(torch.from_numpy(y_pred)==0)] += w[1]
    vote_0[torch.where(torch.from_numpy(z_pred)==0)] += w[2]

    vote_1[torch.where(torch.from_numpy(x_pred)==1)] += w[0]
    vote_1[torch.where(torch.from_numpy(y_pred)==1)] += w[1]
    vote_1[torch.where(torch.from_numpy(z_pred)==1)] += w[2]

    vote_2[torch.where(torch.from_numpy(x_pred)==2)] += w[0]
    vote_2[torch.where(torch.from_numpy(y_pred)==2)] += w[1]
    vote_2[torch.where(torch.from_numpy(z_pred)==2)] += w[2]

    t_shape = x_pred.shape
    vote = torch.zeros((n_classes, t_shape[0], t_shape[1], t_shape[2])).to(device)
    vote[0, :, :, :] = vote_0
    vote_0 = None
    vote[1, :, :, :] = vote_1
    vote_1 = None
    vote[2, :, :, :] = vote_2
    vote_2 = None
    vote = torch.argmax(vote, dim=0)
    vote = vote.cpu().numpy()

    if filter:
        # Filter result
        vote = scipy.ndimage.maximum_filter(vote,3)
    return vote

### Visualization functions

In [26]:
def save_nii_mask(filepath, mask_array, affine=np.eye(4), header=None):
    """
    Saves nii mask
    -----------
    Parameters:
        filepath: str,
            destination path
        mask_array: numpy array,
            mask numpy array to be saved
        affine: numpy array,
            affine matrix for the .nii.gz parameter
        header: nibabel header,
            header to copy .nii.gz parameters
    """
    img = nib.Nifti1Image(mask_array, affine=affine, header=header)
    if header == None:
        img.header.get_xyzt_units()
    img.to_filename(filepath +'.nii.gz')

def save_obj_from_nii(filepath, nii_file):
    """
    Saves an .obj file from .nii.gz
    -----------
    Parameters:
        filepath: str,
            destination path
        nii_file: nibabel object,
            .nii.gz object to be converted as .obj
    """
    verts, faces, normals, values = measure.marching_cubes_lewiner(nii_file.get_fdata(), 0)
    faces=faces +1
    thefile = open(filepath + '.obj', 'w')
    for item in verts:
        thefile.write("v {0} {1} {2}\n".format(item[0],item[1],item[2]))
    for item in normals:
        thefile.write("vn {0} {1} {2}\n".format(item[0],item[1],item[2]))
    for item in faces:
        thefile.write("f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0],item[1],item[2]))  
    thefile.close()

def show_3D_plot(nii_img):
    """
    Plots a 3D graph of a .nii.gz file.
    -----------
    Parameters:
        nii_img: nibabel object,
            .nii.gz file to be plotted
    """
    verts, faces, normals, values = measure.marching_cubes_lewiner(nii_img.get_fdata(), 0)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_trisurf(verts[:, 0], verts[:,1], faces, verts[:, 2],
                    linewidth=0.2, antialiased=True)
    plt.show()

def decode_segmap(image, nc=3, colors=[(0, 0, 0), (0, 128, 0), (128, 0, 0)]):
    """
    Assigns a color map to a mask
    -----------
    Parameters:
        image: numpy array,
            mask numpy array
        nc: int,
            number of classes
        colors: list of int tuples,
            list of rgb values tuples for each label
    Returns:
        The colored mask
    """
    label_colors = np.array(colors) # 0=background # 1=lung, 2=infection
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]
    rgb = np.stack([r, g, b], axis=2)
    return rgb


def show_final_results(ct, gt, pred, n):
    """
    Shows an example of obtained result.
    -----------
    Parameters:
        ct: numpy array,
            original ct-scan numpy array,
        gt: numpy array,
            ground truth numpy array
        pred: numpy array,
            final prediction,
        n: int,
            number of samples
    """
    shape = ct.shape
    for i in range(n):
        nice_sample = False
        while not nice_sample:
            # Choose a random slice
            axis = np.random.choice(['x','y','z'])
            # Take one slice
            if axis == 'x':
                slice_id = np.random.randint(0,shape[0])
                img_slice = ct[slice_id,:,:]
                gt_slice = gt[slice_id,:,:]
                mask_slice = pred[slice_id,:,:]
                y_offest_title = 0.64
            elif axis == 'y':
                slice_id = np.random.randint(0,shape[2])
                img_slice = ct[:,slice_id,:]
                gt_slice = gt[:,slice_id,:]
                mask_slice = pred[:,slice_id,:]
                y_offest_title = 0.64
            elif axis == 'z':
                slice_id = np.random.randint(0,shape[2])
                img_slice = ct[:,:,slice_id]
                gt_slice = gt[:,:,slice_id]
                mask_slice = pred[:,:,slice_id]
                y_offest_title = 0.58
            if len(np.unique(mask_slice)) != 1:
                nice_sample = True
        # Plot
        gt_color = decode_segmap(gt_slice)
        pred_color = decode_segmap(mask_slice)
        f, axarr = plt.subplots(1, 6, figsize=(19,19))
        f.suptitle('Slice n. {} on axis {}'.format(slice_id, axis), y=y_offest_title, fontsize=14)
        axarr[0].imshow(img_slice, cmap='gray')
        axarr[0].set_xlabel('Original image')
        axarr[1].imshow(gt_slice, cmap='gray', vmin=0, vmax=2)
        axarr[1].set_xlabel('Ground truth')
        axarr[2].imshow(img_slice, cmap='gray')
        axarr[2].imshow(gt_color, alpha=0.3)
        axarr[2].set_xlabel('GT mask applied')
        axarr[3].imshow(mask_slice, cmap='gray', vmin=0, vmax=2)
        axarr[3].set_xlabel('Prediction')
        axarr[4].imshow(img_slice, cmap='gray')
        axarr[4].imshow(pred_color, alpha=0.3)
        axarr[4].set_xlabel('Predicted mask applied')
        axarr[5].imshow(np.abs(gt_slice-mask_slice), cmap='gray', vmin=0, vmax=2)
        axarr[5].set_xlabel('Difference') 
        plt.show()

def compare_combination_method(gt, pred1, pred2, n):
    """
    Shows an example of obtained result.
    -----------
    Parameters:
        gt: numpy array,
            ground truth numpy array
        pred1: numpy array,
            first prediction to be compared,
        pred2: numpy array,
            second prediction to be compared,
        n: int,
            number of samples
    """    
    shape = gt.shape
    for i in range(n):
        nice_sample = False
        while not nice_sample:
            # Choose a random slice
            slice_id = np.random.randint(0,shape[2])
            axis = np.random.choice(['x','y','z'])
            # Take one slice
            if axis == 'x':
                gt_slice = gt[slice_id,:,:]
                pred1_slice = pred1[slice_id,:,:]
                pred2_slice = pred2[slice_id,:,:]
                y_offest_title = 0.64
            elif axis == 'y':
                gt_slice = gt[:,slice_id,:]
                pred1_slice = pred1[:,slice_id,:]
                pred2_slice = pred2[:,slice_id,:]
                y_offest_title = 0.64
            elif axis == 'z':
                gt_slice = gt[:,:,slice_id]
                pred1_slice = pred1[:,:,slice_id]
                pred2_slice = pred2[:,:,slice_id]
            y_offest_title = 0.58
            if len(np.unique(gt_slice)) != 1:
                nice_sample = True
        # Plot
        f, axarr = plt.subplots(1, 6, figsize=(19,19))
        f.suptitle('Slice n. {} on axis {}'.format(slice_id, axis), y=y_offest_title, fontsize=14)
        axarr[0].imshow(gt_slice, cmap='gray', vmin=0, vmax=2)
        axarr[0].set_xlabel('Ground truth')
        axarr[1].imshow(pred1_slice, cmap='gray', vmin=0, vmax=2)
        axarr[1].set_xlabel('Prediction method 1')
        axarr[2].imshow(pred2_slice, cmap='gray', vmin=0, vmax=2)
        axarr[2].set_xlabel('Prediction method 2')
        axarr[3].imshow(np.abs(gt_slice-pred1_slice), cmap='gray', vmin=0, vmax=2)
        axarr[3].set_xlabel('Difference GT-Method1') 
        axarr[4].imshow(np.abs(gt_slice-pred2_slice), cmap='gray', vmin=0, vmax=2)
        axarr[4].set_xlabel('Difference GT-Method2') 
        axarr[5].imshow(np.abs(pred1_slice-pred2_slice), cmap='gray', vmin=0, vmax=2)
        axarr[5].set_xlabel('Difference Method1-Method2')  
        plt.show()
    
def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    -----------
    Parameters:
        cm: numpy array,
            confusion matrix
        classes: list,
            list of labels
        title: str, Optional,
            title of the plot
        cmap: matplotlib.cm
            colormap for the plot
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

## Evaluation

In [27]:
def weighted_dice_score_2D(input, target, eps=1e-6):
    """
    Computes the 2D weighted dice score.
    -----------
    Parameters:
        input: Tensor,
            prediction tensor
        target: Tensor,
            ground truth mask tensor
        eps: float, Optional,
            epsilon factor
    Returns:
        The 2D weighted dice score
    """
    with torch.no_grad():
        # compute softmax over the classes axis
        input_soft = torch.nn.functional.softmax(input, dim=1)
        input_soft = input_soft.view(input_soft.shape[0], n_classes, -1) # (B, Nclasses, HxW)

        # create the labels one hot tensor
        target_one_hot = one_hot2D(target, num_classes=input.shape[1],
                                  device=input.device, dtype=input.dtype)
        target_one_hot = target_one_hot.view(target_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxW)

        # count n element for each class
        counts = torch.sum(target_one_hot, dim=2)
        weights = torch.as_tensor(1. / (counts ** 2), dtype=float)
        weights = torch.where(torch.isfinite(weights), weights, eps) # (B, Nclasses)

        # compute the actual dice score factors
        intersection = torch.sum(input_soft * target_one_hot, 2)
        cardinality = torch.sum(input_soft + target_one_hot, 2) # (B)

        # apply weights
        intersection = torch.sum(weights*intersection, axis=-1)
        cardinality = torch.sum(weights*cardinality, axis=-1)
        # compute dice score
        dice_score = 2. * intersection / cardinality
        dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))

        return torch.mean(dice_score).item()

def one_hot_3D(labels, num_classes, device=None, dtype=None, eps=1e-6):
    """
    Convert BxHxWxD image tensor to one hot encoding image tensor with shape BxNxHxWxD with
    B: batch size, H: height, W: width, D: depth and N: number of classes
    -----------
    Parameters:
        labels: Tensor,
            BxHxW Tensor with the ground truth labels
        num_classes: int,
            total number of classes
    Returns:
        the one hot encoded tensor
    """
    batch_size, height, width, depth = labels.shape
    one_hot = torch.zeros((batch_size, num_classes, height, width, depth), device=device)
    return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps

def weighted_dice_score_3D(input, target, eps=1e-6): 
    """
    Computes the 3D weighted dice score.
    -----------
    Parameters:
        input: Tensor,
            prediction tensor
        target: Tensor,
            ground truth mask tensor
        eps: float, Optional,
            epsilon factor
    Returns:
        The 3D weighted dice score
    """
    with torch.no_grad():
        # compute softmax over the classes axis
        input_one_hot = torch.from_numpy(input).unsqueeze(0).long().to(device)
        input_one_hot = one_hot_3D(input_one_hot, num_classes=n_classes, device=device, dtype=input.dtype)
        input_one_hot = input_one_hot.view(input_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxWxD)
        # create the labels one hot tensor
        target_one_hot = torch.from_numpy(target).unsqueeze(0).long().to(device)
        target_one_hot = one_hot_3D(target_one_hot, num_classes=n_classes, device=device, dtype=input.dtype)
        target_one_hot = target_one_hot.view(target_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxWxD)

        # count n element for each class
        counts = torch.sum(target_one_hot, dim=2)
        weights = torch.as_tensor(1. / (counts ** 2), dtype=float)
        weights = torch.where(torch.isfinite(weights), weights, eps) # (B, Nclasses)

        # compute the actual dice score factors
        intersection = torch.sum(input_one_hot * target_one_hot, 2)
        cardinality = torch.sum(input_one_hot + target_one_hot, 2) # (B)

        # apply weights
        intersection = torch.sum(weights*intersection, axis=-1)
        cardinality = torch.sum(weights*cardinality, axis=-1)
        # compute dice score
        dice_score = 2. * intersection / cardinality
        dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))

        return torch.mean(dice_score).item()

def dice_score_3D(input, target, eps=1e-6): 
    """
    Computes the 3D dice score.
    -----------
    Parameters:
        input: Tensor,
            prediction tensor
        target: Tensor,
            ground truth mask tensor
        eps: float, Optional,
            epsilon factor
    Returns:
        The 3D dice score
    """
    with torch.no_grad():
        # compute softmax over the classes axis
        input_one_hot = torch.from_numpy(input).unsqueeze(0).long().to(device)
        input_one_hot = one_hot_3D(input_one_hot, num_classes=n_classes, device=device, dtype=input.dtype)
        input_one_hot = input_one_hot.view(input_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxWxD)

        # create the labels one hot tensor
        target_one_hot = torch.from_numpy(target).unsqueeze(0).long().to(device)
        target_one_hot = one_hot_3D(target_one_hot, num_classes=n_classes, device=device, dtype=input.dtype)
        target_one_hot = target_one_hot.view(target_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxWxD)

        # compute the actual dice score factors
        intersection = torch.sum(input_one_hot * target_one_hot, 2)
        cardinality = torch.sum(input_one_hot + target_one_hot, 2) # (B)

        # compute dice score
        dice_score = 2. * intersection / cardinality
        dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))

        return torch.mean(dice_score).item()
    
def each_label_dice_score_3D(input, target, eps=1e-6):
    """
    Computes the 3D dice score of each label.
    -----------
    Parameters:
        input: Tensor,
            prediction tensor
        target: Tensor,
            ground truth mask tensor
        eps: float, Optional,
            epsilon factor
    Returns:
        The 3D dice scores of each label
    """
    dice_scores = np.zeros(n_classes)
    with torch.no_grad():
        # compute softmax over the classes axis
        input_one_hot = torch.from_numpy(input).unsqueeze(0).long().to(device)
        input_one_hot = one_hot_3D(input_one_hot, num_classes=n_classes, device=device, dtype=input.dtype)
        input_one_hot = input_one_hot.view(input_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxWxD)

        # create the labels one hot tensor
        target_one_hot = torch.from_numpy(target).unsqueeze(0).long().to(device)
        target_one_hot = one_hot_3D(target_one_hot, num_classes=n_classes, device=device, dtype=input.dtype)
        target_one_hot = target_one_hot.view(target_one_hot.shape[0], n_classes, -1) # (B, Nclasses, HxWxD)
        
        for i in range(n_classes):
            input_one_label = input_one_hot[:,i,:]
            target_one_label = target_one_hot[:,i,:]

            # compute the actual dice score factors
            intersection = torch.sum(input_one_label * target_one_label, 1)
            cardinality = torch.sum(input_one_label + target_one_label, 1) # (B)

            # compute dice score
            dice_score = 2. * intersection / cardinality
            dice_score = torch.where(torch.isfinite(dice_score), dice_score, torch.zeros_like(dice_score))
            dice_scores[i] = torch.mean(dice_score).item()

        return dice_scores

## Predict

### Predict and print stats

In [28]:
# Do prediction if flag is True
  
model_config = {
    'model_name': '', # name of the model without prefix and processing to be loaded
    'approach': '2.5D', # 2.5D, transformer
    'version': 'best',  # best, checkpoint
    'fine_tuning': False, # True if the model has been fine tuned
    'freeze_encoder': True, # True if the model has been fine tuned with the encoder weights freezed, used only if fine_tuning:True
    'vit_name': 'R50+ViT-B_16', # Vit model name, used only if approach:'transformer'
    'n_skip': 2, # Vit n of skip, used only if approach:'transformer'
    'vit_patches_size': 16, # vit size of patches, used only if approach:'transformer'
    'encoder_name': 'resnet101', # name of the encoder, used only if approach:'2.5D'
    'encoder_weights': 'imagenet',  # weights for the encoder, used only if approach:'2.5D'
    'decoder_name': 'Unet', # name of the decoder, used only if approach:'2.5D'
    'attention': None, # to use or not the attention layers, used only if approach:'2.5D'
    'batch_norm': True, # batch_normalization technique, used only if approach:'2.5D'
    'weights_path': model_save_path + '', # path to weights to be loaded
    'norm_method': 'as_colab', # normalization method
    'lower': 0, # lower bound for normalization, used only if norm_method:'in_range'
    'higher': 255, # upper bound for normalization, used only if norm_method:'in_range'
    'eq_method': 'clahe+histeq', # equalization method
    'prefix': '' # custom prefix of the model to be loaded
}


prediction_dataset = 'challenge' # dataset on which perform the evaluation

if do_predict:
    
    # Stores the previous normalization and equalization methods and changes with those 
    # defined in the model config avoiding to use the wrong ones
    prev_eq_method = eq_method
    prev_norm_method = norm_method
    prev_lower, prev_higher = lower, higher
    
    eq_method = model_config['eq_method']
    norm_method = model_config['norm_method']
    lower, higher = model_config['lower'], model_config['higher'] 
    
    # Predict
    if merging_method == 'softmax':
        axis_w = [1, 1, 1]
        class_w = [1, 1, 1]
        if prediction_dataset == 'zenodo':
            scores = predict_all_patients_softmax(zenodo_mask_root, zenodo_proc, ['training','validation'], model_config, axis_w, class_w)
        elif prediction_dataset == 'challenge':
            scores = predict_all_patients_softmax(challenge_mask_root, challenge_proc, 'validation', model_config, axis_w, class_w, size_n_patients=10)
    else:
        if prediction_dataset == 'zenodo':
            scores = predict_all_patients(zenodo_mask_root, zenodo_proc, ['training','validation'], model_config)
        elif prediction_dataset == 'challenge':
            scores = predict_all_patients(challenge_mask_root, challenge_proc, 'validation', model_config, size_n_patients=10)
    
    print("\n\n\nResults:\n")
    for key in scores.keys():
        if key != 'CM':
            print(key, scores[key])
        
    plot_confusion_matrix(scores['CM'], [0,1,2],
                          title='Confusion matrix',
                          cmap=plt.cm.Blues)
        
    eq_method = prev_eq_method 
    norm_method = prev_norm_method
    lower, higher  = prev_lower, prev_higher 