In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.functional as F

In [3]:
def conv(in_channels, out_channels):
  x = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3),
      nn.ReLU(),
      nn.Conv2d(out_channels, out_channels, kernel_size=3),
      nn.ReLU()
  )
  return x


def up_conv(in_channels, out_channels):
  x = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=2),
      nn.ReLU(),
      nn.Conv2d(out_channels, out_channels, kernel_size=2),
      nn.ReLU()
  )
  return x


def skip_connection(src, trg):
  src_w = src.shape[2]
  src_h = src.shape[3]

  trg_w = trg.shape[2]
  trg_h = trg.shape[3]

  trg = trg[:, :, : src_w, :src_h]
  return torch.cat((src, trg), 1)

In [49]:
class Encoder(nn.Module):
  def __init__(self):
    super().__init__()
    
    #block1:
    self.conv_1 = conv(in_channels=3, out_channels=64)
    self.max_pool_1 = nn.MaxPool2d(2,2)

    #block2:
    self.conv_2 = conv(64, 128)
    self.max_pool_2 = nn.MaxPool2d(2,2)

    #block3:
    self.conv_3 = conv(128, 256)
    self.max_pool_3 = nn.MaxPool2d(2,2)

    #block4:
    self.conv_4 = conv(256, 512)
    self.max_pool_4 = nn.MaxPool2d(2,2)

    #block5:
    self.conv_5 = conv(512, 1024)
    

  def forward(self, img):
    #block1
    a1 = self.conv_1(img)
    a2 = self.max_pool_1(a1)
    print('First Convolution Block', a1.shape)

    #block2
    a3 = self.conv_2(a2)
    a4 = self.max_pool_2(a3)
    print('Second Convolution Block', a3.shape)

    #block3
    a5 = self.conv_3(a4)
    a6 = self.max_pool_3(a5)
    print('Third Convolution Block', a5.shape)

    #block4
    a7 = self.conv_4(a6)
    a8 = self.max_pool_4(a7)
    print('Fourth Convolution Block', a7.shape)

    #block5
    a9 = self.conv_5(a8)
    print('Fifth Convolution Block', a9.shape)
    

    return a9, a7, a5, a3, a1


class Decoder(nn.Module):
  def __init__(self):
    super().__init__()

    #block1:
    self.trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
    self.up_conv_1 = conv(1024, 512)
    
    #block2:
    self.trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
    self.up_conv_2 = conv(512, 256)

    #block3:
    self.trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
    self.up_conv_3 = conv(256, 128)

    #block4:
    self.trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
    self.up_conv_4 = conv(128, 64)
    self.up_conv_5 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    #encoder:
    self.encoder = Encoder()


  def forward(self, x):

    #down sampling from the encoder
    out = self.encoder(x)
    a9 = out[0]
    a7 = out[1]
    a5 = out[2]
    a3 = out[3]
    a1 = out[4]
    
    #Now we begin the up sampling

    #block1
    trans1 = self.trans_1(a9)
    skip_1 = skip_connection(trans1, a7)
    up_conv_1 = self.up_conv_1(skip_1)
    print('1st Upsampled Convolution')

    #block2:
    trans2 = self.trans_2(up_conv_1)
    skip_2 = skip_connection(trans2, a5)
    up_conv_2 = self.up_conv_2(skip_2)
    print('2nd Upsampled Convolution')

    #block3:
    trans3 = self.trans_3(up_conv_2)
    skip_3 = skip_connection(trans3, a3)
    up_conv_3 = self.up_conv_3(skip_3)
    print('3rd Upsampled Convolution')

    #block2:
    trans4 = self.trans_4(up_conv_3)
    skip_4 = skip_connection(trans4, a1)
    up_conv_4 = self.up_conv_4(skip_4)
    print('4th Upsampled Convolution')

    out = self.up_conv_5(up_conv_4)
    print('Finally the output (FUCK YOU BITCH)', out.shape)
    
    return out

In [50]:
decoder = Decoder()

In [54]:
x = torch.rand(1, 3, 572, 572)

In [55]:
output = decoder(x)

First Convolution Block torch.Size([1, 64, 568, 568])
Second Convolution Block torch.Size([1, 128, 280, 280])
Third Convolution Block torch.Size([1, 256, 136, 136])
Fourth Convolution Block torch.Size([1, 512, 64, 64])
Fifth Convolution Block torch.Size([1, 1024, 28, 28])
1st Upsampled Convolution
2nd Upsampled Convolution
3rd Upsampled Convolution
4th Upsampled Convolution
Finally the output (FUCK YOU BITCH) torch.Size([1, 2, 388, 388])
