# 2 parts: 
- Encoder
- Decoder
## Encoder
- Encoder is a normal feature extractor.  
- Encoder module can be broken down into subparts, we can call it as EncoderPart.  

- ### EncoderPart
    - Each EncoderPart is made up of 2 3x3 convolutions with RELU.  
- ### Maxpool Layer
    - After every EncoderPart we do 2x2 maxpool

- ### UpConvolution Layer
    - Before every DecoderPart 2x2 Upconvolution is performed
- ### DecoderPart
    - Each DecoderPart is made from concatinating from a previous layer having same number of channels and then 2 3x3 convolution followed by RELU

- ### OutputLayer
    - 1x1 conv applied to the last DecoderPart

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [26]:
TEST = True

In [15]:
class EncoderPart(nn.Module):
    def __init__(self, input_channel, output_channel):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, kernel_size=3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(output_channel, output_channel, kernel_size=3)
    
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.relu(x1)
        x3 = self.conv2(x2)
        x4 = self.relu(x3)
        return x4

In [27]:
encoder_part = EncoderPart(1, 64)

## Test
if TEST:
    x = torch.randn(1, 1, 572, 572)
    print(encoder_part(x).shape)
    print(encoder_part.parameters)

torch.Size([1, 64, 568, 568])
<bound method Module.parameters of EncoderPart(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
)>


In [29]:
class Maxpool(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        return self.maxpool(x)

In [30]:
maxpool = Maxpool()

if TEST:
    x = torch.randn(1, 64, 568, 568)
    print(maxpool(x).shape)

torch.Size([1, 64, 284, 284])


## nn.ConvTranspose2d()
Shape:
- Input: (N,Cin,Hin,Win)
- Output: (N,Cout,Hout,Wout)

Where,
- Hout=(Hin−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1
- Wout=(Win−1)×stride[1]−2×padding[1]+dilation[1]×(kernel_size[1]−1)+output_padding[1]+1



In [34]:
class Upconvolution(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2) #input 28x28 out 56x56 || put 28 in above eqn

    def forward(self, x):
        return self.upconv(x)

In [41]:
upconv = Upconvolution(1024, 512)

if TEST:
    x = torch.randn(1, 1024, 28, 28)
    print(upconv(x).shape)

torch.Size([1, 512, 56, 56])


In [48]:
def concat(tensor_1, tensor_2):
    '''
    concatenate tensor_2 to tensor_1 
    '''
    dim_1 = tensor_1.shape[2]
    dim_2 = tensor_2.shape[2]

    part_to_remove = int((dim_2 - dim_1)/2)

    cropped_tensor_2 = tensor_2[:, :, part_to_remove-1:(part_to_remove+dim_1-1), part_to_remove-1:(part_to_remove+dim_1-1)]

    after_concat = torch.cat((tensor_1, cropped_tensor_2), dim=1) # dim=1 means concat along the channels 

    return after_concat

if TEST:
    print(concat(torch.randn(1,512,56,56), torch.randn(1,512,64,64)).shape)

torch.Size([1, 1024, 56, 56])


In [52]:
class DecoderPart(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.conv_block = EncoderPart(input_channels, output_channels)

    def forward(self, x):
        return self.conv_block(x)

In [53]:
decoder_part = DecoderPart(512, 128)

if TEST:
    x = torch.randn(1,512,104,104)
    print(decoder_part(x).shape)

torch.Size([1, 128, 100, 100])


In [54]:
class OutputLayer(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.conv_layer = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        return self.conv_layer(x)

In [55]:
output = OutputLayer(64, 2)

if TEST:
    x = torch.randn(1, 64, 388, 388)
    print(output(x).shape)

torch.Size([1, 2, 388, 388])
