In [240]:
import torch
# from torchsummary import summary
import torch.nn as nn


In [241]:
# e.g 512,512

class Downsample(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels=32,
                #  internal_ratio=4,
                 kernel_size=3,
                 padding=0,
                 dropout_prob=0.,
                 stride = 2,
                 bias=False,
                 relu=True,
                 out_channels_div= [11,9,9] 
            ):
        super(Downsample,self).__init__()


        if relu:
            activation = nn.ReLU()
        else:
            activation = nn.PReLU()

        self.conv2d_d1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels_div[0],kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=1)
        self.conv2d_d2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels_div[1],kernel_size=kernel_size, stride=stride, padding=1, bias=bias, dilation=2)
        self.conv2d_d5 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels_div[2],kernel_size=kernel_size, stride=stride, padding=4, bias=bias, dilation=5)
        self.ext_branch = nn.MaxPool2d(kernel_size, stride=2, padding=padding)
        # self.half = nn.MaxPool2d(2, stride=2)
        self.activation = activation
        self.batch_norm = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        print('input:', input.shape)
        main1_d1= self.conv2d_d1(input)
        print('dialation 1:',main1_d1.shape)
        main2_d2 = self.conv2d_d2(input)
        print('dialation 2', main2_d2.shape)
        main3_d5 = self.conv2d_d5(input)
        print('dialation 5:',main3_d5.shape)
        # ext1=self.half(input)
        ext1 = self.ext_branch(input)
        print(ext1.shape)
        
        
        
        out = torch.cat((main1_d1,main2_d2,main3_d5,ext1), dim=1)
        print('out after concate', out.shape)
        out = self.batch_norm(out)
        out = self.activation(out)
        # print('final', out)


        return out



In [242]:
model = Downsample(in_channels=3)

In [243]:
x = torch.randn(1,3, 360, 640)
xx = model(x)

input: torch.Size([1, 3, 360, 640])
dialation 1: torch.Size([1, 11, 179, 319])
dialation 2 torch.Size([1, 9, 179, 319])
dialation 5: torch.Size([1, 9, 179, 319])
torch.Size([1, 3, 179, 319])
out after concate torch.Size([1, 32, 179, 319])


In [244]:
print(model)

Downsample(
  (conv2d_d1): Conv2d(3, 11, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (conv2d_d2): Conv2d(3, 9, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(2, 2), bias=False)
  (conv2d_d5): Conv2d(3, 9, kernel_size=(3, 3), stride=(2, 2), padding=(4, 4), dilation=(5, 5), bias=False)
  (ext_branch): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (activation): ReLU()
  (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [285]:
# e.g 512,512

class Regular(nn.Module):
    def __init__(self,
                 in_channels=32,
                 out_channels=32,
                #  internal_ratio=4,
                 kernel_size=3,
                 padding=0,
                 dropout_prob=0.,
                 stride = 1,
                 bias=False,
                 relu=True, 
                 out_channels_div = [11,11,10]
                 
                 ):
        super(Regular,self).__init__()


        if relu:
            activation = nn.ReLU()
        else:
            activation = nn.PReLU()

        self.conv2d_d1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels_div[0],kernel_size=kernel_size, stride=stride, padding=1, bias=bias, dilation=1)
        self.conv2d_d2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels_div[1],kernel_size=kernel_size, stride=stride, padding=2, bias=bias, dilation=2)
        self.conv2d_d5 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels_div[2],kernel_size=kernel_size, stride=stride, padding=5, bias=bias, dilation=5)
        # self.half = nn.MaxPool2d(2, stride=2)
        self.activation = activation
        self.batch_norm = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        # print('input:', input.shape)
        main1_d1= self.conv2d_d1(input)
        print('dialation 1:',main1_d1.shape)
        main2_d2 = self.conv2d_d2(input)
        print('dialation 2', main2_d2.shape)
        main3_d5 = self.conv2d_d5(input)
        print('dialation 5:',main3_d5.shape)
        # ext1=self.half(input)
        
        
        
        out = torch.cat((main1_d1,main2_d2,main3_d5), dim=1)
        print('out after concate', out.shape)
        out = self.batch_norm(out)
        out = self.activation(out)
        # print('final', out)


        return out



