## To explore possible variations

* Trying skip connections within double_conv_layers.
* stacking inception modules.

In [5]:
import os
import PIL
import torch
import torchvision
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

In [24]:
def double_conv_layers(in_channels, out_channels, kernel_size, activation, padding='same', batch_norm=True):
  '''
  Return Double Convolutional layers given the input parameters

  in_channels: input channels for the first convolutional layer
  out_channels: output channels for the second convolutional layer
  kernel_size: kernel size to use for both the layers
  activation: activaiton to apply to both the layers, should pass a activation function and not string.
  padding: padding to be applied to the inputs, by default no padding.
  batch_norm: if True applies nn.BatchNorm2d() after every Convolutional layer.
  '''

  if batch_norm:
    double_conv = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
                                nn.BatchNorm2d(out_channels),
                                activation(inplace=True),
                                nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
                                nn.BatchNorm2d(out_channels),
                                activation(inplace=True))
  else:
    double_conv = nn.Sequential(
                                nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
                                activation(inplace=True),
                                nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
                                activation(inplace=True))
  
  return double_conv


class InceptionModule(nn.Module):
  def __init__(self, input_channels):
    super().__init__()

    
    if input_channels % 4 == 0:
      out_channels = [int(input_channels/4) for i in range(4)]
    else:
      out_channels = [int(input_channels//4) if i<3 else input_channels - (3*int(input_channels//4)) for i in range(4)]


    self.channel_1 = nn.Conv2d(in_channels=input_channels, out_channels=out_channels[0], kernel_size=1, stride=1, padding='same')

    self.channel_2 = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=out_channels[1], kernel_size=1, stride=1, padding='same'),
                              nn.Conv2d(in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=3, stride=1, padding='same'))
    
    self.channel_3 = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=out_channels[2], kernel_size=1, stride=1, padding='same'),
                              nn.Conv2d(in_channels=out_channels[2], out_channels=out_channels[2], kernel_size=5, stride=1, padding='same'))
    
    self.channel_4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                              nn.Conv2d(in_channels=input_channels, out_channels=out_channels[3], kernel_size=1, stride=1, padding='same'))
    
  def forward(self, input):
    print(f'input shape : {input.shape}')
    x1 = self.channel_1(input)
    print(f'Channel 1 : {x1.shape}')
    x2 = self.channel_2(input)
    print(f'Channel 2 : {x2.shape}')
    x3 = self.channel_3(input)
    print(f'Channel 3 : {x3.shape}')
    x4 = self.channel_4(input)
    print(f'Channel 4 : {x4.shape}')
    x = torch.cat([x1, x2, x3, x4], 1)
    print(f'Final shape : {x.shape}')
    return x

class UNet(nn.Module):
  def __init__(self, 
               down_conv_out=[64, 128, 256, 512], 
               down_conv_ks=[3, 3, 3, 3],
               down_conv_activation=nn.ReLU,
               up_conv_out=[256, 128, 64],
               up_conv_ks=[3, 3, 3],
               up_conv_activation=nn.ReLU,
               pad='same',
               add_inception=False):
    super().__init__()
    
    self.down_conv_out = down_conv_out
    self.down_conv_ks = down_conv_ks
    self.down_conv_activation = down_conv_activation
    self.up_conv_out = up_conv_out
    self.up_conv_ks = up_conv_ks
    self.up_conv_activation = up_conv_activation
    self.pad = pad
    self.add_inception = add_inception

    # Down Conv Layers
    self.down_conv1 = double_conv_layers(3, down_conv_out[0], down_conv_ks[0], down_conv_activation, padding=pad)
    self.down_conv2 = double_conv_layers(down_conv_out[0], down_conv_out[1], down_conv_ks[1], down_conv_activation, padding=pad)
    self.down_conv3 = double_conv_layers(down_conv_out[1], down_conv_out[2], down_conv_ks[2], down_conv_activation, padding=pad)
    self.down_conv4 = double_conv_layers(down_conv_out[2], down_conv_out[3], down_conv_ks[3], down_conv_activation, padding=pad)

    # Inception Modules
    self.inception_module_1 = InceptionModule(down_conv_out[0])
    self.inception_module_2 = InceptionModule(down_conv_out[1])
    self.inception_module_3 = InceptionModule(down_conv_out[2])
    
    # Conv Transpose layers
    self.up_transpose1 = nn.ConvTranspose2d(down_conv_out[3], up_conv_out[0], 2, 2)
    self.up_transpose2 = nn.ConvTranspose2d(up_conv_out[0], up_conv_out[1], 2, 2)
    self.up_transpose3 = nn.ConvTranspose2d(up_conv_out[1], up_conv_out[2], 2, 2)
    
    # Up Conv Layers
    self.up_conv1 = double_conv_layers(down_conv_out[3], up_conv_out[0], up_conv_ks[0], up_conv_activation, padding=pad)
    self.up_conv2 = double_conv_layers(up_conv_out[0], up_conv_out[1], up_conv_ks[1], up_conv_activation, padding=pad)
    self.up_conv3 = double_conv_layers(up_conv_out[1], up_conv_out[2], up_conv_ks[2], up_conv_activation, padding=pad)

    # final output conv
    self.output_conv = nn.Conv2d(up_conv_out[2], 3, 1)

    # Maxpooling
    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)


  def forward(self, input):

    # Down Conv Encoder Part
    print(f'Start : {input.shape}')
    x1 = self.down_conv1(input)
    if self.add_inception:
      x1 = self.inception_module_1(x1)
    print(f'After Down Conv 1 : {x1.shape}')
    x = self.maxpool(x1)
    print(f'After maxpool : {x.shape}')
    x2 = self.down_conv2(x)
    if self.add_inception:
      x2 = self.inception_module_2(x2)
    print(f'After Down Conv 2 : {x2.shape}')
    x = self.maxpool(x2)
    print(f'After maxpool : {x.shape}')
    x3 = self.down_conv3(x)
    if self.add_inception:
      x3 = self.inception_module_3(x3)
    print(f'After Down Conv 3 : {x3.shape}')
    x = self.maxpool(x3)
    print(f'After maxpool : {x.shape}')
    x4 = self.down_conv4(x)
    print(f'After Down Conv 4 : {x4.shape}')

    # Up Conv Decoder Part
    x = self.up_transpose1(x4)
    print(f'After Up Transpose 1 : {x.shape}')
    x = self.up_conv1(torch.cat([x, x3], 1)) # skip connection from down_conv3
    print(f'After Up Conv 1 : {x.shape}')
    x = self.up_transpose2(x)
    print(f'After Up Transpose 2 : {x.shape}')
    x = self.up_conv2(torch.cat([x, x2], 1)) # skip connection from down_conv2
    print(f'After Up Conv 2 : {x.shape}')
    x = self.up_transpose3(x)
    print(f'After Up Transpose 3 : {x.shape}')
    x = self.up_conv3(torch.cat([x, x1], 1)) # skip connection from down_conv1
    print(f'After Up Conv 3 : {x.shape}')

    # final output conv layer
    x = self.output_conv(x)
    print(f'After Final output conv : {x.shape}')

    return x

