# Lab 2: Encoder-Decoder Models

In this lab, we investigate the U-Net architecture and see two application examples.

<img src="https://upload.wikimedia.org/wikipedia/commons/2/2b/Example_architecture_of_U-Net_for_producing_k_256-by-256_image_masks_for_a_256-by-256_RGB_image.png" width="100%">

Image Wikipedia

## Questions

* Why does `bilinear=False` build a U-Net with more parameters than `bilinear=False`?

#SOLUTIONSTART
bilinear=False implements upsampling with transposed convolutions, which introduces additional weights
#SOLUTIONEND

* U-Nets (and other encoder-decoder models) work very well for medical or remote sensing (satellite) images, but not for classic natural images (natural scenes like cats, dogs, humans, streets). Given their architecture, can you explain why they may be better suited for medical or remote sensing (satellite) images? Think about which inductive biases are encoded in the model architecture.

#SOLUTIONSTART
In general, U-Nets and encoder-decoder models are used broadly used for medical or remote sensing (satellite) images, due to their inductive biases: Here, relevant features in the images have the same scale. The skip connections in the U-Net implement this intuition natively.
#SOLUTIONEND

# U-Net Source Code

Let's investigate how a U-Net is concretely implemented in Pytorch.

For your projects, you can use existing packages like [`torch_segmentation_models`](https://smp.readthedocs.io/en/latest/quickstart.html) that provide various model implementations using common classification backbones (i.e., encoders):
```
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)
```

But for now, let's investigate a vanilla torch U-Net from [the Pytorch-UNet
 repository](https://github.com/milesial/Pytorch-UNet/tree/master)

**Task**: 
* Implement the decoder forward pass through the up layers (see #TODO in the forward())

In [53]:
import torch.nn as nn
import torch.nn.functional as F

#============== some parts of the U-Net model ===============#
""" Parts of the U-Net model """
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

#=================== Assembling parts to form the network =================#
""" Full assembly of the parts to form the complete network """

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x.float())
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        # TODO implement the forward pass of the decoder up layers up2, u3, up4
        # x = ...
        # x = ---
        # x = ...
        #SOLUTIONSTART
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        #SOLUTIONEND
        logits = self.outc(x)
        return logits
    
model = UNet(3,1, bilinear=True)
model

### Investigating the U-Net Upsampling Mechanism

**Question**
* How is the upsampling implemented? Check and follow the `bilinear` argument in the code above.

In [54]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print("bilinear = False")
print(get_n_params(UNet(3,1, bilinear=False)))

print("bilinear = True")
print(get_n_params(UNet(3,1, bilinear=True)))

## Application Examples

### Example FLAIR Abnormality Segmentation in Brain MRI

"Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm" [Buda et al., 2019](https://www.sciencedirect.com/science/article/abs/pii/S0010482519301520?via%3Dihub)

U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI based on a deep learning segmentation algorithm used in Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm.

This code is based on [this TorchHub example](https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/)

In [55]:
import torch
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

# Download an example image
import urllib
url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

The U-Net model Architecture

In [56]:
model

In [57]:
import numpy as np
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
preprocess = transforms.Compose([
    transforms.ToTensor(),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)

if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model = model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1,2)
axs[0].imshow(input_image)
axs[0].axis("off")
axs[0].set_title("input image")
axs[1].imshow(output[0][0])
axs[1].axis("off")
axs[1].set_title("segmentation")

## Explore Model Architecture through Hooks

In PyTorch, hook callback functions are mechanisms used to intercept and observe the computation flow during the forward and backward passes of a neural network. They allow you to access and manipulate intermediate outputs, gradients, and parameters of the network at various stages of its execution.

Types of Hooks:

* Forward Hooks: These hooks are executed during the forward pass of the network, allowing you to access intermediate activations or outputs of each layer before they are passed to the next layer.
* Backward Hooks: These hooks are executed during the backward pass, giving you access to gradients of the parameters with respect to the loss function. They are useful for tasks like gradient visualization, gradient-based optimization, or debugging.

Implementation:

To implement a hook, you define a callback function that specifies what you want to do with the intermediate data or gradients.
You then attach this callback function to the desired module or parameter using the register_forward_hook() or register_backward_hook() methods.

In [52]:
import torch

def forward_hook(module, input, output):
    if isinstance(input, tuple):
        input = input[0]
        
    print(f"Forward pass - module {module} -  Input shape: {input.shape}, Output shape: {output.shape}")

try: # try-catch are mechanisms to handle exceptions. They are here to make sure that the hook is removed even if your hook function throws an error
    hook_handle = model.conv.register_forward_hook(forward_hook)
    output = model(input_batch)
finally: # this code gets executed even if the code above throws an exception
    hook_handle.remove() 

In [58]:
import matplotlib.pyplot as plt

def hook(module, inputs, outputs):
    print(inputs[0].shape)
    print(outputs.shape)
    print(module)
    
    fig, axs = plt.subplots(4,8, figsize=(3*8,3*4))
    #TODO: plot the 32 feature map before the last 1x1 convolution
    #SOLUTIONSTART
    for i, (inp, ax) in enumerate(zip(inputs[0][0], axs.reshape(-1))):
        ax.imshow(inp)
        ax.set_title(f"dim {i+1}")
        ax.axis("off")
    plt.tight_layout()
    #SOLUTIONEND
    
    fig, ax = plt.subplots()
    ax.imshow(outputs[0,0])
    ax.set_title("outputs")
try:
    hook_handle = model.conv.register_forward_hook(hook)
    with torch.no_grad():
        output = model(input_batch)
finally:
    hook_handle.remove()