In [246]:
model2 = Regular(3)

In [247]:
model2(x)

dialation 1: torch.Size([1, 11, 360, 640])
dialation 2 torch.Size([1, 11, 360, 640])
dialation 5: torch.Size([1, 10, 360, 640])
out after concate torch.Size([1, 32, 360, 640])


tensor([[[[0.0000, 0.3802, 0.7758,  ..., 0.0000, 0.0000, 0.8525],
          [0.0620, 0.9355, 1.2865,  ..., 0.0000, 0.0000, 1.1768],
          [0.0000, 1.3972, 1.0617,  ..., 1.7849, 0.0000, 0.7831],
          ...,
          [0.0000, 1.2982, 0.0000,  ..., 2.6496, 0.0000, 0.0000],
          [0.0000, 0.0391, 0.0000,  ..., 0.3552, 0.0000, 0.1680],
          [0.8402, 0.0000, 0.0000,  ..., 0.0750, 0.0000, 0.0000]],

         [[0.3244, 0.5003, 0.2311,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.4729, 0.0000, 0.7840],
          [1.2017, 0.0000, 0.3602,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [1.5101, 0.6026, 0.0000,  ..., 0.0000, 0.0000, 0.2014],
          [0.0000, 0.0000, 1.3153,  ..., 0.6622, 0.2500, 0.1958],
          [0.8610, 0.0906, 0.0000,  ..., 0.7820, 0.0000, 0.0000]],

         [[0.1862, 0.1276, 0.7621,  ..., 0.0000, 0.0352, 0.0000],
          [0.2457, 0.9145, 0.0000,  ..., 0.0000, 0.4412, 0.7753],
          [0.0000, 0.0000, 0.0000,  ..., 0

In [248]:
print(model2)

Regular(
  (conv2d_d1): Conv2d(3, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv2d_d2): Conv2d(3, 11, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
  (conv2d_d5): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
  (activation): ReLU()
  (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [249]:
class Encoder(nn.Module):
    def __init__(self, classes, encoder_relu=False, decoder_relu=True):
        super().__init__()
        self.downsample_1 = Downsample(in_channels=3,out_channels=32, kernel_size=3, relu=True, out_channels_div=[11,9,9])
        self.regular_1 = Regular(in_channels=32, out_channels=32, kernel_size=3, out_channels_div=[11,11,10])
        self.downsample_2 = Downsample(in_channels=32, out_channels=64, kernel_size=3,relu=True, out_channels_div=[11,11,10])
        self.regular_2 = Regular(in_channels=64, out_channels=64,out_channels_div=[22,21,21])
        self.downsample_3 = Downsample(in_channels=64, out_channels=128, out_channels_div=[22,21,21])
        self.regular_3 = Regular(in_channels=128, out_channels=128, out_channels_div=[43,43,42])
        self.regular_4 = Regular(in_channels=128, out_channels=128, out_channels_div=[43,43,42] )

    


    def forward(self, input):
        x = self.downsample_1(input)
        x = self.regular_1(x)
        x = self.downsample_2(x)
        x = self.regular_2(x)
        x = self.downsample_3(x)
        x= self.regular_3(x)
        x= self.regular_4(x)

        return x

        


In [250]:
encoder = Encoder(3)

In [251]:
out = encoder(x)
print(out)

input: torch.Size([1, 3, 360, 640])
dialation 1: torch.Size([1, 11, 179, 319])
dialation 2 torch.Size([1, 9, 179, 319])
dialation 5: torch.Size([1, 9, 179, 319])
torch.Size([1, 3, 179, 319])
out after concate torch.Size([1, 32, 179, 319])
dialation 1: torch.Size([1, 11, 179, 319])
dialation 2 torch.Size([1, 11, 179, 319])
dialation 5: torch.Size([1, 10, 179, 319])
out after concate torch.Size([1, 32, 179, 319])
input: torch.Size([1, 32, 179, 319])
dialation 1: torch.Size([1, 11, 89, 159])
dialation 2 torch.Size([1, 11, 89, 159])
dialation 5: torch.Size([1, 10, 89, 159])
torch.Size([1, 32, 89, 159])
out after concate torch.Size([1, 64, 89, 159])
dialation 1: torch.Size([1, 22, 89, 159])
dialation 2 torch.Size([1, 21, 89, 159])
dialation 5: torch.Size([1, 21, 89, 159])
out after concate torch.Size([1, 64, 89, 159])
input: torch.Size([1, 64, 89, 159])
dialation 1: torch.Size([1, 22, 44, 79])
dialation 2 torch.Size([1, 21, 44, 79])
dialation 5: torch.Size([1, 21, 44, 79])
torch.Size([1, 64

In [252]:
print(encoder)

Encoder(
  (downsample_1): Downsample(
    (conv2d_d1): Conv2d(3, 11, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (conv2d_d2): Conv2d(3, 9, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(2, 2), bias=False)
    (conv2d_d5): Conv2d(3, 9, kernel_size=(3, 3), stride=(2, 2), padding=(4, 4), dilation=(5, 5), bias=False)
    (ext_branch): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (activation): ReLU()
    (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (regular_1): Regular(
    (conv2d_d1): Conv2d(32, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv2d_d2): Conv2d(32, 11, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (conv2d_d5): Conv2d(32, 10, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
    (activation): ReLU()
    (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, tra

In [300]:
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3):
        super().__init__()
        
        self.conv = nn.ConvTranspose2d(in_channels= in_channels, out_channels=out_channels, stride=2, kernel_size=kernel_size, padding=0, output_padding=0, bias=True)
        # self.bn = nn.BatchNorm2d(noutput, eps=1e-3)     # ERFNet batchnorm with eps = e-5
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.regular_de_1=Regular(in_channels=64, out_channels=64, out_channels_div=[22,21,21])
    def forward(self, input):
        print('Input Shape: ', input.shape)
        output = self.conv(input)
        output = self.batch_norm(output)
        print('Output Shape: ', output.shape)
        output = nn.ReLU(output)
        
        output = self.regular_de_1(output)
        output = nn.ReLU(output)
        return output

In [301]:
test = Upsample(in_channels=128, out_channels=64) 

In [302]:
t= torch.rand(1,128,44,79)

In [303]:
y = test(t)

Input Shape:  torch.Size([1, 128, 44, 79])
Output Shape:  torch.Size([1, 64, 89, 159])


TypeError: conv2d() received an invalid combination of arguments - got (ReLU, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!ReLU!, !Parameter!, !NoneType!, !tuple!, !tuple!, !tuple!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!ReLU!, !Parameter!, !NoneType!, !tuple!, !tuple!, !tuple!, int)


In [299]:
class Decoder(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.upsample_1 = Upsample(in_channels=128, out_channels=64)
        self.regular_de_1 = Regular(in_channels=64 ,out_channels=64,kernel_size=3, padding=0, out_channels_div=[22,21,21], relu=True)
        
        self.upsample_2 = Upsample(in_channels=64, out_channels=32)
        self.regular_de_2 = Regular(in_channels=32, out_channels=32, kernel_size=3 ,out_channels_div=[12,11,11], padding=0,relu=True)
        self.upsample_3 = Regular(in_channels=32,out_channels=num_classes)

    def forward(self, input):
        x = self.upsample_1(input)
        x = self.regular_de_1(x)
        x = self.upsample_2(x)
        x = self.regular_de_2(x)
        x= self.upsample_3(x)

    
        return x


In [294]:
mod = Decoder(num_classes=3)

In [295]:
y = mod(t)

Input Shape:  torch.Size([1, 128, 44, 79])
Output Shape:  torch.Size([1, 64, 89, 159])


TypeError: conv2d() received an invalid combination of arguments - got (ReLU, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!ReLU!, !Parameter!, !NoneType!, !tuple!, !tuple!, !tuple!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!ReLU!, !Parameter!, !NoneType!, !tuple!, !tuple!, !tuple!, int)
