<a href="https://colab.research.google.com/github/alirezaghl/Architectures/blob/main/Unet_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a simple implementation of Unet based on the original paper.

In [1]:
import torch
import torch.nn as nn

Helper function that is consist of two convolution operation followed by a relu activation function, we are gonna use this function throughout the architecture. Both are 2D convolutions with the kernel size of 2 and without padding as mentioned in the paper.

In [2]:
def double_convolution(in_ch, out_ch):
  output = nn.Sequential(
      nn.Conv2d(in_ch, out_ch, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_ch, out_ch, kernel_size=3),
      nn.ReLU(inplace=True)
      )
  return output


Let’s make sure this implementation works:


In [3]:
x = torch.randn(1, 1, 572, 572)
y = double_convolution(1,64)
z = y(x)
z.shape

torch.Size([1, 64, 568, 568])

The first part is the contractive path wich takes channel numbers from 1 to 1024. For the first part, we perform the MaxPooling operation to the outputs of double convolution function. As mentioned in the paper the kernel size and number of strides for MaxPooling operation are both 2. the final feature map of the contractive path is of size 1024x28x28.\
The second part is the expansive path is just like the previous one but this time instead of using MaxPooling, we are going to use up_convolution operation followd by concatenation with correspondingly feature map from the contractive path, but there is one problem !\
The outputs of the contracting part must be cropped so that their dimensions match those of the expansive part. The final feature map is 64x388x388, followed by another convolution that reduces the number of channels from 64 to two.

In [4]:
class Unet(nn.Module):
  def __init__(self):
    super(Unet, self).__init__()

    self.conv1 = double_convolution(1,64)
    self.conv2 = double_convolution(64, 128)
    self.conv3 = double_convolution(128, 256)
    self.conv4 = double_convolution(256, 512)
    self.conv5 = double_convolution(512, 1024)
    self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    self.deconv1 = nn.ConvTranspose2d(kernel_size=2, stride=2, in_channels=1024, out_channels=512)
    self.deconv2 = nn.ConvTranspose2d(kernel_size=2, stride=2, in_channels=512, out_channels=256)
    self.deconv3 = nn.ConvTranspose2d(kernel_size=2, stride=2, in_channels=256, out_channels=128)
    self.deconv4 = nn.ConvTranspose2d(kernel_size=2, stride=2, in_channels=128, out_channels=64)
    self.conv6 = double_convolution(1024,512)
    self.conv7 = double_convolution(512, 256)
    self.conv8 = double_convolution(256, 128)
    self.conv9 = double_convolution(128, 64)
    self.conv10 = nn.Conv2d(64, 2, kernel_size=1)



  def forward(self, image):
    x1 = self.Maxpool(self.conv1(image))
    x1_output = self.conv1(image)[:, :, :392, :392]
    x2 = self.Maxpool(self.conv2(x1))
    x2_output = self.conv2(x1)[:, :, :200, :200]
    x3 = self.Maxpool(self.conv3(x2))
    x3_output = self.conv3(x2)[:, :, :104, :104]
    x4 = self.Maxpool(self.conv4(x3))
    x4_output = self.conv4(x3)[:, :, :56, :56]
    x5 = self.conv5(x4)
    x6 = self.deconv1(x5)
    x6_cat = self.conv6(torch.cat([x6,x4_output], 1))
    x7 = self.deconv2(x6_cat)
    x7_cat = self.conv7(torch.cat([x7, x3_output], 1))
    x8 = self.deconv3(x7_cat)
    x8_cat = self.conv8(torch.cat([x8, x2_output], 1))
    x9 = self.deconv4(x8_cat)
    x9_cat = self.conv9(torch.cat([x9, x1_output], 1))
    out = self.conv10(x9_cat)

    return out.shape

Let’s make sure this implementation works:

In [5]:
x = torch.randn(1, 1, 572, 572)
y = Unet()
z = y(x)
z

torch.Size([1, 2, 388, 388])