In [None]:
import torch.nn as nn
import torch

In [None]:
"""
Out UNet model will have 2 padded convolution layers followed by maxpooling during downsampling and Transpose convolutions during upsampling.
"""

def double_conv(in_channels , out_channels):
  convoluted_data = nn.Sequential(
      
      nn.Conv2d(in_channels , out_channels , 3 , stride = 1 , padding = 1) ,
      nn.ReLU(inplace = True) , 
      nn.BatchNorm2d(out_channels) ,
      


      nn.Conv2d(out_channels , out_channels , 3 , stride = 1 , padding = 1) , 
      nn.ReLU(inplace = True) , 
      nn.BatchNorm2d(out_channels) 

  )

  return convoluted_data

In [None]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet , self).__init__()

    self.maxpool = nn.MaxPool2d(2 , 2)

    self.down_conv1 = double_conv(3 , 16)
    self.down_conv2 = double_conv(3 , 32)
    self.down_conv3 = double_conv(32 , 64)
    self.down_conv4 = double_conv(64 , 128)
    self.down_conv5 = double_conv(128 , 256)
    self.down_conv6 = double_conv(256 , 512)
    self.down_conv7 = double_conv(512 , 1024)

    self.up_conv1 = nn.ConvTranspose2d(1024 , 512 , 2 , 2)
    self.up_conv1a = double_conv(1024 , 512)

    self.up_conv2 = nn.ConvTranspose2d(512 , 256 , 2 , 2)
    self.up_conv2a = double_conv(512 , 256)

    self.up_conv3 = nn.ConvTranspose2d(256 , 128 , 2 , 2)
    self.up_conv3a = double_conv(256 , 128)

    self.up_conv4 = nn.ConvTranspose2d(128 , 64 , 2 , 2)
    self.up_conv4a = double_conv(128 , 64)

    self.up_conv5 = nn.ConvTranspose2d(64 , 32 , 2 , 2)
    self.up_conv5a = double_conv(64 , 32)

    self.up_conv6 = nn.ConvTranspose2d(32 , 16 , 2 , 2)
    self.up_conv6a = double_conv(32 , 16)

    self.final_conv = nn.Conv2d(32 , 1 , 1 , 1)

  def forward(self , x):

    #ENCODER PASS
    x1 = self.down_conv1(x) #
    x2 = self.maxpool(x1)
    
    x3 = self.down_conv2(x) #
    x4 = self.maxpool(x3)
    
    x5 = self.down_conv3(x4) #
    x6 = self.maxpool(x5)
    
    x7 = self.down_conv4(x6) #
    x8 = self.maxpool(x7)
    
    x9 = self.down_conv5(x8) #
    x10 = self.maxpool(x9)
    
    x11 = self.down_conv6(x10) #
    x12 = self.maxpool(x11)
    
    x13 = self.down_conv7(x12)    
    

    #DECODER PASS
    x14 = self.up_conv1(x13)
    x14 = torch.cat([x14 , x11] , 1)
    x15 = self.up_conv1a(x14)

    x16 = self.up_conv2(x15)
    x16 = torch.cat([x16 , x9] , 1)
    x17 = self.up_conv2a(x16)

    x18 = self.up_conv3(x17)
    x18 = torch.cat([x18 , x7] , 1)
    x19 = self.up_conv3a(x18)

    x20 = self.up_conv4(x19)
    x20 = torch.cat([x20 , x5] , 1)
    x21 = self.up_conv4a(x20)

    x22 = self.up_conv5(x21)
    x22 = torch.cat([x22 , x3] , 1)
    x23 = self.up_conv5a(x22)

    x24 = self.up_conv6(x23)
    x24 = torch.cat([x24 , x1] , 1)
    x25 = self.up_conv6a(x24)

    output = self.final_conv(x25)

    return F.sigmoid(output )
  
  
unet = UNet()
unet.to("cuda")