This code is based on [this video on Youtube](https://www.youtube.com/watch?v=u1loyDCoGbE). However, in the concatination part, I guess there was a mistake and mine is different.

# Architecture
According to the [original paper](https://arxiv.org/pdf/1505.04597.pdf), U-Net has the following architecture:
![U-net Architecture](Images/unet-architecture.png)  
We can clearly see why it is called U-Net. I labeled some activation maps based on the code below.  
As we see from the above architecture, each section has a double convolution. Hence, we define a function named `double_conv` in this code to simply does that.  
**One of the important things** we also should consider is that when an **up-conv** happens (green arrows), sizes do not match; take the first one as an example, x9 before and after the up-conv is 1024x28x28 and 512x56x56, respectively. After the first up-conv, we want to concatenate this with x7 which is 512x64x64, and obviously 56 ≠ 64. Therefore, this paper suggested to crop x7 to become 512x56x56. To do that, a function named `crop_image` is implemented.

## Imports

In [1]:
import torch
import torch.nn as nn

## Functions described in the architecture section

In [2]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
    )
    
    return conv

def crop_image(tensor, target_tensor):
    target_size = target_tensor.size()[2]  # since height and width are the same we just get one of them
    tensor_size = tensor.size()[2]
    
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]


## Designing the U-Net 

In [3]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
    
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down_conv_1 = double_conv(in_c=1, out_c=64)
        self.down_conv_2 = double_conv(in_c=64, out_c=128)
        self.down_conv_3 = double_conv(in_c=128, out_c=256)
        self.down_conv_4 = double_conv(in_c=256, out_c=512)
        self.down_conv_5 = double_conv(in_c=512, out_c=1024)
        
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.up_conv_1 = double_conv(1024, 512)
        
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.up_conv_2 = double_conv(512, 256)
        
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.up_conv_3 = double_conv(256, 128)
        
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.up_conv_4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
    
    def forward(self, image):
        # encoder
        x1 = self.down_conv_1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        
        # decoder
        x = self.up_trans_1(x9)
        y = crop_image(x7, x)
        x= self.up_conv_1(torch.cat([y, x], 1))
        
        x = self.up_trans_2(x)
        y = crop_image(x5, x)
        x= self.up_conv_2(torch.cat([y, x], 1))
        
        x = self.up_trans_3(x)
        y = crop_image(x3, x)
        x= self.up_conv_3(torch.cat([y, x], 1))
        
        x = self.up_trans_4(x)
        y = crop_image(x1, x)
        x= self.up_conv_4(torch.cat([y, x], 1))
        
        x = self.out(x)
        return x

## Testing the model

In [4]:
image = torch.rand((1, 1, 572, 572))  # batch_size=1, channel=1, width & height = 572
model = UNet()
result = model(image)
result.shape

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([1, 2, 388, 388])

The output size is the same as the one recommended in the paper, as can be seen. It has two channels for image segmentation, one for the foreground and the other for the background.