# **MS LESION SEGMENTATION TRAINING** 

## Install Libraries

In [None]:
!pip install monai==0.9.0
!pip install einops
!pip install torch
!pip install tqdm
!pip install scipy

## Libraries Import

In [None]:
import torch
from torch import nn
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss
from monai.networks.nets import SegResNet
import numpy as np
import random
import os
from glob import glob
import re
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    AddChanneld, Compose, LoadImaged, RandCropByPosNegLabeld,
    Spacingd, ToTensord, NormalizeIntensityd, RandFlipd,
    RandRotate90d, RandShiftIntensityd, RandAffined, RandSpatialCropd,
    RandScaleIntensityd)
from scipy import ndimage

## Setup Functions

In [3]:
def get_default_device():
    """ Set device """
    if torch.cuda.is_available():
        print("Got CUDA!")
        return torch.device('cuda')
    else:
        return torch.device('cpu')

In [4]:
def dice_metric(ground_truth, predictions):
    """
    Compute Dice coefficient for a single example.
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [W, H, D].
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [W, H, D].
    Returns:
      Dice coefficient overlap (`float` in [0.0, 1.0])
      between `ground_truth` and `predictions`.
    """
    # Calculate intersection and union of y_true and y_predict
    intersection = np.sum(predictions * ground_truth)
    union = np.sum(predictions) + np.sum(ground_truth)

    # Calcualte dice metric
    if intersection == 0.0 and union == 0.0:
        dice = 1.0
    else:
        dice = (2. * intersection) / union

    return dice

### Data Load Functions

In [5]:
def get_train_transforms():
    """ Get transforms for training on FLAIR images and ground truth:
    - Loads 3D images from Nifti file
    - Adds channel dimention
    - Normalises intensity
    - Applies augmentations
    - Crops out 32 patches of shape [96, 96, 96] that contain lesions
    - Converts to torch.Tensor()
    """
    return Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            NormalizeIntensityd(keys=["image"], nonzero=True),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
            RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
            RandCropByPosNegLabeld(keys=["image", "label"],
                                   label_key="label", image_key="image",
                                   spatial_size=(128, 128, 128), num_samples=32,
                                   pos=4, neg=1),
            RandSpatialCropd(keys=["image", "label"],
                             roi_size=(96, 96, 96),
                             random_center=True, random_size=False),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=(0, 1, 2)),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(0, 1)),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(0, 2)),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'),
                        prob=1.0, spatial_size=(96, 96, 96),
                        rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                        scale_range=(0.1, 0.1, 0.1), padding_mode='border'),
            ToTensord(keys=["image", "label"]),
        ]
    )

In [6]:
def get_val_transforms(keys=["image", "label"], image_keys=["image"]):
    """ Get transforms for testing on FLAIR images and ground truth:
    - Loads 3D images and masks from Nifti file
    - Adds channel dimention
    - Applies intensity normalisation to scans
    - Converts to torch.Tensor()
    """
    return Compose(
        [
            LoadImaged(keys=keys),
            AddChanneld(keys=keys),
            NormalizeIntensityd(keys=image_keys, nonzero=True),
            ToTensord(keys=keys),
        ]
    )

