# Chapter 2.3: "UNet"

Originally developed by Olaf Ronneberger in MATLAB, the UNet has since been ported to Python and is the industry standard for solving segmentation tasks in the medical machine learning sector. The name is derived from the U-shape of the network. The UNet uses convolutions and pooling to reduce the size of input images and create an information bottleneck. Afterwards, the image is restored to its original size, step by step, utilizing transposed convolutions, which you will have heard about in the presentation of the UNet-Team. Steps on the down- and upwards slope of the U are additionally connected by skip connections, which you have seen before in the ResNet.

You can find the original paper here: https://arxiv.org/abs/1505.04597

Over the next sessions, we will recreate this milestone architecture in PyTorch.

In [None]:
import sys
sys.path.append("/datashare/MLCourse/Course_Materials") # Preferentially import from the datashare.
sys.path.append("../") # Otherwise, import from the local folder's parent folder, where your stuff lives.

import numpy as np
import time
import torch, torch.nn as nn
import torchvision, torchvision.transforms as tt
import albumentations
from torch.multiprocessing import Manager
torch.multiprocessing.set_sharing_strategy("file_system")
from typing import List

from utility import utils as uu
from utility.eval import evaluate_segmentation_model
from utility.unet import Example_UNet
from utility.segloss import ExampleSegmentationLoss

# TASK: Add some data augmentations of your choice (or None, if you want to test something else).

Only use the albumentations package for your augmentations. Why? Because albumentations transforms the targets with the same parameters as the original image. This functionality is vital to preserve useful targets for your predictions, and is guaranteed by albumentations without needing to do anything unusual.

A typical example looks like this:
```
augments = albumentations.Compose([
    albumentations.RandomCrop(width=256, height=256),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.RandomBrightnessContrast(p=0.2)
])
```
As you can see, the process is ostensibly the same as with regular torchvision transforms (although some of the names change on occasion). You can find the list of available transforms here: https://albumentations.ai/docs/getting_started/transforms_and_targets/

In [None]:
# TODO: Your data augments go here
data_augments = None

In [None]:
# Train, Val, and Test datasets are all contained within this dataset.
# They can be selected by setting 'ds.set_mode(selection)'.

# We could also cache any data we read from disk to shared memory, or
# to regular memory, where each dataloader worker caches the entire
# dataset. Option 1 creates more overhead than gain for this problem,
# while option 2 requires more memory than we have. Hence, we still
# read everything from disk.

cache_me = False
if cache_me is True:
    cache_mgr = Manager()
    cache_mgr.data = cache_mgr.dict()
    cache_mgr.cached = cache_mgr.dict()
    for k in ["train", "val", "test"]:
        cache_mgr.data[k] = cache_mgr.dict()
        cache_mgr.cached[k] = False

ds = uu.LiTS_Segmentation_Dataset(
    #data_dir = "/home/coder/Course_Materials/data/Clean_LiTS/",
    data_dir = "../data/Clean_LiTS/",
    transforms = data_augments,
    verbose = True,
    cache_data = cache_me,
    cache_mgr = (cache_mgr if cache_me is True else None),
    debug = True,
)

# This time, our dataset spits out a tensor (our image), and a list of tensors (our targets)

### TASK: Play around with the hyperparameters (if you feel like it).

In [None]:
# Default settings
batch_size = 32
learning_rate = 1e-4
weight_decay = 5e-6
epochs = 10
run_name = "UNet"
device = ("cuda" if torch.cuda.is_available() else "cpu")
time_me = False

In [None]:
# Dataloader
dl = torch.utils.data.DataLoader(
    dataset = ds, 
    batch_size = batch_size, 
    num_workers = 4, 
    shuffle = True, 
    drop_last = False, 
    pin_memory = True,
    persistent_workers = (not cache_me),
    prefetch_factor = 1
    )

### TASK: Construct a UNet.

The input dimensions for the network will be the usual B x 1 x 256 x 256. The output dimensions should be B x 3 x 256 x 256. We have three output channels because we will still predict classes 0 (background), 1 (liver) and 2 (liver tumor) - this time, however, we predict the classes on a per-pixel basis.

Since our input images have vastly smaller dimensions compared to those used in the original UNet-Paper, we will opt for a different scale of UNet. The general design remains the same as in the paper, except:
- We will only downsample 3 times by a factor of 2, using MaxPool (for a minimum resolution 32x32).
- Our 3x3 Convolutions will have Padding. Consequently, there will be no cropping during skip connections
- We will only have 3 skip connections.
- We will go for fewer maximum channels (as we have only 3 downsampling steps, we will have 64, 128, 256, and 512 channels).
- Our final output will be 3 channels wide, not 2 (we predict background, liver, and liver tumors).

