In [5]:
import torch
import torch.nn as nn

In [6]:
class DepthWiseConv2d(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, stride=1):
        super(DepthWiseConv2d, self).__init__()
        self.model = nn.Sequential(
                nn.Conv2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=(3,3),
                stride=stride,
                padding=1,
                groups=in_channels,
                bias=False
                ),
                nn.BatchNorm2d(in_channels),
                nn.ReLU(),
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
        )
        
    def forward(self, feature):
        X = self.model(feature)
        return X

    
class Conv2D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super(Conv2D, self).__init__()
        self.model = nn.Sequential(
                        nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=(3,3), 
                        stride=stride, 
                        padding=1,
                        bias=False
                        ),                    
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU()
        )
        
    def forward(self, feature):
        X = self.model(feature)
        return X


In [7]:
class MobileNet(nn.Module):
    def __init__(self, image_channels, num_classes):
        super(MobileNet, self).__init__()
        self.image_channels = image_channels
        self.num_classes = num_classes
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(in_features=1024, out_features=self.num_classes) 
        self.layer1 = Conv2D(self.image_channels, 32)
        self.layer2 = nn.Sequential(   
                        DepthWiseConv2d(32, 64, 1),
                        DepthWiseConv2d(64, 128, 2),
                        DepthWiseConv2d(128, 128, 1),
                        DepthWiseConv2d(128, 256, 2),
                        DepthWiseConv2d(256, 256, 1),
                        DepthWiseConv2d(256, 512, 2),
                        DepthWiseConv2d(512, 512, 1),
                        DepthWiseConv2d(512, 512, 1),
                        DepthWiseConv2d(512, 512, 1),
                        DepthWiseConv2d(512, 512, 1),
                        DepthWiseConv2d(512, 512, 1),
                        DepthWiseConv2d(512, 1024, 2),
                        DepthWiseConv2d(1024, 1024, 1)
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avg_pool(x)
        x = x.view(x.shape[0], -1)
        x = self.linear(x)
        return x        

In [8]:
x = torch.randn(1, 3, 224, 224)
model = MobileNet(3, 1000)
model(x).shape

torch.Size([1, 1000])

In [9]:
sum([p.numel() for p in model.parameters() if p.requires_grad])

4231976