In [1]:
import torch.nn as nn
import torch
import torchvision as vs
from torchsummary import summary

In [2]:
class DualConv(nn.Module):
    
    def __init__(self, input_channels, output_channels):
        
        super().__init__()
        
        self.dual_conv = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        
        return self.dual_conv(x)

In [3]:
class DownConv(nn.Module):
    
    def __init__(self, input_channels, output_channels):
        
        super().__init__()
        
        self.down_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DualConv(input_channels, output_channels)
        )
    
    def forward(self, x):
        
        return self.down_conv(x)

In [4]:
class UpConv(nn.Module):
    
    def __init__(self, input_channels, output_channels):
        
        super().__init__()
        
        self.up_conv = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)
        self.conv = DualConv(input_channels, output_channels)
    
    def forward(self, x1, x2):
        
        x1 = self.up_conv(x1)
        
        y_pad = x2.size()[2] - x1.size()[2]
        x_pad = x2.size()[3] - x1.size()[3]
        
        x1 = nn.functional.pad(x1, [x_pad // 2, x_pad - x_pad // 2,
                                    y_pad // 2, y_pad - y_pad // 2])
        
        x = torch.cat([x2, x1], dim = 1)
        
        return self.conv(x)

In [5]:
class OutputConv(nn.Module):
    
    def __init__(self, input_channels, output_channels):
        
        super().__init__()
        
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        
        return self.conv(x)

In [6]:
class UNet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        
        super().__init__()
        
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        self.inp = DualConv(n_channels, 64)
        
        self.down_conv_1 = DownConv(64, 128)
        self.down_conv_2 = DownConv(128, 256)
        self.down_conv_3 = DownConv(256, 512)
        self.down_conv_4 = DownConv(512, 1024)
        
        self.up_conv_1 = UpConv(1024, 512)
        self.up_conv_2 = UpConv(512, 256)
        self.up_conv_3 = UpConv(256, 128)
        self.up_conv_4 = UpConv(128, 64)
        
        self.op_conv = OutputConv(64, n_classes)
    
    def forward(self, x):
        
        x1 = self.inp(x)
        
        x2 = self.down_conv_1(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.down_conv_3(x3)
        x5 = self.down_conv_4(x4)
        
        x6 = self.up_conv_1(x5, x4)
        x7 = self.up_conv_2(x6, x3)
        x8 = self.up_conv_3(x7, x2)
        x9 = self.up_conv_4(x8, x1)
        
        result = self.op_conv(x9)
        
        return result

In [7]:
model = UNet(n_channels=3, n_classes=1)

In [8]:
summary(model.cuda(), (3, 480, 360))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 480, 360]           1,792
       BatchNorm2d-2         [-1, 64, 480, 360]             128
              ReLU-3         [-1, 64, 480, 360]               0
            Conv2d-4         [-1, 64, 480, 360]          36,928
       BatchNorm2d-5         [-1, 64, 480, 360]             128
              ReLU-6         [-1, 64, 480, 360]               0
          DualConv-7         [-1, 64, 480, 360]               0
         MaxPool2d-8         [-1, 64, 240, 180]               0
            Conv2d-9        [-1, 128, 240, 180]          73,856
      BatchNorm2d-10        [-1, 128, 240, 180]             256
             ReLU-11        [-1, 128, 240, 180]               0
           Conv2d-12        [-1, 128, 240, 180]         147,584
      BatchNorm2d-13        [-1, 128, 240, 180]             256
             ReLU-14        [-1, 128, 2