### PSA:

This notebook is going to have a laid-back tone so as to ensure that whoever goes through this not only understands what is happening, but also knows of my personal thoughts on the paper as I read through it. A good way to think of this is me giving a semi-live commentary on the paper as I read through it while coding it up. I will not be covering every single detail since it would lead to too much distraction, but do expect a few out-of-context comments here and there. 

# U-Nets

U-Nets are one of the first few segmentation algorithms, after the fully convolutional networks; thet pioneered this particular subdomain of image segmentation. Initially created for biomedical image segmentation, this algorithm is a staple in computer vision and has found many uses over the years ranging from autonomous vehicles and (the obvious) biomedical segmentation, to diffusion frameworks like Dall-E and Midjourney.

This model makes use of the **encoder-decoder** architecture alongside **residual connections**, leading to an impressive level of clarity in the segmentation maps. I really want to see this in action here which is why I will be doing what I usually do with deep learning networks, that is, breaking it open to see how it functions between layers.

In this notebook I will be looking at how this model works following the paper as I go along with the code. I will be looking at segmenting images, and down the line, will eventually explore the generative properties of this model in another project. Going by the theme of the original paper, I will be experimenting with the [DRIVE](https://drive.grand-challenge.org/DRIVE/) dataset, which is openly accessible (you will need to sign up in the link above to be able to download the dataset).

Unzip the training and test datasets from the link and add them into the `/data` folder for this code to work.

Now, let's finally get to it! We load up the packages we will need for the project.

In [None]:
from pathlib import Path
import numpy as np
import gc
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from matplotlib import pyplot as plt 

I will also be trying to implement CUDA to hasten the training process. This is easier done on Linux.

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

From the paper by [Ronneberger et al. (2015)](https://arxiv.org/pdf/1505.04597) we are aware that historically CNNs have been good at classification tasks. However, in uses cases like biomedical imaging we might need to:
- Localize parts our image
- Take into account the fact that we might not have a huge dataset to work with.

The U-Net builds upon the fully convolutional network to ensure that it can work with a very small dataset and yield precise segmentations.

**Personal Note:** This paper has a lot of exposition describing efforts by researchers to get to this point. I personally really like this approach, and it seems to put things in a very easy-to-follow manner. I will, however be avoiding minutiae from this point forward unless it is relevant to our use case.

## U-Net Architecture

The most common diagram of the U-Net is taken directly from the paper as below

![image.png](images/image.png)

Here I notice a few main patters that will help code up this model.
- The first pattern is:
  $$(2 \times (\text{conv} + \text{RELU}) + \text{Maxpool}) \xRightarrow{out(resize)}$$
  - This block is repeated **four** times to bring the original image from its original dimensions to the lower dimensional representations.
  - The convolution operations double the number of input channels while gradually decreasing the input dimensions.
  - At each Maxpool, convolution output dimensions get halved. In every set of the convolution blocks except the first one, the number of channels gets doubled.
  - This is the first part of the U or the "encoder".

- Then we have:
  $$2 \times (\text{conv} + \text{RELU})$$
  - A repetition of the same convolution as before is performed on the lower dimensional representation, bringing down the dimensions even further while doubling the number of channels.
  - This can be seen as the bottom part of the U or the "latent space".

- The third pattern is as follows:
  $$\xRightarrow{(resize)in}(\text{transpose conv} + (2 \times \text{conv}))$$
  - This pattern doubles the dimensions of the input channels with the help of the transpose convolution operation.
  - The following convolution operations halve the number of channels while also decreasing the dimensions.
  - This block is repeated **four** times.
  - This represents the last part of the U or the "decoder".

- The final output block is a 1-D convolution performed on its inputs to give us a segmentation map.

Throughout the above, skip connections are used to give context to the layer pairs between the encoder input and the decoder output. Resizing is implemented here to make sure the inputs comply with the given layer's expected dimensions.

## Loading and Studying the Dataset

We will now load and study what we are working with. When loading the dataset, since it is not in the form of an h5 file, it might be a bit more hands-on to get working.

Any images we get as an input will need to be resized to fit the specification given by the original paper $(572,572)$

<!-- ModuleList -->

In [None]:
class eyeballDataset:
    def __init__(self, image_path:Path, mask_path:Path):

        self.images = {p.stem: p for p in image_path.iterdir()}

        # [:-5] gets rid of the "_mask" in the key values
        self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}
        
        self.ids = list(self.images.keys())

        # Resize the image according to the paper and  
        self.transforms = {
            transforms.Resize(572),
            transforms.ToTensor()
        }

    def __len__(self):
        return len(self.data)

    # Loads up the image and mask according to the ID provided
    def __getitem__(self, idx):
        id_ = self.ids[idx]
        image = Image.open(self.images[id_])
        mask = Image.open(self.masks[id_])
        image = self.transforms(image)
        mask = self.transforms(mask)

        return image,mask

Now, let's load our data and the related masks to see what we are dealing with. This is also a way to ensure that we have implemented the dataloader properly. 

In [None]:
train_dataset = eyeballDataset("data/training/images","data/training/mask/")
test_dataset = eyeballDataset("data/test/mask/")