This is not an exact implementation of unet paper.But I will transfer the core idea. This is a modified version of unet.

In [3]:
import torch.nn as nn
import numpy as np
import torch
import torch.nn.functional as F

In [4]:
class DownSample_Block(nn.Module):
    def __init__(self,in_ch,out_ch,down_sample=True,batch_norm=True):
        super().__init__()
        self.down_sample = down_sample
        self.batch_norm = batch_norm
        self.pool = nn.MaxPool2d(2)
        self.bn = nn.BatchNorm2d(num_features = in_ch)
        self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1)
        # self.dropout = nn.Dropout(p=0.2)


    def forward(self,x):
        if self.down_sample:
            x = self.pool(x)
        if self.batch_norm:
            x = self.bn(x)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x



class Upsample_Block(nn.Module):
    def __init__(self,in_ch,out_ch,skip_ch):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.bn = nn.BatchNorm2d(num_features=in_ch+skip_ch)
        self.conv1 = nn.Conv2d(in_channels=in_ch+skip_ch, out_channels=out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1)
   

    def forward(self,x,skip):
        up = self.upsample(x)
        out = torch.cat((skip,up),axis=1)
        out = self.bn(out)
        out = F.relu(self.conv1(out))
        out = F.relu(self.conv2(out))
        return out




class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.eblock1 = DownSample_Block(3,32,down_sample=False,batch_norm=False)
        self.eblock2 = DownSample_Block(32,64)
        self.eblock3 = DownSample_Block(64,128)
        self.eblock4 = DownSample_Block(128,256)
        self.eblock5 = DownSample_Block(256,512)
        self.eblock6 = DownSample_Block(512,1024)

        self.dblock1 = Upsample_Block(1024,512,512)
        self.dblock2 = Upsample_Block(512,256,256)
        self.dblock3 = Upsample_Block(256,128,128)
        self.dblock4 = Upsample_Block(128,64,64)
        self.dblock5 = Upsample_Block(64,32,32)
     

        self.conv = nn.Conv2d(32,1,kernel_size=3,padding=1)

    def forward(self,x):
        d1 = self.eblock1(x)
        d2 = self.eblock2(d1)
        d3 = self.eblock3(d2)
        d4 = self.eblock4(d3)
        d5 = self.eblock5(d4)
        d6 = self.eblock6(d5)
 
        u1 = self.dblock1(d6,d5)
        u2 = self.dblock2(u1,d4) 
        u3 = self.dblock3(u2,d3)
        u4 = self.dblock4(u3,d2)
        u5 = self.dblock5(u4,d1)

        out = self.conv(u5)

        return out



In [5]:
model = Unet()
model

Unet(
  (eblock1): DownSample_Block(
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (eblock2): DownSample_Block(
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (eblock3): DownSample_Block(
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3),