# 3D Segmentation of Brain Tumor


## Loading the Dataset
The dataset used is the BraTS 2020 Dataset.

The Dataset Contains the Following Scans per case:

T1: T1-weighted, native image, sagittal or axial 2D acquisitions, with 1–6 mm slice thickness.

T1c: T1-weighted, contrast-enhanced (Gadolinium) image, with 3D acquisition and 1 mm isotropic voxel size for most patients.

T2: T2-weighted image, axial 2D acquisition, with 2–6 mm slice thickness.

FLAIR: T2-weighted FLAIR image, axial, coronal, or sagittal 2D acquisitions, 2–6 mm slice thickness.



## Importing Required Libraries.

**Essential**
1. [Pytorch with CuDA](https://pytorch.org/get-started/locally/) 
1. Monai (all or with nibabel) ex: `pip install 'monai[all]'`
1. Pandas
1. Numpy
1. Scikit Learn

*Note: If any libraries are missing during run time, read the error message to install them*


In [1]:
# import matplotlib.pyplot as plt
from os import path, listdir
import os
import monai
from monai.data import (Dataset, list_data_collate, DataLoader, decollate_batch, PersistentDataset)
from monai.transforms import (
    LoadImaged,
    Compose,
    MapTransform,
    Orientationd,
    ToMetaTensord,  # ? same as ensuretyped?
    EnsureChannelFirstd,
    NormalizeIntensityd,
    Spacingd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    Activations,
    AsDiscrete,
)

from monai.networks.nets import  UNet
from monai.losses import  DiceCELoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.inferers import sliding_window_inference
from monai.optimizers import Novograd, WarmupCosineSchedule

from torch.utils.tensorboard import SummaryWriter

import torch

from tqdm import tqdm

Seeding to ensure reproducability

In [2]:
monai.utils.misc.set_determinism(1607)

Optional Function to call to free allocated GPU memory in between trainining different models as PyTorch does not free it

In [3]:
import gc

def clear_gpu_cache():
    """Clear the PyTorch GPU Allocation if an OOM error occurs.
    """

    try:
        print("Deleting Model")
        global model
        del model
    except NameError as e:
        print(f"Model Already Cleared")

    print("Collecting Garbage")
    gc.collect()
    print("Clearing CUDA Cache")
    torch.cuda.empty_cache()
    print("Done")


## Preparing the Dataset

As per [University of Pennsylvania](https://www.med.upenn.edu/cbica/brats2020/data.html):

All BraTS multimodal scans are available as NIfTI files (.nii.gz) and describe a) native (T1) and b) post-contrast T1-weighted (T1Gd), c) T2-weighted (T2), and d) T2 Fluid Attenuated Inversion Recovery (T2-FLAIR) volumes, and were acquired with different clinical protocols and various scanners from multiple (n=19) institutions, mentioned as data contributors here.

All the imaging datasets have been segmented manually, by one to four raters, following the same annotation protocol, and their annotations were approved by experienced neuro-radiologists. Annotations comprise the GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2), and the necrotic and non-enhancing tumor core (NCR/NET — label 1), as described both in the BraTS 2012-2013 TMI paper(opens in a new window) and in the latest BraTS summarizing paper. The provided data are distributed after their pre-processing, i.e., *co-registered to the same anatomical template*, interpolated to the *same resolution (1 mm^3)* and *skull-stripped*.

Below is a function to randomly split the dataset into training and testing, however ensuring that the ratio of High Grade to Low Grade in both split are the same

In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

def encode_lgg_hgg(x):
    """Encode LGG and HGG as 0 and 1 for stratification

    Args:
        x (str): LGG or HGG in Dataframe

    Returns:
        int: encodes 0 for LGG and 1 for HGG
    """
    return 0 if x == "LGG" else 1


