In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import torch as th 
import torch.nn as nn
import torchvision.transforms as transforms
import pickle
import gzip
from tabulate import tabulate


**Image segmentation** is an important task in  computer vision. The goal is to find multiple areas in an image and to assign labels to these area.  It provides a different kind of information than: 
- **image classification**  which caracterizes images with global labels;
- **object detection** which usually relies on finding bounding-boxes around  detected objects

Segmentation is useful and can be used in real-world applications such as medical imaging, clothes segmentation, flooding maps, self-driving cars, etc. There are two types of image segmentation:
- Semantic segmentation: classify each pixel with a label.
- Instance segmentation: classify each pixel and differentiate each object instance.

U-Net is a semantic segmentation technique [originally proposed for medical imaging segmentation](https://arxiv.org/abs/1505.04597). It’s one of the earlier deep learning segmentation models. This architecture is still widely used in more advanced models like Generative Adversarial or Diffusion Network. 

The model architecture is fairly simple: an encoder (for downsampling) and a decoder (for upsampling) with skip connections. U-Net is only based on convolutions. More specifically, the output classification is done at the pixel level with a *(1,1)* convolution. It has therefore  the following advantages: 
- parameter and data efficiency, 
- independent of the input size. 

The following image is taken from the original paper:

<img src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png"  width="60%" height="30%">


The goal of this lab session is to develop the U-Net architecture for image semantic segmentation and we will consider a binary segmentation task.  

#  Cell nuclei segmentation: the dataset

Cell nuclei segmentation is an essential step in the biological analysis of microscopy images. 
This segmentation can be manually achieved with dedicated software, however it is very costly. 
In this lab session, the starting point is this [nature paper](https://www.nature.com/articles/s41597-020-00608-w). To quote some part of the paper: 

Fully-automated nuclear image segmentation is the prerequisite to ensure statistically significant, quantitative analyses of tissue preparations,applied in digital pathology or quantitative microscopy. The design of segmentation methods that work independently of the tissue type or preparation is complex, due to variations in nuclear morphology, staining intensity, cell density and nuclei aggregations. Machine learning-based segmentation methods can overcome these challenges, however high quality expert-annotated images are required for training. Currently, the limited number of annotated fluorescence image datasets publicly available do not cover a broad range of tissues and preparations. We present a comprehensive, annotated dataset including tightly aggregated nuclei of multiple tissues for the training of machine learning-based nuclear segmentation algorithms. The proposed dataset covers sample preparation methods frequently used in quantitative immunofluorescence microscopy. 

To spare some preprocessing time, this lab session starts with this pickle (download it and make it available for your notebook). 

In [None]:
fn = "nuclei_cells_segmentations.pck"
f = open(fn, 'rb')
X, Y = pickle.load(f)
print(X.shape, Y.shape)
N = X.shape[0]

In [None]:
print("Original image / Binary segmentation")
for i in (3,14):
    figs, axs = plt.subplots(1,2)
    axs[0].imshow(X[i].squeeze())
    axs[1].imshow(Y[i].squeeze())

This pickle contains a modified version of the dataset: 
- the same amount of images
- all the images are resized to 128,128
- the segmentation task is converted in a binary pixel classification: nuclei or not. 

The goal is now to train a U-Net on this dataset (70 images for training and 9 for "test"). 

# U-Net overview

Following the previous picture of U-Net, the network is composed of 3 parts: encoder, bottleneck, decoder. These three steps rely on a convolutional block (convolution, relu, convolution, relu) .

The first step is the **encoder**. The goal is to compress the  "geometrical" information with local features (output channels). The encoder first applies a convolution of kernel size (3,3) to extract $F=64$ features. Then the spatial information is compressed using max-pooling (factor 2). The next step does the same:  extract $2\times F=128$ features from the $F=64$, then compression with max-pooling. This operation is repeated 4 times in total to get at the end $F\times 8 = 512$ channels that represent global features extracted from the input image. 

The **bottleneck** layer is a convolutional layer which doubles the number of channels. The idea is to create a "dense" representation of the image to gather both global and local features. 

The **decoder** part is similar to the encoder part but reversed. While we used max-pooling for downsampling in the encoder, the upsampling operation consists in **transposed convolution**. The goal is to increase (so upsample) the spatial dimensions of intermediate feature maps. 

The last peculiarity is the **output layer for classification** at the pixel level. In U-Net this last layer is (once again) a convolutional layer. This means that with the last hidden layer, we recover the same spatial dimension as the input with $F$ feature maps. The classification is carried out for each pixel independently, but the decision is based on $F$ features that encode global information. 

Before creating a U-Net model, we first study the new kind of layer `ConvTranspose2D`

# ConvTranspose2D

In pytorch, transposed convolution is achieved with the module ConvTranspose2D (for images or 2D objects).
To better understand how it works, it can be useful to play with it. 


Since we will print many matrices to better understand this new operation, we first provide a helper function to better visualize the content of a torch tensor: 

In [None]:
def ppmatrix(m, message=None):
    """Pretty print for matrices
    Args: 
    - expect a torch Tensor
    
    Output: 
    The print 
    
    Apply detach, squeeze, and numpy to the input tensor (not inplace) 
    """
    if message is not None: 
        print(message)
    if len(m.shape) == 1: 
        print(m.squeeze().detach().numpy())
    else: 
        print(tabulate(m.squeeze().detach().numpy(),tablefmt="fancy_grid",floatfmt=".3f"))
    

Now we can use it to see what are the parameters of ConvTranspose2D: 

In [None]:
c = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, bias=True)
ppmatrix(c.weight,"weights: ")
print("the bias: ")
ppmatrix(c.bias)

As you can see, the operation is parametrized by a convolution mask $\mathbf{W}$ and one bias term. For one value $v$ in the input, we get as output $v\times\mathbf{W}+b$. As an illustration, we can consider a simple image with one channel. To start, it is easier to start without the bias term: 


In [None]:
im = th.zeros(1,1,2,2)
c = nn.ConvTranspose2d(in_channels=1, out_channels=1, 
                       kernel_size=2, stride=1, 
                       bias=False)
print("Weights: ",c.weight)
print("Bias: ",c.bias)

for i in range(4): 
    im = th.zeros(1,1,2,2)
    im[0,0,i%2,i//2] = 1
    print("-------------------")
    ppmatrix(im,"image: ")
    ppmatrix(c(im),"output: ")


Try now with an image full of 1 and explain the result: 

In [None]:
im = th.ones(1,1,2,2)
ppmatrix(c(im))


Now we can consider different stride: 

In [None]:
c = nn.ConvTranspose2d(in_channels=1, out_channels=1, 
                       kernel_size=2, stride=2,  
                       bias=False)
print(c.weight)
for i in range(4): 
    print("----------------")
    im = th.zeros(1,1,2,2)
    im[0,0,i%2,i//2] = 1
    ppmatrix(im,"image:")
    print(c(im),"output:")

Try to understand the previous and the next examples and how it can be used to upsample. 

In [None]:
im = th.ones(1,1,2,2)
ppmatrix(c(im))

More details can be found on this [blog post](https://towardsdatascience.com/understand-transposed-convolutions-and-build-your-own-transposed-convolution-layer-from-scratch-4f5d97b2967). 

# Simple U-Net 
Now the goal is to implement a simplified version of U-Net (a short version of the "U"). 
The first step is the **encoder** that compresses (or reduces) the spatial dimension to create rich features that represent global information. 

## Step by step
The encoder is composed of many successive blocks. One block is made of : 
- twice the sequence of a convolution (kernel size 3, stride 1, and F output channels), relu, batchnorm
- followed by a max-pooling that reduces the dimensions by 2 (each spatial dimension is halved) 
The input image has one input channel and the convolutions generate $F$ output channel. $F$ will be a parameter and start with $F=4$. 

**TODO**: write the corresponding module and test if it works properly. Check the output dimensions. 

The **bottleneck** layer is a convolutional layer which doubles the number of channels. The idea is to create a "dense" representation of the image to gather both global and local features.

**TODO**: write the corresponding module and test if it works properly. Check the output dimensions. 

The **decoder** part is similar to the encoder part but reversed. While we used max-pooling for downsampling in the encoder, the upsampling operation consists in **transposed convolution**. The goal is to increase (so upsample) the spatial dimensions of intermediate feature maps while reducing the number of channels by a factor 2 for all of them. The important point is the residual connection. 

**TODO**: write the corresponding module and test if it works properly. Check the output dimensions. 

The last peculiarity is the output layer for classification at the pixel level. In U-Net this last layer is (once again) a convolutional layer. This means that with the last hidden layer, we recover the same spatial dimension as the input with $F$ feature maps. The classification is carried out for each pixel independently, but the decision is based on $F$ features that encode global information. 

**TODO**: write the corresponding module and test if it works properly. Check the output dimensions. 

## A class for  simple U-Net

Now we can merge all we did in the previous section to create a U-Net model (light version). 

**TODO**: 
- Write the `Module` that takes $F$ as hyper-parameter
- Train it on the first 70 training images with $F=4$


# Evaluation 

The evalution is important to really assess what we achieved and it can depend on the task and our purpose. 

Of course we can compute the accuracy (the % of well classified pixels), but it could be not enough. Here we can also plot the evaluation. 

**TODO**: 
- make a function that plots the evaluation result of an input image
- try it on all the test images
- maybe you can look at the accurate part, but what if you want to see which pixels belonging to a nucleous is missed ? 
- and for pixels wrongly affected to the nucleous class ? 

## Precision and recall

For segmentation, it can be meaningful to look at the precision and recall. These two terms have very broad meaning depending on the purpose. Here we could say : 
- The precision for the class 1 (a nucleus) is the ratio between the number of true positives for the class 1  and the total of pixels classified as nuclei by the model. 
- The recall is the ratio between the number of pixels of class 1 correctly classified and the total of pixels classified that should be classified as nuclei. 

These measures depend on a threshold of the output score. While the "natural" threshold is $0$ on the output score (or $0.5$ if the model outputs probabilities), we can consider different tradeoff between precision and recall by varying the threshold. 

**TODO:**
- Make a function which computes precision and recall for a given threshold
- Plot the precision *vs* recall curve for a threshold varying between -5 and +5
- Compare models with $F=4,16,32$


# U-Net

Now the goal is to implement U-Net. As a proposed roadmap we propose the following step: 
- a function to create a convolutional block
- a module for the encoder
- a module for the decoder
- and a U-Net module to wrap everything

The number of feature map ($F=64$ in the original work) must be a variable of the UNet. For the first round of experiment, we can use $F=8$.  


**TODO:**
- run the training on the 70 first images and spare the last 9 for evaluation
- after the training process, look at the results on some training images and the evaluation ones. 


# Data Augmentation 
When the dataset is scarce, we can try data-augmentation. The idea is to apply transformation introduce diversity in the dataset with basic transformation. [Look at this page for more information](https://pytorch.org/vision/main/transforms.html). 

**TODO:**
- Select a couple of transformation
- Evaluate the impact of data-augmentation on our task.


