# **Biomarker regions segmentation with 3D U-net**

#### 1. Introduction

In this section we will segment gray matter and subcortical nuclei from preprocessed MRI image.

*Proceeding with this Notebook you confirm your personal acess [to the data](https://www.humanconnectome.org/study/hcp-young-adult/document/1200-subjects-data-release). 
 And your agreement on data [terms and conditions](https://www.humanconnectome.org/study/hcp-young-adult/data-use-terms).*


In [None]:
import scipy as sp
import scipy.misc
import matplotlib.pyplot as plt
import numpy as np
import os
%matplotlib inline

In [None]:
# !pip install --quiet --upgrade comet_ml
from comet_ml import Experiment
    
# Create an experiment with your api key
experiment = Experiment(
    api_key='SKty3eCyCLDyXicElR2IoeZpi',
    project_name='segment-brain'
)

### Check the experiments:

**Baseline** - "6 classes, 4 encoding blocks, 8 out, Patch based, 64 batch, crop"
https://www.comet.ml/kondratevakate/mri-segmentation-2021/view/Q6KVHNSRxQ22hC0oM7NL1rF1B

#### 2. Mounting Google drive

Mounting Google Drive to Collab Notebook. You should go with the link and enter your personal authorization code:

from google.colab import drive
drive.mount('/content/drive')

data_dir = '/content/drive/My Drive/Skoltech Neuroimaging/NeuroML2020/data/seminars/anat/fs_segmentation'

data_list = os.listdir(data_dir)

#### 3. Defining the dataset

Defining the working dataset, there:

 1. `norm` - normalised `T1` image processed with Freesurfer 6.0,

 2. `aparc+aseg` segmentation mask for gray matter and subcortical nuclei from Freesufer 6.0 `recon all` pipeline.

 And U-net model will treat `norm` image as input and `aparc+aseg` as target model.

In [None]:
data_dir = '/home/neuro-ml-2002/anat/fs_segmentation'

In [None]:
import pandas as pd

labels_dir = '/home/neuro-ml-2002/anat/'

Defining new `pd.Dataframe()` with `Subject`, `norm` and `target` files:

In [None]:
data_list = pd.DataFrame(columns = ['Subject', 'norm', 'aseg'])

In [None]:
labels = pd.read_csv('./anat/unrestricted_hcp_freesurfer.csv')

In [None]:
data_list['Subject'] = labels['Subject']

Iterating through files and `Subjects` in ID list:

In [None]:
data_list 

In [None]:
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


for i in tqdm(os.listdir(data_dir)):
    for j in range(0, len(data_list['Subject'])):

        if str(data_list['Subject'].iloc[j]) in i:
            if 'norm' in i: # copydaing path to the column norm
                data_list['norm'].iloc[j] = data_dir +'/'+ i
            elif 'aseg' in i: # copying path to second column
                data_list['aseg'].iloc[j] = data_dir +'/'+ i

data_list.dropna(inplace=True)

In [None]:
data_list.head()

Let's have a closer look on the data:

In [None]:
# !pip install --quiet --upgrade nilearn
import nilearn
from nilearn import plotting

# visualising normalised image
img = nilearn.image.load_img(data_list['norm'].iloc[0])
plotting.plot_anat(img)

In [None]:
# visualising segmentation
img = nilearn.image.load_img(data_list['aseg'].iloc[0])
plotting.plot_anat(img)

In [None]:
np.unique(np.asanyarray(img.dataobj))

#### Defining a testing dataset

In [None]:
test_subjects = [100206, 100307, 100408]
test_norm_dir = './test'

testing_data_list = pd.DataFrame({
    'Subject': test_subjects,
    'norm': [f'{test_norm_dir}/HCP_T1_fs6_{subject}_norm.nii.gz' for subject in test_subjects],
    'aseg': [f'{test_norm_dir}/HCP_T1_fs6_{subject}_aparc+aseg.nii.gz' for subject in test_subjects]
})


testing_data_list.head()

#### 4. Writing dataloader

We will use `TorchIO` library: https://torchio.readthedocs.io/

In [None]:
# !pip install --quiet --upgrade torchio

In [None]:
import torchio 
import enum
"""
    Code adapted from: https://github.com/fepegar/torchio#credits

        Credit: Pérez-García et al., 2020, TorchIO: 
        a Python library for efficient loading, preprocessing, 
        augmentation and patch-based sampling of medical images in deep learning.

"""

MRI = 'MRI'
LABEL = 'LABEL'

class Action(enum.Enum):
    TRAIN = 'Training'
    VALIDATE = 'Validation'

def get_torchio_dataset(inputs, targets, transform):
    """
    The function creates dataset from the list of files from cunstumised dataloader.
    """
    subjects = []
    for (image_path, label_path) in zip(inputs, targets):
        subject_dict = {
            MRI : torchio.Image(image_path, torchio.INTENSITY),
            LABEL: torchio.Image(label_path, torchio.LABEL),
        }
        subject = torchio.Subject(subject_dict)
        subjects.append(subject)
    
    if transform:
        dataset = torchio.SubjectsDataset(subjects, transform = transform)
    elif not transform:
        dataset = torchio.SubjectsDataset(subjects)
    
    return dataset, subjects

In [None]:
data, subjects = get_torchio_dataset(data_list['norm'], data_list['aseg'], False)
testing_data, testing_subjects = get_torchio_dataset(testing_data_list['norm'], testing_data_list['aseg'], False)

In [None]:
data[0]['MRI']

#### 3. Writing visualization tools for torch tensors


In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import nibabel

def plot_central_cuts(img, title=""):
    """
    param image: tensor or np array of shape (CxDxHxW) if t is None
    """
    if isinstance(img, torch.Tensor):
        img = img.numpy()
        if (len(img.shape) > 3):
            img = img[0,:,:,:]
                
    elif isinstance(img, nibabel.nifti1.Nifti1Image):    
        img = img.get_fdata()
   
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(3 * 6, 6))
    axes[0].imshow(img[ img.shape[0] // 2, :, :])
    axes[1].imshow(img[ :, img.shape[1] // 2, :])
    axes[2].imshow(img[ :, :, img.shape[2] // 2])
    fig.suptitle(title, fontsize=16)
    plt.show()
    
def plot_predicted(img, seg, gt, delta = 0, title=""):
    """
    param image: tensor or np array of shape (CxDxHxW) if t is None
    """
    if isinstance(img, torch.Tensor):
        img = img.cpu().numpy()
        if (len(img.shape) == 5):
            img = img[0,0,:,:,:]
        elif (len(img.shape) == 4):
            img = img[0,:,:,:]
                
    elif isinstance(img, nibabel.nifti1.Nifti1Image):    
        img = img.get_fdata()
        
    if isinstance(seg, torch.Tensor):
        seg= seg[0].cpu().numpy()#.astype(np.uint8)
            
    ### MY ADDITION: STARTS HERE
    if isinstance(gt, torch.Tensor):
        gt= gt[0].cpu().numpy()#.astype(np.uint8)

    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(20, 20))
    i = img.shape[0] // 2 + delta
    j = seg.shape[0] // 2 + delta
    
#     overlay is less useful than gt for comet logging
#     intersect = seg[ i, ...]*100 + img[ j, ...]
    axes[0, 0].imshow(img[ i, ...])
    axes[0, 0].set_title('input', fontsize=16)
    axes[0, 1].imshow(gt[ j, ...])
    axes[0, 1].set_title('gt', fontsize=16)
#     axes[0, 2].imshow(intersect, label='overlay')
    axes[0, 2].imshow(seg[ j, ...])
    axes[0, 2].set_title('predicted segmentation', fontsize=16)
    
#     intersect = seg[ :, i, :]*100 + img[ :, j, :]
    axes[1, 0].imshow(img[ :, i, :])
    axes[1, 1].imshow(gt[ :, j, :])
#     axes[1, 2].imshow(intersect)
    axes[1, 2].imshow(seg[ :, j, :])
    
#     intersect = seg[ ..., j]*100 + img[ ..., i]
    axes[2, 0].imshow(img[ ..., i])
    axes[2, 1].imshow(gt[ ..., j])
#     axes[2, 2].imshow(intersect)
    axes[2, 2].imshow(seg[ ..., j])
    
    for ax in axes.ravel():
        ax.axis('off')
    fig.tight_layout();
    
    return fig # return figure for comet_ml
    ### MY ADDITION: ENDS HERE
    

The class `dataset` inherits from `torch.utils.data.Dataset.` It receives as input a list of `torchio.Subject` instances and an optional `torchio.transforms.Transform.`

The inputs to the subject class are instances of torchio.Image, such as torchio.ScalarImage or torchio.LabelMap. The image class will be used by the transforms to decide whether or not to perform the operation. For example, spatial transforms must apply to both, but intensity transforms must apply to scalar images only.

https://torchio.readthedocs.io/data/dataset.html

In [None]:
from torch.utils.data import DataLoader, Subset
from torchio import AFFINE, DATA, PATH, TYPE, STEM

In [None]:
print("Dataset size: {}".format(len(data)))
img = data[0][MRI]
seg = data[0][LABEL]
print("Image shape: {}".format(img.shape))
print("Segmentation shape: {}".format(seg.shape))
plot_central_cuts(img[DATA])
plot_central_cuts(seg[DATA])

In [None]:
print("Dataset size: {}".format(len(data)))
img = data[210][MRI]
seg = data[210][LABEL]
print("Image shape: {}".format(img.shape))
print("Segmentation shape: {}".format(seg.shape))
plot_central_cuts(img[DATA])
plot_central_cuts(seg[DATA])

In [None]:
data[0]['MRI']

In [None]:
testing_data[0]['MRI']

Let's choose cropping based on non-zero values of MRI image - maybe all images have padding.

In [None]:
def get_crop(subjects):
    crop = {i: (256, 0) for i in range(3)}
    
    for subj in tqdm(subjects):
        subj_bool = subj['MRI']['data'][0] != 0
        
        ax_zero_cut = subj_bool.max(dim=2).values.max(dim=1).values.data.numpy()
        ax_one_cut = subj_bool.max(dim=2).values.max(dim=0).values.data.numpy()
        ax_two_cut = subj_bool.max(dim=1).values.max(dim=0).values.data.numpy()
        
        ax_zero_min, ax_zero_max = np.where(ax_zero_cut)[0][[0, -1]]
        ax_one_min, ax_one_max = np.where(ax_one_cut)[0][[0, -1]]
        ax_two_min, ax_two_max = np.where(ax_two_cut)[0][[0, -1]]
        
        crop[0] = (min(crop[0][0], ax_zero_min), max(crop[0][1], ax_zero_max + 1))
        crop[1] = (min(crop[1][0], ax_one_min), max(crop[1][1], ax_one_max + 1))
        crop[2] = (min(crop[2][0], ax_two_min), max(crop[2][1], ax_two_max + 1))
    
    for i in range(3):
        crop[i] = (crop[i][0], 256 - crop[i][1])
    
    
    return (crop[1][0], crop[1][1], crop[0][0], crop[0][1], crop[2][0], crop[2][1])

# crop = get_crop(subjects)
crop = (49, 22, 49, 47, 19, 28)
print(crop)

## 2. Whole brain segmentation

Let's define the experiment for whole brain segmentation:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models.vgg import vgg11_bn
from torch.autograd import Function, Variable
from torch.utils.data import DataLoader, Subset
import torch.backends.cudnn as cudnn
from torch import optim


import random
import numpy as np
import pandas as pd

import sys
import os
from optparse import OptionParser
import time

import torchio
from torchio import transforms

# !pip install --quiet --upgrade unet 
from IPython.display import clear_output
import matplotlib.pyplot as plt
from unet import UNet

from sklearn.model_selection import train_test_split, StratifiedKFold, ShuffleSplit

import warnings
import multiprocessing

In [None]:
num_subjects = len(data)

training_split_ratio = 0.9

train_transform = transforms.Compose([
    # several_transforms,
    transforms.Crop(crop),
    transforms.Pad(4)
])

validation_transform = None

training_subjects, validation_subjects = train_test_split(
    subjects, train_size=training_split_ratio, shuffle=True, random_state=42
)

In [None]:
# training_subjects = subjects[:20]
# validation_subjects = subjects[20:40] # experimenting just on 20 first subjects
def get_sets(train_subjects, val_subjects, test_subjects, train_transform=None, val_transform=None):
    training_set = torchio.SubjectsDataset(
        train_subjects, transform=train_transform)

    validation_set = torchio.SubjectsDataset(
        val_subjects, transform=val_transform)

    testing_set = torchio.SubjectsDataset(
        test_subjects, transform=val_transform)
    
    return training_set, validation_set, testing_set

In [None]:
training_set, validation_set, testing_set = get_sets(training_subjects, validation_subjects, testing_subjects,
                                                    train_transform=train_transform, val_transform=train_transform)

print('Training set:', len(training_set), 'subjects')
print('Validation set:', len(validation_set), 'subjects')
print('Testing set:', len(testing_set), 'subjects')

In [None]:
CHANNELS_DIMENSION = 6
SPATIAL_DIMENSIONS = 2, 3, 4

VENTRCL =  [4,5,15,43,44,72]# 1
BRN_STEM = [16] # 2
HIPPOCMPS = [17, 53] # 3
AMYGDL = [18, 54] # 4
GM = [1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013,
       1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024,
       1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035,
       2000, 2001, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011,
       2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022,
       2023, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033,
       2034, 2035] # 5

LABELS = VENTRCL + BRN_STEM + HIPPOCMPS + AMYGDL + GM # all of interest


def prepare_aseg(targets):
    """
    The function binarises the data  with the LABEL list.
   """
    targets = np.where(np.isin(targets, LABELS, invert = True), 0, targets)
    targets = np.where(np.isin(targets, VENTRCL), 1, targets)
    targets = np.where(np.isin(targets, BRN_STEM), 2, targets)
    targets = np.where(np.isin(targets, HIPPOCMPS), 3, targets)
    targets = np.where(np.isin(targets, AMYGDL), 4, targets)
    targets = np.where(np.isin(targets, GM), 5, targets)


    return targets

def prepare_batch(batch, device):
    """
    The function loaging *nii.gz files, sending to the devise.
    For the LABEL in binarises the data.
    """
    inputs = batch[MRI][DATA].to(device)
    targets = batch[LABEL][DATA]
    targets = torch.from_numpy(prepare_aseg(targets))
    targets = targets.to(device)    
    return inputs, targets

In [None]:
plot_central_cuts(validation_set[1][MRI][DATA])

In [None]:
prepare_aseg(validation_set[1][LABEL][DATA][0]).shape

In [None]:
print(np.unique(prepare_aseg(seg[DATA])[0]))

plot_central_cuts(prepare_aseg(validation_set[1][LABEL][DATA][0]))

The data is really heavy, so lets try to start with 1 subject/ batch:

#### Defining the model and optimizer for training

At first check if we have GPU onborad:

In [None]:
torch.cuda.is_available()

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [None]:
class DiceLoss(nn.Module):
    """ Multiclass loss"""
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes, class_wise_dice

In [None]:
def get_model_and_optimizer(device, num_encoding_blocks=4, out_channels_first_layer=8, patience=3):
    #Better to train with num_encoding_blocks >=3, out_channels_first_layer>=4 '''
    #repoducibility
    np.random.seed(0)
    torch.manual_seed(0)
    # added reproducibility for cuda
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    #
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
    model = UNet(
          in_channels=1,
          out_classes=6,
          dimensions=3,
          num_encoding_blocks=num_encoding_blocks,
          out_channels_first_layer=out_channels_first_layer,
          normalization='batch',
          upsampling_type='linear',
          padding=True,
          activation='PReLU',
      ).to(device)
      
    optimizer = torch.optim.AdamW(model.parameters())
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=patience, threshold=0.01)
    
    return model, optimizer, scheduler

# model, optimizer, scheduler = get_model_and_optimizer(device)

In [None]:
def get_loaders(training_set, validation_set):
    patches_training_set = torchio.Queue(
        subjects_dataset=training_set,
        max_length=max_queue_length,
        samples_per_volume=samples_per_volume,
        sampler=torchio.sampler.UniformSampler(patch_size),
        num_workers=num_training_workers,
        shuffle_subjects=True,
        shuffle_patches=True,
    )

    patches_validation_set = torchio.Queue(
        subjects_dataset=validation_set,
        max_length=max_queue_length,
        samples_per_volume=samples_per_volume,
        sampler=torchio.sampler.UniformSampler(patch_size),
        num_workers=num_validation_workers,
        shuffle_subjects=False,
        shuffle_patches=False,
    )

    training_loader = torch.utils.data.DataLoader(
        patches_training_set, batch_size=training_batch_size)

    validation_loader = torch.utils.data.DataLoader(
        patches_validation_set, batch_size=validation_batch_size, shuffle=False)
    
    return training_loader, validation_loader

### Code for training

In [None]:
import SimpleITK as sitk
from einops import rearrange
from collections import Counter
from torchvision.utils import make_grid

summary_freq_train = 50
summary_freq_val = 10

def evaluate(model, evaluation_set, patch_size=64, patch_overlap=0, epoch=None):
    if epoch is None:
        epoch = 0
    dice_scores = []
    iou_scores = []
    dice_loss = DiceLoss(6)
    
    for i in tqdm(range(len(evaluation_set)), leave=False):
        sample = evaluation_set[i]
        input_tensor = sample[MRI][DATA][0]
        targets = torch.from_numpy(
            prepare_aseg(sample[LABEL][DATA])
        
        )
        
        grid_sampler = torchio.inference.GridSampler(
            sample,
            patch_size,
            patch_overlap,
        )
        patch_loader = torch.utils.data.DataLoader(
            grid_sampler, batch_size=validation_batch_size, num_workers=num_validation_workers)

        patch_loader = torch.utils.data.DataLoader(
            grid_sampler, batch_size=validation_batch_size)
    
        aggregator = torchio.inference.GridAggregator(grid_sampler)
        
        model.eval()
        
        with torch.no_grad():
            for patches_batch in patch_loader:
                inputs = patches_batch[MRI][DATA].to(device)
                locations = patches_batch['location']
                logits = model(inputs.float())
        
                aggregator.add_batch(logits, locations)
            
            prediction = aggregator.get_output_tensor()
            dice_loss_, dice_score_ = dice_loss(prediction.unsqueeze(0), 
                                                targets, 
                                                softmax=True)
            if (i % summary_freq_val == 0):
                pred = F.softmax(prediction, dim=0).argmax(0,True)
                experiment.log_figure('validation predictions', plot_predicted(input_tensor, pred, targets),
                                      step=i + epoch*len(evaluation_set))
            dice_scores.append(dice_score_)
    
    
    return {
        'dice': dice_scores
    }



def train(num_epochs, training_loader, validation_set, model, optimizer, start_epoch=0, scheduler=None,
          weights_stem='', patch_size=64, patch_overlap=0):
    
    scores = evaluate(model, validation_set, patch_size=patch_size, patch_overlap=patch_overlap, epoch=start_epoch)
    
    for key in scores:
        scores[key] = np.mean(scores[key])
        experiment.log_metric(f"avg_val_{key}", scores[key], step=0, epoch=start_epoch)
        
    best_dice = scores['dice']
    print(f"Validation mean score: DICE {scores['dice']:0.3f}", "by class:",scores[key] )    
    
    step_counter = Counter()
    torch.save(model.state_dict(), os.path.join(model_dir, f'model_{weights_stem}.pth'))
    for epoch_idx in range(start_epoch+1, start_epoch + num_epochs + 1):
        print('\nStarting epoch', epoch_idx)
        run_epoch(epoch_idx, Action.TRAIN, training_loader, model, optimizer, step_counter,
                  scheduler=scheduler)
        
        scores = evaluate(model, validation_set, patch_size=patch_size, patch_overlap=patch_overlap, epoch=epoch_idx)
        for key in scores:
            scores[key] = np.mean(scores[key])
            experiment.log_metric(f"avg_val_{key}", scores[key], step=epoch_idx, epoch=epoch_idx)
            
        print(f"Validation mean score: DICE {scores['dice']:0.3f}")    
        
        avg_dice = scores['dice']
        if avg_dice > best_dice:
            best_dice = avg_dice
            torch.save(model.state_dict(), os.path.join(model_dir, f'model_{weights_stem}.pth'))
        
        
def run_epoch(epoch_idx, action, loader, model, optimizer, step_counter, scheduler=None, loss=DiceLoss):
    is_training = action == Action.TRAIN
    if isinstance(loss, DiceLoss):
        dice_loss = loss(6)
    elif isinstance(loss, FocalLoss):
        dice_loss = loss(alpha=0.99, gamma=0.2)
    elif isinstance(loss, TverskyLoss):
        dice_loss = loss(alpha=0.3, gamma=0.75)
    else:
        raise NotImplemented
        
    epoch_losses = []
    model.train(is_training)
    
    for batch_idx, batch in enumerate(tqdm(loader, leave=False)):
        inputs, targets = prepare_batch(batch, device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(is_training):
            logits = model(inputs.float())
            batch_losses, score_dice = dice_loss(logits, 
                                                 targets.squeeze(1), 
                                                 softmax=True)
            batch_loss = batch_losses.mean() 
            if is_training:
                batch_loss.backward()
                optimizer.step()
                
            epoch_losses.append(batch_loss.item())
            if action == Action.TRAIN:
                experiment.log_metric("train_dice_loss", batch_loss.item(),
                                      epoch=epoch_idx, step=step_counter[action])
                # log images with lower frequency
                if batch_idx % summary_freq_train == 0:
                    probabilities = F.softmax(logits, dim=1)
                    batch_size, _, patch_d = inputs.shape[:3]
                    images_grid = []
                    
                    for sample_idx in range(batch_size):
                        # appending the predicted and gt picture (central slices)
                        pred = 1 - torch.argmax(logits[sample_idx], dim=0)

                        pred_slice = pred[patch_d // 2]
                        gt_slice   = targets[sample_idx, 0, patch_d // 2]
                        images_grid.append(make_grid([rearrange(pred_slice.float(), '(c h) w -> c h w',  c=1), 
                                                      rearrange(gt_slice.float(), '(c h) w -> c h w', c=1)], 
                                                     nrow=1))          

                    del probabilities
                
                    grid_img = make_grid(images_grid)
                    grid_img = rearrange(grid_img.cpu().numpy(), 'c h w -> h w c')
                    experiment.log_image(grid_img, name='train patches (odd rows -- pred, even rows -- gt)',
                                         step=step_counter[action])
                        
            elif action == Action.VALIDATE:
                experiment.log_metric("val_dice_loss", batch_loss.item(),
                                      epoch=epoch_idx, step=step_counter[action])
            step_counter[action] += 1
    
    epoch_losses = np.array(epoch_losses)
    avg_loss = epoch_losses.mean()
    
    if action == Action.TRAIN:
        experiment.log_metric("avg_train_dice_loss", avg_loss, step=epoch_idx, epoch=epoch_idx)
        if scheduler:
            scheduler.step(avg_loss)
    
    elif action == Action.VALIDATE:
        experiment.log_metric("avg_val_dice", 1 - avg_loss, step=epoch_idx, epoch=epoch_idx)
    
    print(f'{action.value} mean loss: {avg_loss:0.3f}')
    return avg_loss


## 3. Patch-based segmentation

Let's define another experiment within the same workspace in `COMET ML`:

In [None]:
experiment.set_name("6 classes, 4 encoding blocks, 8 out, Patch based, 64 batch, crop with images")
model_dir = './logs/6cls_4enc_8out_patch_64batch_crop'

In [None]:
patch_size = 64
samples_per_volume = 8
max_queue_length = 240
training_batch_size = 16
validation_batch_size = 4
num_training_workers = 8
num_validation_workers = 1

In [None]:
training_loader, validation_loader = get_loaders(training_set, validation_set)

In [None]:
torch.cuda.empty_cache()

In [None]:
model, optimizer, scheduler = get_model_and_optimizer(device)
weights_stem = '6_classes_4_blocks_8_chanels'


In [None]:
num_epochs = 10
    
train(num_epochs, training_loader, validation_set, model, optimizer, scheduler,
      weights_stem=weights_stem, patch_overlap=0)

In [None]:
# checking
enum = enumerate(tqdm(training_loader))
_, batch = next(enum)
inputs, targets = prepare_batch(batch, 'cuda:0')


plt.imshow(inputs.cpu()[2, 0, 32])
plt.show()
plt.imshow(targets.cpu()[2, 0, 32])
plt.show()

labels = model(inputs.float())
plt.imshow(labels.cpu().detach().numpy()[2, 0, 32])
plt.show()

print(targets.size(), labels.size())
dice_loss = DiceLoss(6)

dice_loss_, dice = dice_loss(labels, targets.squeeze(1), softmax=True)
print(targets.squeeze(1).size(), labels.size())

pred = labels.argmax(1,True) #([16, 1, 64, 64, 64])
plt.imshow(pred.cpu().numpy()[2,0, 32])
plt.show()

print(dice)

In [None]:
model.load_state_dict(torch.load(os.path.join(model_dir, f'model_{weights_stem}.pth'), map_location=device))
test_scores = evaluate(model, testing_set, patch_size=64, patch_overlap=0, epoch=0)
print(test_scores)
print(f"\nTesting mean score: DICE {np.mean(test_scores['dice']):0.3f}")  

In [None]:
experiment.log_metric("avg_test_dice", np.mean(test_scores['dice']))
for i, subject in enumerate(test_subjects):
    experiment.log_metric(f"test_subj_{subject}_dice", np.mean(test_scores['dice'][i]))

### Illustrate prediction for random sample from validation set

In [None]:
import nibabel as nib
sample = random.choice(validation_set)
input_tensor = sample[MRI][DATA][0]
patch_size = 64, 64, 64  # we can user larger or smaller patches for inference
patch_overlap = 20
grid_sampler = torchio.inference.GridSampler(
    sample,
    patch_size,
    patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(
    grid_sampler, batch_size=validation_batch_size)
aggregator = torchio.inference.GridAggregator(grid_sampler)

model.eval()
with torch.no_grad():
    for patches_batch in patch_loader:
        inputs = patches_batch[MRI][DATA].to(device)
#         print(inputs.unique())
        locations = patches_batch['location']
        logits = model(inputs.float())
        labels = logits.argmax(dim=1, keepdim=True)
        aggregator.add_batch(labels, locations)
        
plot_central_cuts(aggregator.get_output_tensor())

## Augmentation experiments

### #1

In [None]:
### MY ADDITION: STARTS HERE ###
my_transforms = [transforms.RandomGamma(log_gamma=(-0.3, 0.3)),
                 # TODO: check augmentations examples
                 transforms.RandomBiasField(),
                 transforms.RandomMotion(),
                 transforms.RandomGhosting(),
                 transforms.OneOf({transforms.RandomNoise(): 0.5, 
                                   transforms.RandomBlur(): 0.5})
                ] # AUGMENTATIONS FOR TRAINING

### MY ADDITION: ENDS HERE ###

In [None]:
plot_central_cuts(training_set[0][MRI][DATA], title='No transforms')
for t in my_transforms:
    plot_central_cuts(t(training_set[0][MRI][DATA]), title=str(t))

In [None]:
train_transform = transforms.Compose([
    transforms.Crop(crop),
    transforms.Pad(4),
    *my_transforms
])
val_transform = transforms.Compose([
    # several_transforms,
    transforms.Crop(crop),
    transforms.Pad(4)
])

training_set_augs, validation_set_augs, testing_set_augs = get_sets(training_subjects, validation_subjects, testing_subjects,
                                                                    train_transform=train_transform, val_transform=val_transform)


In [None]:
patch_size = 64
samples_per_volume = 8
max_queue_length = 240
training_batch_size = 16
validation_batch_size = 4
num_training_workers = 8
num_validation_workers = 1

In [None]:
training_loader, validation_loader = get_loaders(training_set_augs, validation_set_augs)

In [None]:
experiment.set_name("augmentations + 6 classes, 4 encoding blocks, 8 out, Patch based, 64 batch, crop with images")
model_dir = './logs/augmentations_6cls_4enc_8out_patch_64batch_crop'

In [None]:
torch.cuda.empty_cache()

In [None]:
model, optimizer, scheduler = get_model_and_optimizer(device)
weights_stem = 'aug_6_classes_4_blocks_8_chanels'

In [None]:
num_epochs = 50
    
train(num_epochs, training_loader, validation_set, model, optimizer, start_epoch=0, scheduler=scheduler,
      weights_stem=weights_stem, patch_overlap=0)

### #2

In [None]:
### MY ADDITION: STARTS HERE ###
my_transforms = [torchio.transforms.RandomFlip(axes=(0, 1, 2), flip_probability=0.5),
                 torchio.transforms.RandomAffine(),
                 transforms.RandomGamma(log_gamma=(-0.3, 0.3)),
                 # TODO: check augmentations examples
                 transforms.RandomBiasField(),
                 transforms.RandomMotion(),
                 transforms.RandomGhosting(),
                 transforms.RandomNoise(mean=0, std=(0, 3)),
                 transforms.RandomBlur(std=(0, 1))
                ] # AUGMENTATIONS FOR TRAINING

### MY ADDITION: ENDS HERE ###

In [None]:
plot_central_cuts(training_set[0][MRI][DATA], title='No transforms')
for t in my_transforms:
    plot_central_cuts(t(training_set[0][MRI][DATA]), title=str(t))

In [None]:
train_transform = transforms.Compose([
    transforms.Crop(crop),
    transforms.Pad(4),
    *my_transforms
])
val_transform = transforms.Compose([
    # several_transforms,
    transforms.Crop(crop),
    transforms.Pad(4)
])

training_set_augs, validation_set_augs, testing_set_augs = get_sets(training_subjects, validation_subjects, testing_subjects,
                                                                    train_transform=train_transform, val_transform=val_transform)


In [None]:
patch_size = 64
samples_per_volume = 8
max_queue_length = 240
training_batch_size = 16
validation_batch_size = 4
num_training_workers = 8
num_validation_workers = 1

In [None]:
training_loader, validation_loader = get_loaders(training_set_augs, validation_set_augs)

In [None]:
experiment.set_name("second augmentations + 6 classes, 4 encoding blocks, 32 out, Patch based, 64 batch, crop with images")
model_dir = './logs/second_augmentations_6cls_4enc_8out_patch_64batch_crop'

In [None]:
torch.cuda.empty_cache()

In [None]:
model, optimizer, scheduler = get_model_and_optimizer(device, num_encoding_blocks=4, out_channels_first_layer=32)
weights_stem = 'second_aug_6_classes_4_blocks_32_chanels'

In [None]:
num_epochs = 50
    
train(num_epochs, training_loader, validation_set, model, optimizer, start_epoch=0, scheduler=scheduler,
      weights_stem=weights_stem, patch_overlap=0)

## Losses experiments

In [None]:
def focal_loss(inputs, target, alpha, gamma):
    if len(target.shape) == 5:
        tg = target[0,0]
    elif len(target.shape) == 4:
        tg == target[0]
    else:
        tg = target
        
    y_true = F.one_hot(tg.to(torch.int64), CHANNELS_DIMENSION).permute(3, 0, 1, 2) # 6xHxWxD
    y_pred = torch.clamp(F.softmax(inputs, dim=0), self.epsilon, 1. - self.epsilon) # 6xHxWxD
    pt = (y_true * y_pred).sum(dim=0)

    loss = -1 * alpha * torch.pow((1 - pt), gamma) * torch.log(pt)
    return loss


class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=1, epsilon=1e-5):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        
    def forward(self, inputs, target):      
        if self.alpha is None:
            alpha = torch.ones(CHANNELS_DIMENSION, 1)
        if isinstance(self.alpha, float):
            alpha = torch.ones(CHANNELS_DIMENSION, 1)*self.alpha
        else:
            alpha = self.alpha
        alpha = torch.squeeze(alpha[target.to(torch.int64)[0]])
        
        loss = focal_loss(inputs, target, alpha, self.gamma)
        
        return loss.mean(), None # classwise is not computed
    
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=1):
        super(AssymetricLoss, self).__init__()
        self.assymetric = assymetric
        self.weight = weight
        
        self.alpha = alpha
        self.gamma = gamma

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _tversky_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target) # TP
        # dice:
        # y_sum = torch.sum(target * target) # TP + FN
        # z_sum = torch.sum(score * score) # TP + FP
        # loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        # tversky: 2TP / (2TP + alpha*FP + gamma*FN)
        fp = torch.sum(score*(1 - target))
        fn = torch.sum((1 - score)*target)
        
        loss = (2*intersect + smooth) / (2*intersect + self.alpha*fp + (1-self.alpha)*fn + smooth)
        
        loss = torch.pow((1 - loss), self.gamma)
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        class_wise_tv = []
        tv_loss = 0.0
        for i in range(0, self.n_classes):
            tv = self._tversky_loss(inputs[:, i], target[:, i])
            class_wise_tv.append(1 - tv.item())
            tv_loss += tv * weight[i]
        
        return tv_loss / self.n_classes, class_wise_tv

In [None]:
loss = FocalLoss(alpha=torch.Tensor(range(2, 8)))

In [None]:
model, optimizer, scheduler = get_model_and_optimizer('cpu', num_encoding_blocks=4, out_channels_first_layer=32)

In [None]:
model.load_state_dict(torch.load('./logs/second_augmentations_6cls_4enc_8out_patch_64batch_crop/model_second_aug_6_classes_4_blocks_32_chanels.pth'))

In [None]:
import nibabel as nib
sample = random.choice(validation_set)
input_tensor = sample[MRI][DATA][0]
targets = torch.from_numpy(prepare_aseg(sample[LABEL][DATA]))
patch_size = 64, 64, 64  # we can user larger or smaller patches for inference
patch_overlap = 20
grid_sampler = torchio.inference.GridSampler(
    sample,
    patch_size,
    patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(
    grid_sampler, batch_size=16)
aggregator = torchio.inference.GridAggregator(grid_sampler)

model.eval()
with torch.no_grad():
    for patches_batch in tqdm(patch_loader):
        inputs = patches_batch[MRI][DATA].to('cpu')
#         print(inputs.unique())
        locations = patches_batch['location']
        logits = model(inputs.float())
#         labels = logits.argmax(dim=1, keepdim=True)
        aggregator.add_batch(logits, locations)
        
output = aggregator.get_output_tensor()

In [None]:
i = targets.shape[1]//4
plt.imshow(targets[0][i, ...])
plt.colorbar()

In [None]:
plt.imshow(output.argmax(dim=0)[i, ...])
plt.colorbar()

In [None]:
loss(output, targets.unsqueeze(0))

In [None]:
experiment.set_name("focal loss + 6 classes, 4 encoding blocks, 32 out, Patch based, 64 batch, crop with images")
model_dir = './logs/focal loss_6cls_4enc_8out_patch_64batch_crop'

Yes, you get a full solution it your hands, yet keep in mind - it is just your baseline. You should experiment with the augmentations, losses and build a story of your model development.