def train_val_test_dataset(data_path: str):
    """From 100% Cases take 20% cases as Validation.
    Take the remaining 80% cases as training

    Stratification done on data to ensure that the classes are balanced.

    Args:
        data_path (str, optional): Path to Name Mapping File provided by BraTS.

    Returns:
        training, validation, testing: list of case names split into training, validation and testing.
    """
    data = pd.read_csv(data_path)
    data = data[["Grade", "BraTS_2020_subject_ID"]]
    data.Grade = data["Grade"].map(encode_lgg_hgg)
    (training, validation, train_check, val_check,) = train_test_split(
        data.BraTS_2020_subject_ID.to_list(),
        data.Grade.to_numpy(),
        test_size=0.2,
        random_state=42,
        stratify=data.Grade.to_numpy(),
        shuffle=True,
    )

    print(
        f"""
    Total Samples = {len(training)+len(validation)}\n
    Ratio of LGG:HGG in {len(training)} Training Samples:
    \t Ratio = {np.count_nonzero(train_check==0)/np.count_nonzero(train_check==1):.2f}\n

    Ratio of LGG:HGG in {len(validation)} Validation Samples:
    \t Ratio = {np.count_nonzero(val_check==0)/np.count_nonzero(val_check==1):.2f}\n
    
    """
    )

    return (training, validation)

In [5]:
# Prepare list of All training Cases
TRAINING_DATASET_PATH = r"./MICCAI_BraTS2020_TrainingData/"
NAME_MAPPING = r"./MICCAI_BraTS2020_TrainingData/name_mapping.csv"

# Function returns names of cases to be used
train_cases, val_cases = train_val_test_dataset(NAME_MAPPING)


    Total Samples = 369

    Ratio of LGG:HGG in 295 Training Samples:
    	 Ratio = 0.26


    Ratio of LGG:HGG in 74 Validation Samples:
    	 Ratio = 0.25

    
    


Prepare mapping to convert into PyTorch Dataset

In [6]:
train_cases = [
    {
        "image": [
            path.join(TRAINING_DATASET_PATH, case, f"{case}_t1.nii.gz"),
            path.join(TRAINING_DATASET_PATH, case, f"{case}_t1ce.nii.gz"),
            path.join(TRAINING_DATASET_PATH, case, f"{case}_t2.nii.gz"),
            path.join(TRAINING_DATASET_PATH, case, f"{case}_flair.nii.gz"),
        ],
        "seg": path.join(
            TRAINING_DATASET_PATH, case, f"{case}_seg.nii.gz"
        ),
    }
    for case in train_cases
]

val_cases = [
    {
        "image": [
            path.join(TRAINING_DATASET_PATH, case, f"{case}_t1.nii.gz"),
            path.join(TRAINING_DATASET_PATH, case, f"{case}_t1ce.nii.gz"),
            path.join(TRAINING_DATASET_PATH, case, f"{case}_t2.nii.gz"),
            path.join(TRAINING_DATASET_PATH, case, f"{case}_flair.nii.gz"),
        ],
        "seg": path.join(
            TRAINING_DATASET_PATH, case, f"{case}_seg.nii.gz"
        ),
    }
    for case in val_cases
]


The input to the Data Loader will be a dictionary mapping the names of where the scans are. 

In [7]:
train_cases[:1]

[{'image': ['./MICCAI_BraTS2020_TrainingData/BraTS20_Training_219\\BraTS20_Training_219_t1.nii.gz',
   './MICCAI_BraTS2020_TrainingData/BraTS20_Training_219\\BraTS20_Training_219_t1ce.nii.gz',
   './MICCAI_BraTS2020_TrainingData/BraTS20_Training_219\\BraTS20_Training_219_t2.nii.gz',
   './MICCAI_BraTS2020_TrainingData/BraTS20_Training_219\\BraTS20_Training_219_flair.nii.gz'],
  'seg': './MICCAI_BraTS2020_TrainingData/BraTS20_Training_219\\BraTS20_Training_219_seg.nii.gz'}]

Since the goal is to segment Whole Tumor, Tumor Core and Enhancing Tumor, the segmentation Niftii file has a special transformation applied to it as seen in the function below

