# U-NET Implementation

U-NET Implementation based in it's paper using Pytorch.

## References:

Original paper at [arxiv](https://arxiv.org/abs/1505.04597).

Great [video explaining the architecture](https://www.youtube.com/watch?v=NhdzGfB1q74).

Implementation [coded along with this video](https://www.youtube.com/watch?v=IHq1t7NxS8k).

## Architecture:

UNET (name given because of the shape of the architecure) uses a Encoder-Decoder architecture. The encoder is responsible for extracting features from the image by downsizing it using CNN's (followed by ReLU) and max pooling (while the net downsizes it, the channels double for each step). The decoder is responsible for upsampling the image (using Transposed Convolution). While upsampling, the Net halves the channels for each step.

Each connection (gray arrows) concatenates the equivalent images in upscaling and downscaling. This step is in important, for instance, for image segmentation, so the Net can learn how pixel-perfectly delimit objects in an image given it's mask.


![Architecture for U-NET](architecture.png)

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

### Observations:

- The original paper did not include Batch Normalization in it's architecture, but using it here stabilizes learning, allows higher learning rates and faster convergence.

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 
                      out_channels, 
                      kernel_size=3, 
                      stride=1,
                      padding=1,
                      bias=False), # bias is cancelled by BatchNorm
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, 
                      out_channels, 
                      kernel_size=3, 
                      stride=1,
                      padding=1,
                      bias=False), # bias is cancelled by BatchNorm
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Going downwards
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Going Upwards
        for feature in reversed(features):
            # Upsampling
            self.ups.append(nn.ConvTranspose2d(feature*2, 
                                               feature, 
                                               kernel_size=2,
                                               stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], 2*features[-1]) # dowmonst Double Convolution
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) # conv 1x1
    
    def forward(self, x):
        skip_connections = []

        # Encoding
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        # At this point, skip_connections has copies of the encoding steps
        # The Decoding part needs to concatenate these copies onto the
        # upsampling steps

        # Decoding
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            # Originally, the Net accepts only dimensions
            # that are multiple of 16 (4 max pools -> 2^4)
            # To solve this, 
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:]) # [2:] skips batch size & number of channels

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
        
        return self.final_conv(x)