In [11]:
import numpy as np
import PIL
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [2]:
def double_conv_layers(in_channels, out_channels, kernel_size, activation, padding=0, batch_norm=True):
  '''
  Return Double Convolutional layers given the input parameters

  in_channels: input channels for the first convolutional layer
  out_channels: output channels for the second convolutional layer
  kernel_size: kernel size to use for both the layers
  activation: activaiton to apply to both the layers, should pass a activation function and not string.
  padding: padding to be applied to the inputs, by default no padding.
  batch_norm: if True applies nn.BatchNorm2d() after every Convolutional layer.
  '''

  if batch_norm:
    double_conv = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
                                nn.BatchNorm2d(out_channels),
                                activation(inplace=True),
                                nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
                                nn.BatchNorm2d(out_channels),
                                activation(inplace=True))
  else:
    double_conv = nn.Sequential(
                                nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
                                activation(inplace=True),
                                nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
                                activation(inplace=True))
  
  return double_conv

def inception_module(inputs, channel_input_dict):
  channel_1 = nn.Conv2d(in_channels=c1_in, out_channels=c1_out, kernel_size=1, stride=1, padding=1)

  channel_2 = nn.Sequential(nn.Conv2d(in_channels=c2_in, out_channels=c2_out1, kernel_size=1, stride=1, padding=1),
                            nn.Conv2d(in_channels=c2_out1, out_channels=c2_out2, kernel_size=1, stride=1, padding=1))
  
  channel_3 = nn.Sequential(nn.Conv2d(in_channels=c3_in, out_channels=c3_out1, kernel_size=5, stride=1, padding=1),
                            nn.Conv2d(in_channels=c3_out1, out_channels=c3_out2, kernel_size=3, stride=1, padding=1))
  
  channel_4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                            nn.Conv2d(in_channels=c4_in, out_channels=c4_out, kernel_size=1, stride=1, padding=1))

class UNet(nn.Module):
  def __init__(self):
    super().__init__()

    # Down Conv Layers
    self.down_conv1 = double_conv_layers(3, 64, 3, nn.ReLU, padding=1)
    self.down_conv2 = double_conv_layers(64, 128, 3, nn.ReLU, padding=1)
    self.down_conv3 = double_conv_layers(128, 256, 3, nn.ReLU, padding=1)
    self.down_conv4 = double_conv_layers(256, 512, 3, nn.ReLU, padding=1)
    
    # Conv Transpose layers
    self.up_transpose1 = nn.ConvTranspose2d(512, 256, 2, 2)
    self.up_transpose2 = nn.ConvTranspose2d(256, 128, 2, 2)
    self.up_transpose3 = nn.ConvTranspose2d(128, 64, 2, 2)
    
    # Up Conv Layers
    self.up_conv1 = double_conv_layers(512, 256, 3, nn.ReLU, padding=1)
    self.up_conv2 = double_conv_layers(256, 128, 3, nn.ReLU, padding=1)
    self.up_conv3 = double_conv_layers(128, 64, 3, nn.ReLU, padding=1)

    # final output conv
    self.output_conv = nn.Conv2d(64, 3, 1)

    # Maxpooling
    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)


  def forward(self, input):

    # Down Conv Encoder Part
    print(f'Start : {input.shape}')
    x1 = self.down_conv1(input)
    print(f'After Down Conv 1 : {x1.shape}')
    x = self.maxpool(x1)
    print(f'After maxpool : {x.shape}')
    x2 = self.down_conv2(x)
    print(f'After Down Conv 2 : {x2.shape}')
    x = self.maxpool(x2)
    print(f'After maxpool : {x.shape}')
    x3 = self.down_conv3(x)
    print(f'After Down Conv 3 : {x3.shape}')
    x = self.maxpool(x3)
    print(f'After maxpool : {x.shape}')
    x4 = self.down_conv4(x)
    print(f'After Down Conv 4 : {x4.shape}')

    # Up Conv Decoder Part
    x = self.up_transpose1(x4)
    print(f'After Up Transpose 1 : {x.shape}')
    x = self.up_conv1(torch.cat([x, x3], 1)) # skip connection from down_conv3
    print(f'After Up Conv 1 : {x.shape}')
    x = self.up_transpose2(x)
    print(f'After Up Transpose 2 : {x.shape}')
    x = self.up_conv2(torch.cat([x, x2], 1)) # skip connection from down_conv2
    print(f'After Up Conv 2 : {x.shape}')
    x = self.up_transpose3(x)
    print(f'After Up Transpose 3 : {x.shape}')
    x = self.up_conv3(torch.cat([x, x1], 1)) # skip connection from down_conv1
    print(f'After Up Conv 3 : {x.shape}')

    # final output conv layer
    x = self.output_conv(x)
    print(f'After Final output conv : {x.shape}')

    return x

In [44]:
image = torch.zeros(1, 3, 128, 128)
model = UNet()
x = model(image)

Start : torch.Size([1, 3, 128, 128])
After Down Conv 1 : torch.Size([1, 64, 128, 128])
After maxpool : torch.Size([1, 64, 64, 64])
After Down Conv 2 : torch.Size([1, 128, 64, 64])
After maxpool : torch.Size([1, 128, 32, 32])
After Down Conv 3 : torch.Size([1, 256, 32, 32])
After maxpool : torch.Size([1, 256, 16, 16])
After Down Conv 4 : torch.Size([1, 512, 16, 16])
After Up Transpose 1 : torch.Size([1, 256, 32, 32])
After Up Conv 1 : torch.Size([1, 256, 32, 32])
After Up Transpose 2 : torch.Size([1, 128, 64, 64])
After Up Conv 2 : torch.Size([1, 128, 64, 64])
After Up Transpose 3 : torch.Size([1, 64, 128, 128])
After Up Conv 3 : torch.Size([1, 64, 128, 128])
After Final output conv : torch.Size([1, 3, 128, 128])