In [8]:
class ConvertLabelsIntoOneHotd(MapTransform):
    """Takes input tensor of segmentation which contains
    values in set (0,1,2,4) where\n
    0 -> Background/Normal\n
    1 -> Non- Enhancing Tumor Core\n
    2 -> Edema\n
    4 -> Enhancing tumor core\n

    and returns a one hot encoded 3 channel tensor where
    1st Channel -> Whole tumor (1,2 and 4)
    2nd Channel -> Tumor Core (1 and 4)
    3rd Channel -> Enhancing Tumor core (4)
    """

    def __call__(self, data):
        data_dict = dict(data)
        for key in self.keys:
            one_hot_encode_array = [
                torch.logical_or(
                    torch.logical_or(data_dict[key] == 1, data_dict[key] == 2),
                    data_dict[key] == 4,
                ),  # Whole Tumor
                torch.logical_or(data_dict[key] == 1, data_dict[key] == 4),  # Tumor Core
                data_dict[key] == 4,  # Enhancing Core
                
            ]
        data_dict[key] = torch.stack(one_hot_encode_array, axis=0).astype(torch.float32)
        return data_dict


Transformations during training and later during validation are declared below:

In [9]:
transform_validation_dataset = Compose(
    [
        LoadImaged(keys=["image", "seg"]),
        EnsureChannelFirstd(keys=["image"]),
        ConvertLabelsIntoOneHotd(keys="seg"),
        ToMetaTensord(["image", "seg"]),
        Orientationd(keys=["image", "seg"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

val_data_loader = DataLoader(
    Dataset(
        val_cases, transform_validation_dataset,
    ),
    shuffle=False,
    batch_size=1,
)

transform_training_dataset = Compose(
        [
            LoadImaged(keys=["image", "seg"]),
            EnsureChannelFirstd(keys=["image"]),
            ConvertLabelsIntoOneHotd(keys="seg"),
            ToMetaTensord(["image", "seg"]),
            Orientationd(keys=["image", "seg"], axcodes="RAS"),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            RandCropByPosNegLabeld(
                keys=["image", "seg"],
                spatial_size=(128, 128, 128),
                label_key="seg",
                neg=0,
                num_samples=2,
            ),
            RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=0),
            RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=1),
            RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=2),
            RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        ]
    )

train_data_loader = DataLoader(
    Dataset(
        train_cases,
        transform_training_dataset,
    ),
    shuffle=True,
    batch_size=2,
)



Some notes about the above transformations
- During validation the data transformation stops upto normalisation. It does not include any random cropping.
- During Training, in each epoch, two random sample of size (128,128,128) is taken from each input.
- Batch size or the number of samples maybe increased or asjusted as per system resources available.

A TensorBoard Logger is instantiated to log metrics and other features during training

In [10]:
logger = SummaryWriter(log_dir="./logs")

In [13]:
clear_gpu_cache()

Deleting Model
Model Already Cleared
Collecting Garbage
Clearing CUDA Cache
Done


Training Period is set to 100

In [11]:
epochs = 100

Validation is set to occur at every 2 epochs

In [12]:
val_interval = 2

The training device is auto set to GPU or CPU depending on availability. Note that certain tasks are coded to use CPU, such as during sliding window inference as my system cannot handle it

In [13]:
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"

The UNet Model is instantiated from MONAI

In [14]:
model = UNet(
    spatial_dims=3,
    in_channels=4,
    out_channels=3,
    strides=(2, 2, 2),
    channels=[16,32,64,128],
    num_res_units=2,
).to(device)

The optimiser is initialised. I use Novograd, but you may use any other as you see fit

In [15]:
optimizer = Novograd(model.parameters(), weight_decay=0.0001)

A learning rate scheduler is used where the first 10% of epochs are used to warmup

In [16]:
lr_scheduler = WarmupCosineSchedule(
    optimizer,
    warmup_steps=int(epochs / 10),
    warmup_multiplier=1e-10,
    t_total=epochs,
)

Loss Function is initialised

In [17]:
loss_function = DiceCELoss(sigmoid=True, squared_pred=True)

PyTorch Automatic Mixed Precision is used to increase training speed and reduce memory consumption. 

In [18]:
scaler = torch.cuda.amp.GradScaler()

The function for inference is defined

In [19]:
def inference(input):
    """Do Sliding Window Inference on input tensor
    To avoid OOM Error, Input Model done on CPU.
    Patch taken from input and its inference done on GPU
    to speed up inference time.

    Args:
        input: Full input to pass in the model. For the case
        of this project size => (3,240,240,155)
    """

    def _compute(input):
        return sliding_window_inference(
            inputs=input.to("cpu"),
            roi_size=(128,128,128),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
            padding_mode="constant",
            sw_device="cuda:0",
            device="cpu",
            mode="constant",
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

During Inference the output from the model undergoes post processing. Namely,
- Sigmoid Activation
- Convert the sigmoid values to Discrete 0 or 1 based on threshold value. Can adjust

In [20]:
post_processing_validation = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold=0.5)]
)

