In [57]:
# import necessary libraries
import torch
import torch.nn as nn

In [58]:
# create CNNBlock to simplify implementation
class CNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(CNNBlock,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=2)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self,x):
        return self.relu(self.bn(self.conv(x)))

In [59]:
# We seperate the standard conv layer into
# depthwise conv and pointwise conv.
# It helps to reduce computation time and model size.
class DepthWiseCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1,count=1):
        super(DepthWiseCNNBlock,self).__init__()
        self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=stride,padding=1,groups=in_channels,bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.pointwise = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.layers = nn.ModuleList()

        # we use count because we need to repeat some blocks several times
        for i in range(count):
            self.layers+=[nn.Sequential(
                self.depthwise,
                self.bn1,
                self.relu,
                
                self.pointwise,
                self.bn2,
                self.relu  
            )]
            
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)

        return x


In [60]:
# Distinctive features of MobileNet:
# 1) no Maxpooling layers
# 2) DepthWise and pointwise convolutional layers
# To see more, read the paper https://arxiv.org/pdf/1704.04861.pdf
class MobileNet_v1(nn.Module):
    def __init__(self,img_channels,num_classes):
        super(MobileNet_v1,self).__init__()
        
        self.model = nn.Sequential(
            CNNBlock(img_channels,32),
            DepthWiseCNNBlock(32,64),
            DepthWiseCNNBlock(64,128,stride=2),
            DepthWiseCNNBlock(128,128),
            DepthWiseCNNBlock(128,256,stride=2),
            DepthWiseCNNBlock(256,256),
            DepthWiseCNNBlock(256,512,stride=2),
            DepthWiseCNNBlock(512,512,count=5),
            DepthWiseCNNBlock(512,1024,stride=2),
            DepthWiseCNNBlock(1024,1024,stride=1),
            nn.AvgPool2d(7)
        )
        self.fc = nn.Linear(1024,num_classes)

    def forward(self,x):
        x = self.model(x)
        #flatten x
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x 


In [61]:
# create MobileNet_v1
def mobile_net_v1(img_channels=3,num_classes=1000):
    return MobileNet_v1(img_channels,num_classes)

In [62]:
# test the net architecture
def test():
    net = mobile_net_v1()
    x = torch.rand(2,3,224,224)
    y = net(x)
    print(y.shape)

In [63]:
test()

torch.Size([2, 1000])
