In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

`in_sh` (int): input shape channels

`num_classes` (int): number of classes to segment

In [2]:
class SegNet(nn.Module):
  def __init__(self,in_sh,num_classes):
    super(SegNet,self).__init__()

    self.conv11= nn.Conv2d(in_sh, 64, kernel_size = 3, stride = 1, padding = 1)
    self.batch11= nn.BatchNorm2d(64)
    self.conv12= nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1)
    self.batch12= nn.BatchNorm2d(64)

    self.conv21= nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1)
    self.batch21= nn.BatchNorm2d(128)
    self.conv22= nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1)
    self.batch22= nn.BatchNorm2d(128)

    self.conv31= nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1)
    self.batch31= nn.BatchNorm2d(256)
    self.conv32= nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
    self.batch32= nn.BatchNorm2d(256)
    self.conv33= nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
    self.batch33= nn.BatchNorm2d(256)

    self.conv41= nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1)
    self.batch41= nn.BatchNorm2d(512)
    self.conv42= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.batch42= nn.BatchNorm2d(512)
    self.conv43= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.batch43= nn.BatchNorm2d(512)

    self.conv51= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.batch51= nn.BatchNorm2d(512)
    self.conv52= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.batch52= nn.BatchNorm2d(512)
    self.conv53= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.batch53= nn.BatchNorm2d(512)

    self.d_conv53= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch53= nn.BatchNorm2d(512)
    self.d_conv52= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch52= nn.BatchNorm2d(512)
    self.d_conv51= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch51= nn.BatchNorm2d(512)

    self.d_conv43= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch43= nn.BatchNorm2d(512)
    self.d_conv42= nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch22= nn.BatchNorm2d(512)
    self.d_conv41= nn.Conv2d(512, 256, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch41= nn.BatchNorm2d(256)

    self.d_conv33= nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch33= nn.BatchNorm2d(256)
    self.d_conv32= nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch32= nn.BatchNorm2d(256)
    self.d_conv31= nn.Conv2d(256, 128, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch31= nn.BatchNorm2d(128)

    self.d_conv22= nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch22= nn.BatchNorm2d(128)
    self.d_conv21= nn.Conv2d(128, 64, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch21= nn.BatchNorm2d(64)

    self.d_conv12= nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch12= nn.BatchNorm2d(128)
    self.d_conv11= nn.Conv2d(64, num_classes, kernel_size = 3, stride = 1, padding = 1)
    self.d_batch11= nn.BatchNorm2d(64)

  def forward(self,x):
        x11 = F.relu(self.batch11(self.conv11(x)))
        x12 = F.relu(self.batch12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)

        x21 = F.relu(self.batch21(self.conv21(x1p)))
        x22 = F.relu(self.batch22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)

        x31 = F.relu(self.batch31(self.conv31(x2p)))
        x32 = F.relu(self.batch32(self.conv32(x31)))
        x33 = F.relu(self.batch33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)

        x41 = F.relu(self.batch41(self.conv41(x3p)))
        x42 = F.relu(self.batch42(self.conv42(x41)))
        x43 = F.relu(self.batch43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)

        x51 = F.relu(self.batch51(self.conv51(x4p)))
        x52 = F.relu(self.batch52(self.conv52(x51)))
        x53 = F.relu(self.batch53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)

        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        x53d = F.relu(self.d_batch53(self.conv53d(x5d)))
        x52d = F.relu(self.d_batch52(self.conv52d(x53d)))
        x51d = F.relu(self.d_batch51(self.conv51d(x52d)))

        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        x43d = F.relu(self.d_batch43(self.conv43d(x4d)))
        x42d = F.relu(self.d_batch42(self.conv42d(x43d)))
        x41d = F.relu(self.d_batch41(self.conv41d(x42d)))

        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        x33d = F.relu(self.d_batch33(self.conv33d(x3d)))
        x32d = F.relu(self.d_batch32(self.conv32d(x33d)))
        x31d = F.relu(self.d_batch31(self.conv31d(x32d)))

        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        x22d = F.relu(self.d_batch22(self.conv22d(x2d)))
        x21d = F.relu(self.d_batch21(self.conv21d(x22d)))

        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        x12d = F.relu(self.d_batch12(self.conv12d(x1d)))
        x11d = self.d_batch11(x12d)

        return x11d






