<h1>U-Net model<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Model-classes-creation" data-toc-modified-id="Model-classes-creation-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Model classes creation</a></span></li><li><span><a href="#Model-test" data-toc-modified-id="Model-test-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model test</a></span><ul class="toc-item"><li><span><a href="#Base-model-(in_channels=3,-num_classes=1)" data-toc-modified-id="Base-model-(in_channels=3,-num_classes=1)-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Base model (in_channels=3, num_classes=1)</a></span></li><li><span><a href="#Model--with-(in_channels=1,-num_classes=5)" data-toc-modified-id="Model--with-(in_channels=1,-num_classes=5)-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Model  with (in_channels=1, num_classes=5)</a></span></li></ul></li></ul></div>

In [1]:
import torch
import torch.nn as nn

### Model classes creation

In [51]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), padding=1)
        self.act = nn.ReLU(True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        out = self.act(x)
        
        return out

In [52]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = DoubleConv(in_channels, out_channels)
        self.down = nn.MaxPool2d(2)
        
    def forward(self, x):
        out_double_conv = self.double_conv(x)
        out_down = self.down(out_double_conv)
        
        return out_double_conv, out_down   

In [53]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, (2, 2), stride=2)
        self.double_conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], dim=1)
        out = self.double_conv(x)
        
        return out

In [54]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1):
        super().__init__()
        self.down1 = DownSample(in_channels, 64)
        self.down2 = DownSample(64, 128)
        self.down3 = DownSample(128, 256)
        self.down4 = DownSample(256, 512)
        
        self.bottleneck = DoubleConv(512, 1024)
        
        self.up1 = UpSample(1024, 512)
        self.up2 = UpSample(512, 256)
        self.up3 = UpSample(256, 128)
        self.up4 = UpSample(128, 64)
        
        self.out = nn.Conv2d(64, num_classes, (1, 1))
        
    def forward(self, x):
        sk1, x = self.down1(x)
        sk2, x = self.down2(x)
        sk3, x = self.down3(x)
        sk4, x = self.down4(x)
        
        x = self.bottleneck(x)
        
        x = self.up1(x, sk4)
        x = self.up2(x, sk3)
        x = self.up3(x, sk2)
        x = self.up4(x, sk1)
        
        out = self.out(x)
        
        return out

In [37]:
model = UNet()

In [38]:
model

UNet(
  (down1): DownSample(
    (double_conv): DoubleConv(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (down): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down2): DownSample(
    (double_conv): DoubleConv(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (down): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down3): DownSample(
    (double_conv): DoubleConv(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (down): MaxPool2d(kernel_size=2, stride=

### Model test

#### Base model (in_channels=3, num_classes=1)

In [61]:
inp = torch.rand([1, 3, 512, 512])

In [62]:
pred = model(inp)
pred.shape

torch.Size([1, 1, 512, 512])

#### Model  with (in_channels=1, num_classes=5)

In [63]:
model_ = UNet(in_channels=1, num_classes=5)

In [64]:
model_

UNet(
  (down1): DownSample(
    (double_conv): DoubleConv(
      (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (down): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down2): DownSample(
    (double_conv): DoubleConv(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (down): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down3): DownSample(
    (double_conv): DoubleConv(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): ReLU(inplace=True)
    )
    (down): MaxPool2d(kernel_size=2, stride=

In [65]:
inp = torch.rand([1, 1, 512, 512])
pred = model_(inp)
pred.shape

torch.Size([1, 5, 512, 512])