Below, you can find an example UNet. If you want to test your loss module (or anything, really), it should be working with this UNet. Note that the example UNet does **not** follow the exact specifications of this task.

In [None]:
# Stand-in example model (if you want to test something else)
model = Example_UNet(in_channels = 1, out_classes = 3)
model = model.to(device)

In [None]:
# Your implementation
class UNet(torch.nn.Module):
    pass

In [None]:
# Create an instance of your model
model = UNet()
model.to(device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate, weight_decay = weight_decay)

### TASK: Create a DICE/XE loss

The loss you create should fulfill the following criteria:
- It subclasses torch.nn.module
- It is a class that implements an \_\_init\_\_ function and a forward function.
- The forward function takes as argument the predictions from your model, and the target masks from the dataset.
- The loss function should compute a Cross-Entropy (XE) loss and a DICE loss, based on predictions and targets and return their sum or weighted sum.
- The loss should accept a tensor of shape $B*3*H*W$ as predictions, and a tensor of shape $B*2*H*W$ as targets, where B is the batch size, and H and W are height and width of the images. The predictions have a channel dimension of length 3, because we predict background, liver, and tumor regions. Targets only have a channel dimension of length 2 because we only get segmentations for liver and tumor regions - we will derive our background from the other two segmentations.
- Note that the class you write should not use torch.nn.CrossEntropyLoss under the hood - we want to write the calculation of the loss ourselves.

Below you will find a regular XE loss, designed for this segmentation task. You can use the example loss module to test your UNet implementation. If your UNet works correctly, this module should work, too.

In [None]:
# We create an instance of the loss module and we put the loss module onto the GPU aswell.
# This is not necessary, but greatly speeds up the computation, if you have the space.
# For segmentation tasks, this can be a real time saver.
criterion = ExampleSegmentationLoss(
    classes = 3,
    weights = torch.Tensor([1, 3, 10]),
    on_the_fly_background = True,
    allow_multiclass = False
    )
criterion = criterion.to(device)

In [None]:
# Your implementation
class DICE_XE_Loss:
    pass

In [None]:
# Create an instance of your loss module
criterion = DICE_XE_Loss()
criterion = criterion.to(device)

In [None]:
if time_me is True:
    c_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    
    # If we are caching, we now have all data and let the (potentially non-persistent) workers know
    if cache_me is True and epoch > 0:
        dl.dataset.set_cached("train")
        dl.dataset.set_cached("val")
    
    # Time me
    if time_me is True:
        e_start = time.time()

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data = data.to(device)
        targets = [target.to(device) for target in targets]
        if time_me is True:
            c_end = time.time()
            if step % 20 == 0:
                print(f"CPU time: {c_end-c_start:.4f}s")
            g_start = time.time()
        predictions = model(data)
        if time_me is True:
            g_end = time.time()
            c_start = time.time()
        if step % 20 == 0 and time_me is True:
            print(f"GPU time: {g_end-g_start:.4f}s")
        loss = criterion(predictions, targets)
        if step % 20 == 0:
            print(f"Epoch [{epoch+1}/{epochs}]\t Step [{step+1}/{num_steps}]\t Train Loss: {loss.item():.4f}")
        uu.csv_logger(
            logfile = f"../logs/{run_name}_train.csv",
            content = {"epoch": epoch, "step": step, "loss": loss.item()},
            first = (epoch == 0 and step == 0),
            overwrite = (epoch == 0 and step == 0)
                )
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    metrics = {"epoch": epoch}
    metrics.update(evaluate_segmentation_model(model = model, dataloader = dl, device = device))
    print('\n'.join([f'{m}: {v}' for m, v in metrics.items() if not m.startswith("#")]))
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {m: v for m, v in metrics.items() if not m.startswith("#")},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )
        
    if time_me is True:
        print(f"Epoch time: {time.time()-e_start:.4f}s")

# Finally, test time
ds.set_mode("test")
model.eval()

metrics = evaluate_segmentation_model(model = model, dataloader = dl, device = device)
print("Test-time metrics:")
print('\n'.join([f'{m}: {v}' for m, v in metrics.items() if not m.startswith("#")]))
uu.csv_logger(
    logfile = f"../logs/{run_name}_test.csv",
    content = {m: v for m, v in metrics.items() if not m.startswith("#")},
    first = True,
    overwrite = True
        )

# Visualization

If you have some time left and your model has successfully trained, try outputting a few images and the segmentations your model predicted, as well as the ground truth. Does your model do a good job? Is there something your model is particularly good or bad at?