In [113]:
# import necessary libraries
import torch
import torch.nn as nn

In [114]:
# create CNNBlock to simplify implementation
# unlike MobileNet_v1 we use ReLU6
class CNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(CNNBlock,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU6()


    def forward(self,x):
        return self.relu(self.bn(self.conv(x)))

In [115]:
# Pointwise conv - expansion layer
# it creates a mapping of the input tensor in a high-dimensional space(manifold of interests)
# then comes the depthwise conv. 
# in the end there is another 1x1 conv, it stacks 
# the high dimensionality manifold of interests into
# a subspace of lower dimensionality without loss of information.

# Residuals improve the ability of a gradient to propagate
# across multiplier layers

class BottleNeck(nn.Module):
    def __init__(self,in_channels,out_channels,expansion,stride):
        super(BottleNeck,self).__init__()
        self.expansion = expansion
        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.pointwise = nn.Conv2d(in_channels,in_channels*self.expansion,kernel_size=1,stride=1,padding=0)
        self.bn1 = nn.BatchNorm2d(in_channels*self.expansion)
        self.depthwise = nn.Conv2d(in_channels*self.expansion,in_channels*self.expansion,kernel_size=3,stride=self.stride,padding=1,
                                    groups=in_channels*self.expansion)
        self.bn2 = nn.BatchNorm2d(in_channels*self.expansion)
        self.conv = nn.Conv2d(in_channels*self.expansion,out_channels,kernel_size=1,stride=1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu6 = nn.ReLU6()
        self.relu = nn.ReLU()
        layers = []

        # if expansion==1, it is simple DepthWiseCNNBlock from MobileNet_v1
        if self.expansion!=1:
            layers.append(nn.Sequential(
                self.pointwise,
                self.bn1,
                self.relu6,
                ))
        
        layers.extend(nn.Sequential(
                self.depthwise,
                self.bn2,
                self.relu6,

                self.conv,
                self.bn3,
                self.relu
        ))

        self.layers = nn.Sequential(*layers)    
            

    #condition of using skip connections
    def forward(self,x):
        if self.stride==1 and self.in_channels==self.out_channels:
            x = x + self.layers(x)
        else:
            x = self.layers(x)

        return x

In [116]:
# In this version there are
# inverted Residuals and Linear Bottlenecks
# To see more, read the paper https://arxiv.org/pdf/1801.04381.pdf
class MobileNet_v2(nn.Module):
    def __init__(self,img_channels,num_classes):
        super(MobileNet_v2,self).__init__()
        
        model = []
        self.in_channels = 32
        self.conv1 = CNNBlock(img_channels,self.in_channels,kernel_size=3,stride=2)
        self.conv2 = CNNBlock(320,1280,kernel_size=1,stride=1)
        self.conv3 = CNNBlock(1280,num_classes,kernel_size=1,stride=1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # t,c,n,s - expansion factor,output channels,number of blocks,stride
        self.bottlenecks = [
            [1,16,1,1],
            [6,24,2,2],
            [6,32,3,2],
            [6,64,4,2],
            [6,96,3,1],
            [6,160,3,2],
            [6,320,1,1]
        ]

        for t,c,n,s in self.bottlenecks:
            for i in range(n):
                model.append(BottleNeck(self.in_channels,c,t,s))
                self.in_channels = c

        self.model = nn.Sequential(*model)

    def forward(self,x):
        x = self.conv1(x)
        x = self.model(x)
        x = self.conv2(x)
        x = self.avg_pool(x)
        x = self.conv3(x)
        # flatten x
        x = x.view(-1, 1000)
        return x 


In [117]:
# create MobileNet_v2
def mobile_net_v2(img_channels=3,num_classes=1000):
    return MobileNet_v2(img_channels,num_classes)

In [118]:
# test the net architecture
def test():
    net = mobile_net_v2()
    x = torch.rand(2,3,224,224)
    y = net(x)
    print(y.shape)

In [119]:
test()

torch.Size([2, 1000])
