In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels, mid_channels=None):
    super().__init__()

    mid_channels = out_channels if mid_channels is None else mid_channels

    self.double_conv = nn.Sequential(
        nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(mid_channels),
        nn.ReLU(inplace=True),

        nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.double_conv(x)

In [3]:
class ContractBlock(nn.Module):
  def __init__(self, in_channels, out_channels, mid_channels=None):
    super().__init__()

    self.contract_path = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_channels, out_channels, mid_channels)
    )

  def forward(self, x):
    return self.contract_path(x)

In [4]:
class ExpansiveBlock(nn.Module):
  def __init__(self, in_channels, out_channels, mid_channels=None):
    super().__init__()

    self.upsample = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
    self.double_conv = DoubleConv(in_channels, out_channels, mid_channels)

  def forward(self, x_to_upsample, x_residual):
    x_upsampled = self.upsample(x_to_upsample)
    x_concat = torch.cat([x_residual, x_upsampled], dim=1)

    return self.double_conv(x_concat)

In [5]:
class Unet(nn.Module):
  def __init__(self, in_channels, n_classes):
    super().__init__()

    self.input_conv = DoubleConv(in_channels, 64)
    self.contract1 = ContractBlock(64, 128)
    self.contract2 = ContractBlock(128, 256)
    self.contract3 = ContractBlock(256, 512)

    # dropout
    self.contract4 = ContractBlock(512, 1024)

    # dropout
    self.expansive3 = ExpansiveBlock(1024, 512)
    # dropout
    self.expansive2 = ExpansiveBlock(512, 256)
    # dropout
    self.expansive1 = ExpansiveBlock(256, 128)
    # dropout
    self.expansive0 = ExpansiveBlock(128, 64)

    #dropout
    self.out_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    self.dropout = nn.Dropout2d(0.3)

  def forward(self, x):
    x0 = self.input_conv(x)
    x1 = self.contract1(x0)
    x2 = self.contract2(x1)
    x3 = self.contract3(x2)

    x = self.dropout(x3)
    x = self.contract4(x)

    x = self.dropout(x)
    x = self.expansive3(x, x3)

    x = self.dropout(x)
    x = self.expansive2(x, x2)

    x = self.dropout(x)
    x = self.expansive1(x, x1)

    x = self.dropout(x)
    x = self.expansive0(x, x0)

    x = self.dropout(x)
    x = self.out_conv(x)


In [6]:
from torchsummary import summary

In [7]:
unet = Unet(3, 7)
summary(unet, (3, 512, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 256]           1,792
       BatchNorm2d-2         [-1, 64, 512, 256]             128
              ReLU-3         [-1, 64, 512, 256]               0
            Conv2d-4         [-1, 64, 512, 256]          36,928
       BatchNorm2d-5         [-1, 64, 512, 256]             128
              ReLU-6         [-1, 64, 512, 256]               0
        DoubleConv-7         [-1, 64, 512, 256]               0
         MaxPool2d-8         [-1, 64, 256, 128]               0
            Conv2d-9        [-1, 128, 256, 128]          73,856
      BatchNorm2d-10        [-1, 128, 256, 128]             256
             ReLU-11        [-1, 128, 256, 128]               0
           Conv2d-12        [-1, 128, 256, 128]         147,584
      BatchNorm2d-13        [-1, 128, 256, 128]             256
             ReLU-14        [-1, 128, 2