# U-NET implementation from scratch for segmentation tasks

In [3]:
import torch
import torch.nn as nn

In [9]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
    )
    
    return conv


def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:,:, delta:tensor_size-delta, delta:tensor_size-delta]



In [17]:
class UNet(nn.Module):
    
    def __init__(self):
        super(UNet, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024, 
                                             out_channels=512, 
                                             kernel_size=2, 
                                             stride=2)
        
        self.up_conv_1 = double_conv(1024, 512)
        
        
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512, 
                                             out_channels=256, 
                                             kernel_size=2, 
                                             stride=2)
        
        self.up_conv_2 = double_conv(512, 256)
        
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256, 
                                             out_channels=128, 
                                             kernel_size=2, 
                                             stride=2)
        
        self.up_conv_3 = double_conv(256, 128)
        
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128, 
                                             out_channels=64, 
                                             kernel_size=2, 
                                             stride=2)
        
        
        self.up_conv_4 = double_conv(128, 64)
        
        
        self.out = nn.Conv2d(
            in_channels = 64,
            out_channels=2, 
            kernel_size=1
            
        
        )
        
    def forward(self, image):
        # bs, c, h, w
        # encoder
        x1 = self.down_conv_1(image)
        print("size of x1 : ", x1.size())
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        print("size of x9 : ", x9.size())
        
        # decoder
        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x, y], 1))
        
        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x, y], 1))
        
        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x, y], 1))
        
        
        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x, y], 1))
        
        print("size of x7 : ", x7.size())
        print("size of y : ", y.size())
        
        
        x = self.out(x)
        print("size of x : ", x.size())
        return x

if __name__ == "__main__":
    image = torch.rand((1, 1, 572, 572))
    model = UNet()
    print(model(image))

size of x1 :  torch.Size([1, 64, 568, 568])
size of x9 :  torch.Size([1, 1024, 28, 28])
size of x7 :  torch.Size([1, 512, 64, 64])
size of y :  torch.Size([1, 64, 392, 392])
size of x :  torch.Size([1, 2, 388, 388])
tensor([[[[0.0324, 0.0328, 0.0294,  ..., 0.0336, 0.0315, 0.0323],
          [0.0349, 0.0291, 0.0374,  ..., 0.0361, 0.0351, 0.0328],
          [0.0348, 0.0360, 0.0360,  ..., 0.0382, 0.0329, 0.0378],
          ...,
          [0.0381, 0.0310, 0.0305,  ..., 0.0336, 0.0281, 0.0364],
          [0.0317, 0.0295, 0.0310,  ..., 0.0334, 0.0299, 0.0336],
          [0.0364, 0.0325, 0.0372,  ..., 0.0351, 0.0337, 0.0303]],

         [[0.0184, 0.0191, 0.0194,  ..., 0.0182, 0.0197, 0.0229],
          [0.0218, 0.0227, 0.0210,  ..., 0.0189, 0.0187, 0.0168],
          [0.0220, 0.0217, 0.0190,  ..., 0.0201, 0.0172, 0.0204],
          ...,
          [0.0219, 0.0200, 0.0185,  ..., 0.0199, 0.0192, 0.0194],
          [0.0205, 0.0174, 0.0224,  ..., 0.0212, 0.0175, 0.0207],
          [0.0198, 0.0216,