In [1]:
import torch
import torch.nn as nn

In [34]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv
"""Convolution decreases image size, if we don't use padding."""

In [35]:
def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2] ## 2 is the height and width position in B,C,H,W)
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta] ## represents B,C,H,W

In [38]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        ### 1st part encoder
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        """"When performing upconvolution, we are taking image from the previous double convolution layer.
        We are cropping it to match the size of the upconvolution layer and concatenate them before passing to the next double convolution layer.
        We are doinng it at 4 stages. 
        """
        
        ### 2nd part decoder
        self.uptrans_1 = nn.ConvTranspose2d(
            in_channels=1024, 
            out_channels=512, 
            kernel_size=2, 
            stride=2)
        
        """While Transpose convolution increase the image size depending on the stride and kernel size."""
        
        self.upconv_1 = double_conv(1024, 512)
        
        self.uptrans_2 = nn.ConvTranspose2d(
            in_channels=512, 
            out_channels=256, 
            kernel_size=2, 
            stride=2)
        
        self.upconv_2 = double_conv(512, 256)
        
        self.uptrans_3 = nn.ConvTranspose2d(
            in_channels=256, 
            out_channels=128, 
            kernel_size=2, 
            stride=2)
        
        self.upconv_3 = double_conv(256, 128)
        
        self.uptrans_4 = nn.ConvTranspose2d(
            in_channels=128, 
            out_channels=64, 
            kernel_size=2, 
            stride=2)
        
        self.upconv_4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(
            in_channels=64, 
            out_channels=2, 
            kernel_size=1
            )
        
    def forward(self, image):
        # encoder
        x1 = self.down_conv_1(image) # we will need x1, x3, x5, x7 in the decoder part. We don't need x9 in the decoder part because there is no maxpooling.
        print(x1.size())
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2) #
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4) #
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6) #
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        # print(x9.size()) ## (Batch_size, Channel, Height, Width)
        
        # In decoder part we need to concatenate x to x7, x to x5, x to x3, x to x1 
        x = self.uptrans_1(x9)
        # print(x.size())
        y = crop_img(x7, x)
        x = self.upconv_1(torch.cat([x, y], 1))
        # print(x.size()) # #torch.Size([1, 512, 56, 56])
        # print(x7.size()) ##torch.Size([1, 512, 64, 64]) 
        ####(For concatenation, both the left and right side images (i.e., x7 and x) should have the same size or tensors should have same size)
        ### There are different ways to make the size same, one way is to use padding, another way is to use interpolation, cropping, etc. 
        # print(y.size()) #torch.Size([1, 512, 56, 56])
        
        x = self.uptrans_2(x)
        y = crop_img(x5, x)
        x = self.upconv_2(torch.cat([x, y], 1))
        
        x = self.uptrans_3(x)
        y = crop_img(x3, x)
        x = self.upconv_3(torch.cat([x, y], 1))
        
        x = self.uptrans_4(x)
        y = crop_img(x1, x)
        x = self.upconv_4(torch.cat([x, y], 1))
        
        
        x = self.out(x)
        print(x.size()) 
        return x
        
        
        
        
        
        
        
        
        

In [39]:
if __name__ == '__main__':
        image = torch.rand((1, 1, 572, 572))
        model = UNet()
        print(model(image))

torch.Size([1, 64, 568, 568])
torch.Size([1, 2, 388, 388])
tensor([[[[-0.0712, -0.0680, -0.0710,  ..., -0.0736, -0.0666, -0.0636],
          [-0.0670, -0.0713, -0.0716,  ..., -0.0725, -0.0673, -0.0702],
          [-0.0630, -0.0731, -0.0647,  ..., -0.0710, -0.0749, -0.0740],
          ...,
          [-0.0740, -0.0726, -0.0694,  ..., -0.0647, -0.0728, -0.0694],
          [-0.0667, -0.0682, -0.0694,  ..., -0.0695, -0.0648, -0.0711],
          [-0.0689, -0.0697, -0.0713,  ..., -0.0719, -0.0699, -0.0708]],

         [[-0.1199, -0.1179, -0.1152,  ..., -0.1151, -0.1170, -0.1062],
          [-0.1163, -0.1140, -0.1129,  ..., -0.1123, -0.1185, -0.1165],
          [-0.1112, -0.1185, -0.1136,  ..., -0.1151, -0.1193, -0.1145],
          ...,
          [-0.1156, -0.1193, -0.1118,  ..., -0.1104, -0.1141, -0.1142],
          [-0.1122, -0.1131, -0.1191,  ..., -0.1224, -0.1132, -0.1167],
          [-0.1108, -0.1239, -0.1166,  ..., -0.1180, -0.1160, -0.1149]]]],
       grad_fn=<ConvolutionBackward0>)
