In [2]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [3]:
class Stdinp(nn.Module):
  def __init__(self,in_ch,out_ch,k_size=(3,3)):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=k_size,padding=1)
    self.dropout = nn.Dropout(0.2)
    self.conv2 = nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=k_size,padding=1)

  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = self.dropout(x)
    x = F.relu(self. conv2(x))
    x = self.dropout(x)
    return x


In [4]:
class Unetplus(nn.Module):
  def __init__(self,num_classes,deep_supervision=True):
    super().__init__()
    self.channels = [32,64,128,256,512]
    self.num_classes = num_classes
    self.deep_supervision = deep_supervision
    self.maxpool = nn.MaxPool2d(stride=2,kernel_size=(2,2))
    self.conv11 = Stdinp(in_ch=3,out_ch=self.channels[0],k_size=(3,3))
    self.conv21 = Stdinp(in_ch=self.channels[0],out_ch=self.channels[1],k_size=(3,3))
    self.conv31 = Stdinp(in_ch=self.channels[1],out_ch=self.channels[2],k_size=(3,3))
    self.conv41 = Stdinp(in_ch=self.channels[2],out_ch=self.channels[3],k_size=(3,3))
    self.conv51 = Stdinp(in_ch=self.channels[3],out_ch=self.channels[4],k_size=(3,3))

    self.tconv2 = nn.ConvTranspose2d(in_channels=self.channels[1],out_channels=self.channels[1],kernel_size=2,stride=2)
    self.tconv3 = nn.ConvTranspose2d(in_channels=self.channels[2],out_channels=self.channels[2],kernel_size=2,stride=2)
    self.tconv4 = nn.ConvTranspose2d(in_channels=self.channels[3],out_channels=self.channels[3],kernel_size=2,stride=2)
    self.tconv5 = nn.ConvTranspose2d(in_channels=self.channels[4],out_channels=self.channels[4],kernel_size=2,stride=2)

    self.conv12 = Stdinp(in_ch=self.channels[0]+self.channels[1],out_ch=self.channels[0],k_size=(3,3))
    self.conv22 = Stdinp(in_ch=self.channels[1]+self.channels[2],out_ch=self.channels[1],k_size=(3,3))
    self.conv32 = Stdinp(in_ch=self.channels[2]+self.channels[3],out_ch=self.channels[2],k_size=(3,3))
    self.conv42 = Stdinp(in_ch=self.channels[3]+self.channels[4],out_ch=self.channels[3],k_size=(3,3))

    self.conv13 = Stdinp(in_ch=self.channels[0]+self.channels[0]+self.channels[1],out_ch=self.channels[0],k_size=(3,3))
    self.conv23 = Stdinp(in_ch=self.channels[1]+self.channels[1]+self.channels[2],out_ch=self.channels[1],k_size=(3,3))
    self.conv33 = Stdinp(in_ch=self.channels[2]+self.channels[2]+self.channels[3],out_ch=self.channels[2],k_size=(3,3))

    self.conv14 = Stdinp(in_ch=(self.channels[0]*3)+self.channels[1],out_ch=self.channels[0],k_size=(3,3))
    self.conv24 = Stdinp(in_ch=(self.channels[1]*3)+self.channels[2],out_ch=self.channels[1],k_size=(3,3))

    self.conv15 = Stdinp(in_ch=(self.channels[0]*4)+self.channels[1],out_ch=self.channels[0],k_size=(3,3))

    self.convf = nn.Conv2d(in_channels=self.channels[0],out_channels=self.num_classes,kernel_size=(1,1))

    

  def forward(self,x):
    x_11 = self.conv11(x)  #3 - 32
    pool1 = self.maxpool(x_11)
    x_21 = self.conv21(pool1) 
    pool2 = self.maxpool(x_21)
    x_31 = self.conv31(pool2) 
    pool3 = self.maxpool(x_31)
    x_41 = self.conv41(pool3) 
    pool4 = self.maxpool(x_41)
    x_51 = self.conv51(pool4)

    # second parallel layer
    up_21 = self.tconv2(x_21)  #64
    up_21_cat = torch.cat([x_11,up_21],axis=1)  #32+64 - 
    x_12 = self.conv12(up_21_cat)

    up_31 = self.tconv3(x_31)
    up_31_cat = torch.cat([x_21,up_31],axis=1)
    x_22 = self.conv22(up_31_cat)

    up_41 = self.tconv4(x_41)
    up_41_cat = torch.cat([x_31,up_41],axis=1)
    x_32 = self.conv32(up_41_cat)

    up_51 = self.tconv5(x_51)
    up_51_cat = torch.cat([x_41,up_51],axis=1)
    x_42 = self.conv42(up_51_cat)

    #3rd parallel layer
    up_22 = self.tconv2(x_22)
    up_22_cat = torch.cat([x_11,x_12,up_22],axis=1)
    x_13 = self.conv13(up_22_cat)

    up_32 = self.tconv3(x_32)
    up_32_cat = torch.cat([x_21,x_22,up_32],axis=1)
    x_23 = self.conv23(up_32_cat)

    up_42 = self.tconv4(x_42)
    up_42_cat = torch.cat([x_31,x_32,up_42],axis=1)
    x_33 = self.conv33(up_42_cat)

    #4th parallel layer
    up_23 = self.tconv2(x_23)
    up_23_cat = torch.cat([x_11,x_12,x_13,up_23],axis=1)
    x_14 = self.conv14(up_23_cat)

    up_33 = self.tconv3(x_33)
    up_33_cat = torch.cat([x_21,x_22,x_23,up_33],axis=1)
    x_24 = self.conv24(up_33_cat)

    # 5th parallel layer
    up_24 = self.tconv2(x_24)
    u_24_cat = torch.cat([x_11,x_12,x_13,x_14,up_24],axis=1)
    x_15 = self.conv15(u_24_cat)

    nest_out1,nest_out2,nest_out3,nest_out4 = self.convf(x_12),self.convf(x_13),self.convf(x_14),self.convf(x_15)

    if self.deep_supervision:
      output = [nest_out1,nest_out2,nest_out3,nest_out4]

    else:
      output = nest_out5

    return output

In [5]:
model = Unetplus(num_classes=1)
model

Unetplus(
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv11): Stdinp(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout(p=0.2, inplace=False)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv21): Stdinp(
    (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout(p=0.2, inplace=False)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv31): Stdinp(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout(p=0.2, inplace=False)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv41): Stdinp(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout(p=0.2, inplace=False)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pad

In [6]:
inp = torch.randint(0,255,size=(3,256,256)).unsqueeze(dim=0).float()
inp.shape

torch.Size([1, 3, 256, 256])

In [7]:
out = model(inp)

In [8]:
for i in out:
  print(i.shape)

torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