In [25]:
# without inception modules
image = torch.zeros(1, 3, 128, 128)
model = UNet(add_inception=False)
x = model(image)

Start : torch.Size([1, 3, 128, 128])
After Down Conv 1 : torch.Size([1, 64, 128, 128])
After maxpool : torch.Size([1, 64, 64, 64])
After Down Conv 2 : torch.Size([1, 128, 64, 64])
After maxpool : torch.Size([1, 128, 32, 32])
After Down Conv 3 : torch.Size([1, 256, 32, 32])
After maxpool : torch.Size([1, 256, 16, 16])
After Down Conv 4 : torch.Size([1, 512, 16, 16])
After Up Transpose 1 : torch.Size([1, 256, 32, 32])
After Up Conv 1 : torch.Size([1, 256, 32, 32])
After Up Transpose 2 : torch.Size([1, 128, 64, 64])
After Up Conv 2 : torch.Size([1, 128, 64, 64])
After Up Transpose 3 : torch.Size([1, 64, 128, 128])
After Up Conv 3 : torch.Size([1, 64, 128, 128])
After Final output conv : torch.Size([1, 3, 128, 128])


In [26]:
# with inception modules
image = torch.zeros(1, 3, 128, 128)
model = UNet(add_inception=True)
x = model(image)

Start : torch.Size([1, 3, 128, 128])
input shape : torch.Size([1, 64, 128, 128])
Channel 1 : torch.Size([1, 16, 128, 128])
Channel 2 : torch.Size([1, 16, 128, 128])
Channel 3 : torch.Size([1, 16, 128, 128])
Channel 4 : torch.Size([1, 16, 128, 128])
Final shape : torch.Size([1, 64, 128, 128])
After Down Conv 1 : torch.Size([1, 64, 128, 128])
After maxpool : torch.Size([1, 64, 64, 64])
input shape : torch.Size([1, 128, 64, 64])
Channel 1 : torch.Size([1, 32, 64, 64])
Channel 2 : torch.Size([1, 32, 64, 64])
Channel 3 : torch.Size([1, 32, 64, 64])
Channel 4 : torch.Size([1, 32, 64, 64])
Final shape : torch.Size([1, 128, 64, 64])
After Down Conv 2 : torch.Size([1, 128, 64, 64])
After maxpool : torch.Size([1, 128, 32, 32])
input shape : torch.Size([1, 256, 32, 32])
Channel 1 : torch.Size([1, 64, 32, 32])
Channel 2 : torch.Size([1, 64, 32, 32])
Channel 3 : torch.Size([1, 64, 32, 32])
Channel 4 : torch.Size([1, 64, 32, 32])
Final shape : torch.Size([1, 256, 32, 32])
After Down Conv 3 : torch.S

In [27]:
model = UNet(down_conv_out=[32, 64, 128, 256],
             down_conv_ks=[5, 3, 3, 3],
             down_conv_activation=nn.SELU,
             up_conv_out=[128, 64, 32],
             up_conv_activation=nn.SELU)
x = model(image)

Start : torch.Size([1, 3, 128, 128])
After Down Conv 1 : torch.Size([1, 32, 128, 128])
After maxpool : torch.Size([1, 32, 64, 64])
After Down Conv 2 : torch.Size([1, 64, 64, 64])
After maxpool : torch.Size([1, 64, 32, 32])
After Down Conv 3 : torch.Size([1, 128, 32, 32])
After maxpool : torch.Size([1, 128, 16, 16])
After Down Conv 4 : torch.Size([1, 256, 16, 16])
After Up Transpose 1 : torch.Size([1, 128, 32, 32])
After Up Conv 1 : torch.Size([1, 128, 32, 32])
After Up Transpose 2 : torch.Size([1, 64, 64, 64])
After Up Conv 2 : torch.Size([1, 64, 64, 64])
After Up Transpose 3 : torch.Size([1, 32, 128, 128])
After Up Conv 3 : torch.Size([1, 32, 128, 128])
After Final output conv : torch.Size([1, 3, 128, 128])
