# Hippocampus U-Net Segmentation with Pytorch Lightning + MONAI

In this exercise we bring together all of what we have learned during the course. Instead of building things from scratch for learning the basics as we did during the previous exercises, we now apply convenient out of the box python modules that have best practice implementations of the state of the art of training and models for quick prototyping!

We will be using these tools:

* **PyTorch** as the deep learning framework: https://pytorch.org/
* **Medical Segmentation Decathlon** for the hippocampus dataset: http://medicaldecathlon.com
*   **Medical Open Network for AI (MONAI)** for transforms, downloading the data and models: https://docs.monai.io/en/latest/index.html
*  **Pytorch Lightning** for dataloading and training: https://pytorchlightning.ai/


![Hippocampus](https://upload.wikimedia.org/wikipedia/commons/9/99/Hippocampus.gif)


Have fun!




## Check for GPU Runtime

By default Google Colab runs without a GPU - we need one!
Go to the menu and change it there:

*   German: Laufzeit > Laufzeittyp aendern > Hardwarebeschleuniger = GPU
*   English: Runtime > Change runtime type > Hardware accelerator = GPU



In [None]:
import torch
assert torch.cuda.device_count(), 'This exercise is a lot faster with a GPU - thanks gamers'
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

## Hippocampus Segmentation Dataset

For this example we will be using the Hippocampus Segmentation Challenge data from the Medical Segmentation Decathlon: 

http://medicaldecathlon.com/#tasks

To make our live easier we will be using some methods from the Medical Open Network for AI (MONAI) module:

https://docs.monai.io/en/latest/index.html


In [None]:
# install monai framework
!pip install tqdm --upgrade --quiet
!pip install monai --quiet

import logging

logging.disable(logging.WARNING)

In [None]:
from monai.transforms import Compose, LoadNiftiD, AddChannelD, ScaleIntensityD, ToTensorD
from monai.apps import DecathlonDataset

transform = Compose(
    [
        LoadNiftiD(keys=["image", "label"]),
        AddChannelD(keys=["image", "label"]),
        ScaleIntensityD(keys="image"),
        ToTensorD(keys=["image", "label"]),
    ]
)
train_data = DecathlonDataset(
    root_dir="./", task="Task04_Hippocampus", transform=transform, section="training", seed=12345, download=True
)
val_data = DecathlonDataset(
    root_dir="./", task="Task04_Hippocampus", transform=transform, section="validation", seed=12345, download=True
)

Check the dataset size, shape and value ranges

In [None]:
import random

idx = random.randint(0, len(train_data))
image = train_data[idx]['image']
label = train_data[idx]['label']
print(f'train dataset size: {len(train_data)}')
print(f'validation dataset size: {len(val_data)}')
print(f'\nrandom index: {idx}')
print(f'\nimage shape', image.shape)
print(f'label shape', label.shape)
print(f'\nimage range', image.min(), image.max())
print(f'label range', label.min(), label.max())
print(f'label unique values', label.unique())

Plot random images

In [None]:
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
import logging

logging.disable(logging.WARNING)

to_pillow = ToPILImage()

idx = random.randint(0, len(train_data))
# plot some input images
sample = train_data[idx]

input = sample['image'].squeeze()
input_imgs = [input.select(i, input.shape[i]//3) for i in range(3)]

label = sample['label'].squeeze()
label_imgs = [label.select(i, label.shape[i]//3) for i in range(3)]

fig,ax = plt.subplots(2,3, figsize=(12, 9))
fig.suptitle(f'Hippocampus Dataset Sample - index {idx}', fontsize=16)
for i in range(3):
  ax[0,i].imshow(to_pillow(input_imgs[i]), cmap='gray')
  if not i:
    ax[0,i].set_ylabel('input', rotation=90, fontsize=20)
for i in range(3):
  ax[1,i].imshow(to_pillow(label_imgs[i]), cmap='gnuplot')
  if not i:
    ax[1,i].set_ylabel('label', rotation=90, fontsize=20)
fig.show()

## U-Net Model

We will be using a 3D version of the famous U-Net first published here by Olaf Ronneberger et al. from the University of Freiburg (https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/).

Here's the original paper: https://arxiv.org/abs/1505.04597

![U-Net Architecture](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

This is a very basic implementation of the original 2D U-Net with minor improvements (e.g. optional batch normalization):


In [None]:
# adapted from: https://github.com/jvanvugt/pytorch-unet/blob/master/unet.py
# adapted from https://discuss.pytorch.org/t/unet-implementation/426

import torch
from torch import nn
import torch.nn.functional as F

from math import log

class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=2,
        num_filters=64,
        depth=5,
        padding=False,
        batch_norm=False,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        wf=int(log(num_filters, 2))
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        # encoder
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)
        # decoder
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

Today we will not be using this one but an improved yet still very basic 3D version with residual connnections but the core architecture is still the same as the original U-Net. More details here https://docs.monai.io/en/latest/networks.html#basicunet and here https://www.nature.com/articles/s41592-018-0261-2.

## Segmentation loss function: DICE Loss

Before we can setup our training pipeline we need a a loss function. For a semantic segmentation task, the go-to loss for the beginning should be the DICE losss. The DICE loss is one of the most popular contributions of our chair and Prof. Dr. Nassir Navab to the Deep Learning community. It was first published by Fausto Milletari et. al. in 2016 in their publication [*V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation*](https://arxiv.org/abs/1606.04797) and is based on the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient). The DICE score is equivalent to the F1 score in a classification task.

![DICE Formula](https://wikimedia.org/api/rest_v1/media/math/render/svg/a80a97215e1afc0b222e604af1b2099dc9363d3b)

The core idea is to directly optimize for the DICE coefficient in a semantic image segmentation task by making it differentiable. It is calculated by dividing the doubled sum of the intersection of the predictions and groundtruth by the union of predictions and groundtruth.


A good visualization can be found here:

![DICE loss](https://miro.medium.com/max/486/1*yUd5ckecHjWZf6hGrdlwzA.png)

https://towardsdatascience.com/metrics-to-evaluate-your-semantic-segmentation-model-6bcb99639aa2

In [None]:
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F

class DiceLoss():
  '''
  Calculates the DICE loss per channel
  '''
  def __call__(self, segmentation, labelmap):
    # convert label encoding e.g. 0,1,2 to one-hot-encoded vectors
    # we do this to bring the labelmaps from the dataset to the same format as the networks output (which is one-hot-encoded)
    labelmap = F.one_hot(labelmap.long(), num_classes=segmentation.shape[1]).squeeze(dim=0).movedim(-1, 1).float()
    # you could just use the monai DICE loss: https://docs.monai.io/en/latest/losses.html#monai.losses.DiceLoss
    smooth = 1e-5 # to avoid division by zero
    # summing up all voxels where groundtruth has been predicted correctly over all spatial dimensions and the batch
    #  -> vector of size of c
    intersection = einsum("bcxyz, bcxyz->c", segmentation, labelmap) + smooth
    # summing up all voxels predicted for class c -> vector of size of c
    prediction = einsum("bcxyz, bcxyz->c", segmentation, segmentation)
    # summing up all ground truth voxels for class c -> vector of size of c
    groundtruth = einsum("bcxyz, bcxyz->c", labelmap, labelmap)
    # adding up
    union = prediction + groundtruth + smooth
    # DICE formula for the differentiable DICE loss per class c
    DICE = 2.0 * intersection / union
    # a loss needs to decrease in order to be optimized so we need to make it negative
    # we choose 1 - dice_loss just to leave the losss positive -dice_loss would work just as well
    dice_loss = 1 - DICE
    # return the dice loss per channel (vector with the size of classes)
    return dice_loss

## Pytorch Lightning Module

This is our favorite framework for setting up pytorch projects: https://pytorchlightning.ai/

Check out the documentation: https://pytorch-lightning.readthedocs.io/en/stable/

![Pytorch Lightning Idea](https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/docs/source/_images/general/pl_quick_start_full_compressed.gif)



### Lightning module
extension of pytorch models with a simple training interface

In [None]:
# first we need to install pytorch lightning
! pip install pytorch-lightning --quiet

In [None]:
import pytorch_lightning as pl
from monai.networks.nets import BasicUNet
import torch

class SegmentationModule(pl.LightningModule):

  def __init__(self, in_channels=1, n_classes=3):
    super().__init__()

    # load unet as defined above as the backbone of our segmentation model
    # see https://docs.monai.io/en/latest/networks.html#basicunet 
    # and here https://www.nature.com/articles/s41592-018-0261-2.
    self.unet = BasicUNet(
      dimensions=3, # using 3d images
      in_channels=in_channels, # we just have one grey channel as input
      out_channels=n_classes, # binary segmentation -> just one output channel
      features=(32, 32, 64, 128, 256, 32),
    )

    # final layer for activation i.e. converting the logits to a value between 0 and 1
    self.activation = nn.Softmax(dim=1)
    
    # you could just use the monai DICE loss: https://docs.monai.io/en/latest/losses.html#monai.losses.DiceLoss
    self.dice_loss = DiceLoss()

  def forward(self, x):
    output = {}
    output['logits'] = self.unet(x)
    output['probabilities'] = self.activation(output['logits'])
    output['predictions'] = output['probabilities'].gt(0.5).float() # set threshold for positive prediction to > 0.5
    return output

  def training_step(self, batch, batch_idx):
    segmentation = self(batch['image'])
    labelmap = batch['label']
    # for the dice loss we take the mean of all channels including the background
    loss = self.dice_loss(segmentation['probabilities'], labelmap).mean()
    # in the metric we are only interested in the two target classes
    dice_metric = 1 - self.dice_loss(segmentation['predictions'], labelmap)[1:].mean()
    self.log('train_DICE', dice_metric, on_step=False, on_epoch=True, prog_bar=True)
    self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=False)
    return loss

  def validation_step(self, batch, batch_idx):
    segmentation = self(batch['image'])
    labelmap = batch['label']
    loss = self.dice_loss(segmentation['probabilities'], labelmap).mean()
    dice_metric = 1 - self.dice_loss(segmentation['predictions'], labelmap)[1:].mean()
    self.log('val_loss', loss, prog_bar=True)
    self.log('val_DICE', dice_metric, prog_bar=True)

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=5e-4)

### Data module 
just a fancy dataloader with the same dataset we loaded above

In [None]:
from monai.apps import DecathlonDataset
from monai.transforms import Compose, LoadNiftiD, AddChannelD, ScaleIntensityD, ToTensorD
from torch.utils.data import DataLoader

class HippocampusData(pl.LightningDataModule):

  def __init__(self, data_dir="./", batch_size=1):
    super().__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size

  def setup(self, stage=None):
    # very basic transformations check out: https://docs.monai.io/en/latest/transforms.html
    self.transform = Compose(
      [
        LoadNiftiD(keys=["image", "label"]),
        AddChannelD(keys=["image", "label"]),
        ScaleIntensityD(keys="image"),
        ToTensorD(keys=["image", "label"]),
      ]
    )
    self.train_data = DecathlonDataset(
      root_dir=self.data_dir, 
      task="Task04_Hippocampus", 
      transform=transform, 
      section="training", 
      seed=42, 
      download=True
    )
    self.val_data = DecathlonDataset(
      root_dir=self.data_dir, 
      task="Task04_Hippocampus", 
      transform=transform, 
      section="validation", 
      seed=42, 
      download=True
    )

  def train_dataloader(self):
      return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=4)

  def val_dataloader(self):
      return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=4)

### Training process

Training setup

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger

# setting up our pytorch lightning modules
model = SegmentationModule()
data = HippocampusData()
logger = TensorBoardLogger("tb_logs", name="hippo 3D unet")

# setting up our pytorch lightning trainer
trainer = pl.Trainer(
  gpus=1, 
  max_epochs=10, 
  progress_bar_refresh_rate=20,
  limit_train_batches=50, # we just look at 50 random samples per epoch to shorten them
  logger=logger
  # weights_summary='full', # if you are curious about the model uncomment this
)
print('training setup finished')

Training loop

In [None]:
trainer.fit(model, data)

Plot some predictions with input and groundtruth

In [None]:
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt

to_pillow = ToPILImage()

idx = random.randint(0, len(train_data))
# plot some input images
sample = train_data[idx]

input = sample['image'].squeeze()
input_imgs = [input.select(i, input.shape[i]//3) for i in range(3)]

label = sample['label'].squeeze()
label_imgs = [label.select(i, label.shape[i]//3) for i in range(3)]

prediction = model(sample['image'].unsqueeze(dim=0))['predictions']

dice_score = 1 - model.dice_loss(prediction, sample['label'].unsqueeze(dim=0))[1:].mean()

prediction = prediction.squeeze().argmax(dim=0).float()

pred_imgs = [prediction.select(i, prediction.shape[i]//3) for i in range(3)]

fig,ax = plt.subplots(3,3, figsize=(12, 9))
fig.suptitle(f'Hippocampus 3D-Unet Segmentation - index {idx} - DICE {dice_score:.2%}', fontsize=16)
for i in range(3):
  ax[0,i].imshow(to_pillow(input_imgs[i]), cmap='gray')
  if not i:
    ax[0,i].set_ylabel('input', rotation=90, fontsize=20)
for i in range(3):
  ax[1,i].imshow(to_pillow(label_imgs[i]), cmap='gnuplot')
  if not i:
    ax[1,i].set_ylabel('label', rotation=90, fontsize=20)
for i in range(3):
  ax[2,i].imshow(to_pillow(pred_imgs[i]), cmap='gnuplot')
  if not i:
    ax[2,i].set_ylabel('prediction', rotation=90, fontsize=20)
fig.show()

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir tb_logs

## What next?

Dive into the documentation of Pytorch Lightning and MONAI:

https://pytorch-lightning.readthedocs.io/en/stable/
https://docs.monai.io/en/latest/index.html

*   Optimize the training and do some hyper parameter tuning!
*   Try your own models
*   Try other datasets



