In [225]:
import torch
import torch.nn as nn
import torch.onnx

from torchsummary import summary

- The authors disable bias for all convolutional layers
- PyTorch BatchNorm has fewer parameters
- explain depthwise separable convolution in thesis (mobilenet)
- A little bit more parameters because of SeparableConv2D (less total but more trainable)

In [226]:
class SeparableConv2D(nn.Module):
    def __init__(self, input_filters, output_filters, kernel_size=3):
        super(SeparableConv2D, self).__init__()
        self.depthwise = nn.Conv2d(input_filters, input_filters, kernel_size=kernel_size, padding=1, groups=input_filters, bias=False)
        self.pointwise = nn.Conv2d(input_filters, output_filters, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [227]:
def cba(input_filters=3, output_filters=32, kernel_size=(3,3), stride=(1,1), padding=1):
    return nn.Sequential(
        nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(output_filters, momentum=0.99, eps=0.001),
        nn.ReLU()
    )

def acb(input_filters=3, output_filters=32, kernel_size=(3,3), stride=(1,1), padding=1):
    return nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(output_filters, momentum=0.99, eps=0.001)
    )

def sep_acb(input_filters=3, output_filters=32, kernel_size=(3,3), stride=(1,1), padding=1):
    return nn.Sequential(
        nn.ReLU(),
        SeparableConv2D(input_filters=384, output_filters=576, kernel_size=(3,3)),
        nn.BatchNorm2d(output_filters, momentum=0.99, eps=0.001)
    )

def cb(input_filters=3, output_filters=32, kernel_size=(3,3), stride=(1,1), padding=1):
    return nn.Sequential(
        nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(output_filters, momentum=0.99, eps=0.001)
    )

In [228]:
class Stem(nn.Module):

    def __init__(self):
        super(Stem, self).__init__()
        
        self.cba1 = cba(input_filters=3, output_filters=32, kernel_size=(3,3), stride=(2,2))
        self.cba2 = cba(input_filters=32, output_filters=32, kernel_size=(3,3), stride=(1,1))
        self.cba3 = cba(input_filters=32, output_filters=64, kernel_size=(3,3), stride=(1,1))
        self.cba4 = cba(input_filters=64, output_filters=96, kernel_size=(3,3), stride=(2,2))
        
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=1)
        
        self.cba5 = cba(input_filters=160, output_filters=64, kernel_size=(1,1), stride=(1,1), padding=0)
        
        self.cb1 = cb(input_filters=64, output_filters=96, kernel_size=(3,3), stride=(1,1))

        self.cba6 = cba(input_filters=160, output_filters=64, kernel_size=(1,1), stride=(1,1), padding=0)
        self.cba7 = cba(input_filters=64, output_filters=64, kernel_size=(5,1), stride=(1,1), padding=1)
        self.cba8 = cba(input_filters=64, output_filters=64, kernel_size=(1,5), stride=(1,1), padding=1)
        self.cb2 = cb(input_filters=64, output_filters=96, kernel_size=(3,3), stride=(1,1))
        
        self.acb1 = acb(input_filters=192, output_filters=192, kernel_size=(3,3), stride=(2,2))
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2), padding=0)
        
        self.acb2 = acb(input_filters=384, output_filters=576, kernel_size=(1,1), stride=(1,1), padding=0)
        self.sep_acb1 = sep_acb(input_filters=384, output_filters=576, kernel_size=(3,3), stride=(1,1), padding=1)

        
    def forward(self, x):
        out = self.cba1(x)
        out = self.cba2(out)
        out = self.cba3(out)
        
        a = self.cba4(out) #96, 61, 61
        b = self.maxpool1(out) #64,61,61
        out = torch.cat((a,b), 1)
        print("after first concat", out.shape)
        
        a = self.cba5(out)        
        a = self.cb1(a)
        
        b = self.cba6(out)
        b = self.cba7(b)
        b = self.cba8(b)
        b = self.cb2(b)

        out = torch.cat((a,b), 1)
        print("after second concat", out.shape)
        
        a = self.acb1(out)
        b = self.maxpool2(out)

        out = torch.cat((a,b), 1)
        print("after third concat", out.shape)
        
        # sep_conv_residual
        
        a = self.acb2(out)
        b = self.sep_acb1(out)
        
        #print("a", a.shape)
        #print("b", b.shape)
        out = a + b

        return out

In [229]:
summary(Stem(), input_size=(3, 256, 256))


after first concat torch.Size([2, 160, 64, 64])
after second concat torch.Size([2, 192, 64, 64])
after third concat torch.Size([2, 384, 32, 32])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 128]             864
       BatchNorm2d-2         [-1, 32, 128, 128]              64
              ReLU-3         [-1, 32, 128, 128]               0
            Conv2d-4         [-1, 32, 128, 128]           9,216
       BatchNorm2d-5         [-1, 32, 128, 128]              64
              ReLU-6         [-1, 32, 128, 128]               0
            Conv2d-7         [-1, 64, 128, 128]          18,432
       BatchNorm2d-8         [-1, 64, 128, 128]             128
              ReLU-9         [-1, 64, 128, 128]               0
           Conv2d-10           [-1, 96, 64, 64]          55,296
      BatchNorm2d-11           [-1, 96, 64, 64]             192
             ReLU-12  

In [None]:
torch.onnx.export(Stem(), (3,256,256), "onnx_model_name.onnx")

In [208]:
summary(, input_size=(384, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 384, 32, 32]           3,456
            Conv2d-2          [-1, 576, 32, 32]         221,184
Total params: 224,640
Trainable params: 224,640
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.50
Forward/backward pass size (MB): 7.50
Params size (MB): 0.86
Estimated Total Size (MB): 9.86
----------------------------------------------------------------
