# Advanced Image Processing (TM11005)
## *Week 1: Segmentation, Exercise 3: Deep Learning*
*Author: Karin van Garderen*
### Spleen 3D segmentation with MONAI

This exercise was adapted from the [MONAI tutorials on Spleen segmentation](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb).
MONAI is an open-source framework for deep learning in medical imaging. It contains a wide range of popular network architectures, loss functions and functions for loading and transforming images. This tutorial will make you familiar with the basics of deep learning for medical image segmentation using MONAI.

The Spleen dataset can be downloaded from http://medicaldecathlon.com/.

![spleen](http://medicaldecathlon.com/img/spleen0.png)

Target: Spleen  
Modality: CT  
Size: 61 3D volumes (41 Training + 20 Testing)  
Source: Memorial Sloan Kettering Cancer Center 

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MStarmans91/AIP-exercises/blob/cnn_exercise/Week1_Segmentation/Week1.4_CNN/spleen_segmentation_3d.ipynb)

## **Before you start**

The CNN's used in this tutorial run much faster with the right hardware. It is advised to run this notebook on Google Colab with GPU accelaration. Make sure you turn on GPU acceleration on Colab by going to Runtime -> Change Runtime type

## **Handing in your answers**
For each exercise, you have to hand in answers to questions, and
for some also the code. Hence please only hand in two files in total
for this exercise set:

- Code.py (or .ipynb): a Python script / jupyter notebook.
- Answers.docx (or .PDF): a text file with the answers to the
questions (plots, text, ...).

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Setup environment

In [None]:
## No need to edit this
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [None]:
## No need to edit this

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    SqueezeDimd,
    CenterSpatialCropd,
    DivisiblePadd
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
## No need to edit this

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Download dataset

Downloads and extracts the dataset.  
The dataset comes from http://medicaldecathlon.com/.

In [None]:
## No need to edit this

resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

## Set MSD Spleen dataset path
First, we will set up the dataset and split split it in a training and test set. For each patient, we have a CT image and a label image in the form of two files. We will store them as a dictionary per patient, referencing the files we just downloaded. The last 9 patients will form the test set.

In [None]:
## No need to edit this

train_images = sorted(
    glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
print(data_dicts)

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

## Setup transforms for training and validation

Here we use several transforms to make sure the data is optimal for training. 
The following transforms will be applied:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `AddChanneld` as the original data is 3D, add 1 dimension to serve as 'channels', even though this dataset only has one channel.
1. `Orientationd` unifies the data orientation based on the affine matrix defined in the NIfTI format.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `EnsureTyped` converts the numpy array to PyTorch Tensor for further steps.

By working with the dictionary format, we can make sure that MONAI performs these operations on both the 'image' and 'label' file. The transforms are also applied to the validation set, except here balanced random cropping is not applied.

In [None]:
## No need to edit this for now, unless you are working on exercise 3

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

## Check transforms in DataLoader

In [None]:
## No need to edit this

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()

## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  

In [None]:
## No need to edit this

train_ds = CacheDataset(
    data=train_files, transform=train_transforms,
    cache_rate=1.0, num_workers=2)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 4 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)

val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2)

## Exercise 1: execute a PyTorch training process and track the loss

- **Hand-In Code**: The training loop edited below.


This is the main training loop. For *n* epochs, we will pass the data through the model and optimize the weights using the loss function. The main loop is already there, but your task is to write the code for the actual training. 
Additionally, we would like to compute the loss on the validation set in order to track whether the model is overfitting.


We start by using a 3D Unet, which we keep small enough to train it quickly.
We use PyTorch for backpropagation and optimization of the network, so you can find all the information you need in their documentation. Specifically: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#optimizing-the-model-parameters

To give you a hint, you will need to implement the following steps (except in a different order):
- Backpropagation of the loss through the network
- Pass the data through the model to get the output
- Take one step using the optimizer
- Compute the loss
- Set the gradients to zero
- Compute the loss for the validation set

In [None]:
## Program your part in the indicated spaces

# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(4, 16, 32, 64),
    strides=(2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")

## Initialize lists to track metrics, set number of epochs and interval for validation.
max_epochs = 10
val_interval = 1
epoch_loss_values = []
val_loss_values = []
metric_values = []
## Transforms to use after prediction
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    epoch_val_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1

        ## Transfer data to the GPU
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )

        #### START PROGRAMMING HERE
        
        loss = ..
        #### END PROGRAMMING HERE

        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    ## For monitoring, compute the performance on the validation set
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160, 160, 160)
                sw_batch_size = 4

                ## Use 'sliding window' technique to achieve a prediction for the entire image
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                

                ###### START PROGRAMMING HERE ######
                
                epoch_val_loss = ...
                #### END PROGRAMMING HERE

                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            epoch_val_loss /= step
            val_loss_values.append(epoch_val_loss)
            
            print(f'Epoch validation loss: {epoch_val_loss}')

            metric_values.append(metric)

## Exercise 2: Plot the loss over 10 epochs

Make sure your run the training loop above for 10 epochs, and plot the loss on train and validation set below. Also plot the Dice metric on the validation set.

- **Hand-In Answers**: The generated plots and answers to the following questions.

1. **Question:** Are the loss on the training and validation set similar? Which is higher? Name **two** reasons why one would be higher than the other.

2. **Question:** Is the Dice metric the same as the Dice loss? If not, why are they different?



In [None]:
### Program here


## Check model output with the input image and label

Plot the results on the validation set. 

In [None]:
### No need to edit this.

model.eval()
with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model
        )

        ## Start programming here
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, 80])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, 80])
        plt.show()
        if i == 2:
            break

## Exercise 3: Experiment

From the results so far it is clear that there is room for improvement. Pick one of the following changes to the process and run an experiment to see how it affects the performance.


*   Make the UNet larger (See https://docs.monai.io/en/stable/networks.html#unet for the documentation)
*   Change the balance between foreground and background patches during training (See https://docs.monai.io/en/stable/transforms.html#randcropbyposneglabeld)
*   Change the loss function (See https://docs.monai.io/en/stable/losses.html and https://pytorch.org/docs/stable/nn.html#loss-functions )

You may edit this jupyter notebook to run all the experiments and plot the results. Considering the time required to train one network, we do not expect any cross-validation or extensive experimentation. You must show **at least three different settings/models** and compare their performance after at least **10 epochs**.

- **Hand-In Code**: The full code (.ipynb file) to run the experiment and plot the results. 
- **Hand-In Answers**: The plots you made and the answer to the question below.

3. **Question:** Did you find an improvement with your changes? How can you explain the (lack of) improvement?