The DICE and Hausdaurff metric are prepared. Other metrics may be instantiated as needed. 
Metric with _batch are reduced (mean) on the Batch Channel [Batch,Channel,Dims..] and later aggregated on each channel

In [21]:
dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

hausdorff_metric = HausdorffDistanceMetric(
    include_background=True, distance_metric='euclidean',
    reduction="mean"
)

hausdorff_metric_batch = HausdorffDistanceMetric(
    include_background=True, distance_metric='euclidean',
    reduction="mean_batch"
)

Optionally, If you already have a checkpoint file (model weights from a previous session). You may load them to continue training

In [None]:
# try:
#     chk = torch.load("nameoffile.pth")
#     model.load_state_dict(chk['model'])
#     optimizer.load_state_dict(chk['optimiser'])
# except Exception as e:
#     print(e)

Initialise a global variable to keep track of best_metric score so that the model state can be saved

In [23]:
best_metric = -1

Train the Model for Epoch

In [None]:
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    step = 0

    logger.add_scalar("Learning_Rate", optimizer.param_groups[0]["lr"], epoch)

    for batch_data in tqdm(train_data_loader):

        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["seg"].to(device),
        )

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        step += 1
        epoch_loss += loss.item()

    epoch_loss_value = epoch_loss / step

    lr_scheduler.step()

    logger.add_scalar("Training/Loss", epoch_loss_value, epoch)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in tqdm(val_data_loader):
                val_inputs, val_labels = (
                    val_data["image"],
                    val_data["seg"].to(device),
                )

                val_outputs = inference(val_inputs)

                val_outputs = [
                    post_processing_validation(i)
                    for i in decollate_batch(val_outputs.to(device))
                ]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()

            metric_batch = dice_metric_batch.aggregate()
            metric_wt = metric_batch[0].item()
            metric_tc = metric_batch[1].item()
            metric_et = metric_batch[2].item()


            hausdorff_avg = hausdorff_metric.aggregate().item()
            hausdorff_batch = hausdorff_metric_batch.aggregate()
            hausdorff_wt = hausdorff_batch[0].item()
            hausdorff_tc = hausdorff_batch[1].item()
            hausdorff_et = hausdorff_batch[2].item()


            

            if metric > best_metric:
                best_metric = metric
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimiser": optimizer.state_dict(),
                        "scheduler": lr_scheduler.state_dict(),
                        "best_metric_epoch": epoch,
                        "best_metric": best_metric
                    },
                    f"./best_model_DICE_{int(metric*100)}.pth",
                )

            logger.add_scalar("DICE/Average", metric, epoch)
            logger.add_scalar("DICE/WT", metric_wt, epoch)
            logger.add_scalar("DICE/TC", metric_tc, epoch)
            logger.add_scalar("DICE/ET", metric_et, epoch)

            logger.add_scalar("Hausdaurff/Average", hausdorff_avg, epoch)
            logger.add_scalar("Hausdaurff/WT", hausdorff_wt, epoch)
            logger.add_scalar("Hausdaurff/TC", hausdorff_tc, epoch)
            logger.add_scalar("Hausdaurff/ET", hausdorff_et, epoch)


            dice_metric.reset()
            dice_metric_batch.reset()
            hausdorff_metric.reset()
            hausdorff_metric_batch.reset()


For Inference and Analysis from the Model Refer to the [Analysis Notebook](Analysis.ipynb)