In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install torchinfo
import torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [44]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.BatchNorm2d(num_features=in_channels),
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1),
        nn.ReLU()
    )
def double_trans_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1),
        nn.ReLU()
    )

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.max_unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.conv1 = double_conv(3, 16)
        self.conv2 = double_conv(16, 32)
        self.conv3 = double_conv(32, 64)
        self.conv4 = double_conv(64, 128)
        self.conv5 = double_conv(128, 256)
        self.conv6 = double_conv(256, 512)
        self.conv_1x1_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1)
        self.conv_1x1_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1, stride=1)
        self.conv_1x1_3 = nn.Conv2d(in_channels=2048, out_channels=1024, kernel_size=1, stride=1)
        self.conv_1x1_4 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1)
        self.conv_1x1_out = nn.Conv2d(in_channels=16, out_channels=5, kernel_size=1, stride=1)
        self.trans_conv1 = double_trans_conv(1024, 256)
        self.trans_conv2 = double_trans_conv(512, 256)
        self.trans_conv3 = double_trans_conv(256, 64)
        self.trans_conv4 = double_trans_conv(128, 32)
        self.trans_conv5 = double_trans_conv(64, 32)
        self.trans_conv6 = double_trans_conv(32, 16)
    def forward(self, inp):
        c1 = self.conv1(inp)
        c2 = self.conv2(c1)
        p1, i1 = self.max_pool(c2)
        c3 = self.conv3(p1)
        p2, i2 = self.max_pool(c3)
        c4 = self.conv4(p2)
        c5 = self.conv5(c4)
        p3, i3 = self.max_pool(c5)
        c6 = self.conv6(p3)
        p4, i4 = self.max_pool(c6)
        
        e1 = self.conv_1x1_1(p4)
        e2 = self.conv_1x1_2(e1)
        e3 = self.conv_1x1_3(e2)
        q4 = self.conv_1x1_4(e3)

        d6 = self.max_unpool(nn.BatchNorm2d(512)(q4), i4, output_size = c6.size())
        d6 = torch.cat((c6, d6), dim=1)
        q3 = self.trans_conv1(d6)
        d5 = self.max_unpool(nn.BatchNorm2d(256)(q3), i3, output_size = c5.size())
        d5 = torch.cat((c5, d5), dim=1)
        d4 = self.trans_conv2(d5)
        q2 = self.trans_conv3(d4)
        d3 = self.max_unpool(nn.BatchNorm2d(64)(q2), i2, output_size = c3.size())
        d3 = torch.cat((c3, d3), dim=1)
        q1 = self.trans_conv4(d3)
        d2 = self.max_unpool(nn.BatchNorm2d(32)(q1), i1, output_size = c2.size())
        d2 = torch.cat((c2, d2), dim=1)
        d1 = self.trans_conv5(d2)
        d0 = self.trans_conv6(d1)
        l = self.conv_1x1_out(d0)
        return nn.Softmax(dim=1)(l)
        



inp = torch.randn(1,3,256,256)
model = Network()
out = model(inp)
out.shape

torch.Size([1, 5, 256, 256])

In [46]:
torchinfo.summary(model, (1, 3, 256, 256))

Layer (type:depth-idx)                   Output Shape              Param #
Network                                  [1, 5, 256, 256]          --
├─Sequential: 1-1                        [1, 16, 252, 252]         --
│    └─BatchNorm2d: 2-1                  [1, 3, 256, 256]          6
│    └─Conv2d: 2-2                       [1, 16, 254, 254]         448
│    └─ReLU: 2-3                         [1, 16, 254, 254]         --
│    └─Conv2d: 2-4                       [1, 16, 252, 252]         2,320
│    └─ReLU: 2-5                         [1, 16, 252, 252]         --
├─Sequential: 1-2                        [1, 32, 248, 248]         --
│    └─BatchNorm2d: 2-6                  [1, 16, 252, 252]         32
│    └─Conv2d: 2-7                       [1, 32, 250, 250]         4,640
│    └─ReLU: 2-8                         [1, 32, 250, 250]         --
│    └─Conv2d: 2-9                       [1, 32, 248, 248]         9,248
│    └─ReLU: 2-10                        [1, 32, 248, 248]         --
├─MaxP