In [7]:
def get_train_dataloader(flair_path, gts_path, num_workers, cache_rate=0.1):
    """
    Get dataloader for training 
    Args:
      flair_path: `str`, path to directory with FLAIR images from Train set.
      gts_path:  `str`, path to directory with ground truth lesion segmentation 
                    binary masks images from Train set.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted
    segs = sorted(glob(os.path.join(gts_path, "*gt_isovox.nii")),
                  key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding ground truths

    files = [{"image": fl, "label": seg} for fl, seg in zip(flair, segs)]

    print("Number of training files:", len(files))

    ds = CacheDataset(data=files, transform=get_train_transforms(),
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=1, shuffle=True,
                      num_workers=num_workers)

In [8]:
def get_val_dataloader(flair_path, gts_path, num_workers, cache_rate=0.1, bm_path=None):
    """
    Get dataloader for validation and testing. Either with or without brain masks.

    Args:
      flair_path: `str`, path to directory with FLAIR images.
      gts_path:  `str`, path to directory with ground truth lesion segmentation 
                    binary masks images.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
      bm_path:   `None|str`. If `str`, then defines path to directory with
                 brain masks. If `None`, dataloader does not return brain masks. 
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted
    segs = sorted(glob(os.path.join(gts_path, "*_isovox.nii")),
                  key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding ground truths

    if bm_path is not None:
        bms = sorted(glob(os.path.join(bm_path, "*isovox_fg_mask.nii")),
                     key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding brain masks

        assert len(flair) == len(segs) == len(bms), f"Some files must be missing: {[len(flair), len(segs), len(bms)]}"

        files = [
            {"image": fl, "label": seg, "brain_mask": bm} for fl, seg, bm
            in zip(flair, segs, bms)
        ]

        val_transforms = get_val_transforms(keys=["image", "label", "brain_mask"])
    else:
        assert len(flair) == len(segs), f"Some files must be missing: {[len(flair), len(segs)]}"

        files = [{"image": fl, "label": seg} for fl, seg in zip(flair, segs)]

        val_transforms = get_val_transforms()

    print("Number of validation files:", len(files))

    ds = CacheDataset(data=files, transform=val_transforms,
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=1, shuffle=False,
                      num_workers=num_workers)


In [9]:
def get_flair_dataloader(flair_path, num_workers, cache_rate=0.1, bm_path=None):
    """
    Get dataloader with FLAIR images only for inference
    
    Args:
      flair_path: `str`, path to directory with FLAIR images from Train set.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
      bm_path:   `None|str`. If `str`, then defines path to directory with
                 brain masks. If `None`, dataloader does not return brain masks.
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted

    if bm_path is not None:
        bms = sorted(glob(os.path.join(bm_path, "*isovox_fg_mask.nii")),
                     key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding brain masks

        assert len(flair) == len(bms), f"Some files must be missing: {[len(flair), len(bms)]}"

        files = [{"image": fl, "brain_mask": bm} for fl, bm in zip(flair, bms)]

        val_transforms = get_val_transforms(keys=["image", "brain_mask"])
    else:
        files = [{"image": fl} for fl in flair]

        val_transforms = get_val_transforms(keys=["image"])

    print("Number of FLAIR files:", len(files))

    ds = CacheDataset(data=files, transform=val_transforms,
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=1, shuffle=False,
                      num_workers=num_workers)

In [10]:
def remove_connected_components(segmentation, l_min=9):
    """
    Remove all lesions with less or equal amount of voxels than `l_min` from a 
    binary segmentation mask `segmentation`.
    Args:
      segmentation: `numpy.ndarray` of shape [H, W, D], with a binary lesions segmentation mask.
      l_min:  `int`, minimal amount of voxels in a lesion.
    Returns:
      Binary lesion segmentation mask (`numpy.ndarray` of shape [H, W, D])
      only with connected components that have more than `l_min` voxels.
    """
    labeled_seg, num_labels = ndimage.label(segmentation)
    label_list = np.unique(labeled_seg)
    num_elements_by_lesion = ndimage.labeled_comprehension(segmentation, labeled_seg, label_list, np.sum, float, 0)

    seg2 = np.zeros_like(segmentation)
    for i_el, n_el in enumerate(num_elements_by_lesion):
        if n_el > l_min:
            current_voxels = np.stack(np.where(labeled_seg == i_el), axis=1)
            seg2[current_voxels[:, 0],
                 current_voxels[:, 1],
                 current_voxels[:, 2]] = 1
    return seg2

## Training Function

In [13]:
"""
Training function for UNet model with Dice and Focal loss.
Args:
    Data Args:
        path_train_data: `str`, path to directory with FLAIR images from Train set.
        path_train_gts: `str`, path to directory with ground truth lesion segmentation
        path_val_data: `str`, path to directory with FLAIR images from Validation set.
        path_val_gts: `str`, path to directory with ground truth lesion segmentation
    Model Args:
        learning_rate: `float`, learning rate for the optimizer.
        n_epochs: `int`, number of epochs to train the model.
        seed: `int`, random seed for reproducibility.
        threshold: `float`, probability threshold for binarization of the output.
        num_workers: `int`, number of worker threads to use for parallel processing
        path_save: `str`, path to directory where to save the best model.
        val_interval: `int`, interval for validation.
"""
def trainSegResNet(path_train_data, path_train_gts, path_val_data, path_val_gts, learning_rate = 1e-5, n_epochs = 300, seed = 42, threshold = 0.4, num_workers = 10, path_save = '', val_interval = 5):
    # setting up the seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # setting up the device
    device = get_default_device()
    torch.multiprocessing.set_sharing_strategy('file_system')
    
    # Initialise dataloaders
    train_loader = get_train_dataloader(flair_path=path_train_data, 
                                        gts_path=path_train_gts, 
                                        num_workers=num_workers)
    val_loader = get_val_dataloader(flair_path=path_val_data, 
                                    gts_path=path_val_gts, 
                                    num_workers=num_workers)
    # Initialise the model
    model = SegResNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2).to(device)
    
    loss_function = DiceLoss(to_onehot_y=True, 
                             softmax=True, sigmoid=False,
                             include_background=False)
    optimizer = torch.optim.Adam(model.parameters(), learning_rate)
    act = nn.Softmax(dim=1)
    
    epoch_num = n_epochs
    val_interval = val_interval
    thresh = threshold
    gamma_focal = 2.0
    dice_weight = 0.5
    focal_weight = 1.0
    roi_size = (96, 96, 96)
    sw_batch_size = 4
    
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values, metric_values = [], []
    
    # Training Loop
    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            n_samples = batch_data["image"].size(0)
            for m in range(0,batch_data["image"].size(0), 2):
                step += 2
                inputs, labels = (
                    batch_data["image"][m:(m+2)].to(device),
                    batch_data["label"][m:(m+2)].type(torch.LongTensor).to(device))
                optimizer.zero_grad()
                outputs = model(inputs)
                
                # Dice loss
                loss1 = loss_function(outputs, labels)
                # Focal loss
                ce_loss = nn.CrossEntropyLoss(reduction='none')
                ce = ce_loss(outputs, torch.squeeze(labels, dim=1))
                pt = torch.exp(-ce)
                loss2 = (1 - pt)**gamma_focal * ce 
                loss2 = torch.mean(loss2)
                loss = dice_weight * loss1 + focal_weight * loss2              
                
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                if step % 100 == 0:
                    step_print = int(step/2)
                    print(f"{step_print}/{(len(train_loader)*n_samples) // (train_loader.batch_size*2)}, train_loss: {loss.item():.4f}")

        epoch_loss /= step_print
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        
        # Validation
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device)
                        )
                    
                    val_outputs = sliding_window_inference(val_inputs, roi_size, 
                                                           sw_batch_size, 
                                                           model, mode='gaussian')
                   
                    gt = np.squeeze(val_labels.cpu().numpy())
                    
                    seg = act(val_outputs).cpu().numpy()
                    seg= np.squeeze(seg[0,1])
                    seg[seg >= thresh] = 1
                    seg[seg < thresh] = 0
                    
                    value = dice_metric(ground_truth=gt.flatten(), predictions=seg.flatten())

                    metric_count += 1
                    metric_sum += value.sum().item()
                
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(path_save, "Best_model_finetuning_"+str(best_metric_epoch)+"_epoch.pth"))
                    print("saved new best metric model")
                print(f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                                    )


## Train The Model

In [None]:
train_data = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Train/FLAIR'
train_gts = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Train/GroundTruth'
val_data = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Val/FLAIR'
val_data_gts = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Val/GroundTruth'

trainSegResNet(train_data, train_gts, val_data, val_data_gts, learning_rate = 1e-5, n_epochs = 100, seed = 42, threshold = 0.4, num_workers = 4, path_save = '/kaggle/working/', val_interval = 5)