In [1]:
import os, sys
import torch
import torchvision
import torch.nn as nn

### Setting block

In [24]:
# Reference: https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py
class depthwise_separable_conv(nn.Module):
    def __init__(self, in_channels, kernels_per_layer, out_channels):
        super(depthwise_separable_conv, self).__init__()
        
        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=kernels_per_layer * in_channels, 
                                        kernel_size=(3,3), 
                                        stride = (1,1),
                                        padding = (1,1),
                                        groups=in_channels)
        self.pointwise_conv = nn.Conv2d(in_channels=kernels_per_layer * in_channels,
                                        out_channels=out_channels,
                                        kernel_size=(1,1),
                                        stride=(1,1),
                                        padding=(0,0))
    def forward(self, x):
        out = self.depthwise_conv(x)
        out = self.pointwise_conv(out)
        
        return out
    
class depthwise_conv(nn.Module):
    def __init__(self, in_channels, kernels_per_layer, out_channels, groups, **kwarg):
        super(depthwise_conv, self).__init__()
        
        assert in_channels%groups == 0, "Groups Error: in_channels should be divisible by groups"
            
        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=kernels_per_layer*in_channels,
                                        groups=in_channels,
                                        **kwarg)
        
    def forward(self, x):
        out = self.depthwise_conv(x)
        
        return out
    
class Basic_conv2d_ReLU6(nn.Module):
    def __init__(self, in_channels, out_channels, inplace, **kwarg):
        super(Basic_conv2d_ReLU6, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwarg)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu6 = nn.ReLU6(inplace=inplace)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu6(out)
        
        return out
    
class Inverted_Residual_Block(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio, strides, bias):
        super(Inverted_Residual_Block, self).__init__()
        
        self.strides = strides
        self.use_residual = strides == 1 or strides == (1,1)
        hidden_dim = round(in_channels * expand_ratio)
        
        self.net = nn.Sequential(
            #pw
            Basic_conv2d_ReLU6(in_channels=in_channels,
                               out_channels=hidden_dim,
                               inplace=True,
                               bias = bias,
                               kernel_size=(1,1),
                               stride=1,
                               padding=0),
            
            #dw
            depthwise_conv(in_channels=hidden_dim,
                           kernels_per_layer=1,
                           out_channels=hidden_dim,
                           groups=hidden_dim,
                           bias=bias,
                           kernel_size=(3,3),
                           padding=1,
                           stride=strides),
            nn.BatchNorm2d(num_features=hidden_dim),
            nn.ReLU6(inplace=True),
            
            #pw
            nn.Conv2d(in_channels=hidden_dim,
                      out_channels=out_channels,
                      bias=bias,
                      kernel_size=(1,1),
                      stride=1,
                      padding=0),
            nn.BatchNorm2d(num_features=out_channels)
        )
        
        if in_channels != out_channels:
            self.match_conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), stride=1, padding=0)
        
    def forward(self, x):
        if self.use_residual:
            x_shortcut = x
            out = self.net(x)
            if out.size()[1] != x_shortcut.size()[1]:
                return self.match_conv(x_shortcut) + out
            else:
                return x_shortcut + out
            
        else:
            return self.net(x)

### Testing on Mnist

In [25]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [26]:
model = Inverted_Residual_Block(in_channels=1, out_channels=64, expand_ratio=6, strides=2, bias=False)
model

Inverted_Residual_Block(
  (net): Sequential(
    (0): Basic_conv2d_ReLU6(
      (conv): Conv2d(1, 6, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu6): ReLU6(inplace)
    )
    (1): depthwise_conv(
      (depthwise_conv): Conv2d(6, 6, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=6, bias=False)
    )
    (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU6(inplace)
    (4): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (match_conv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
)

In [27]:
out = model(x)
out.size()

torch.Size([64, 64, 14, 14])

### Reference:
1. [MobileNet V2: Paper](https://arxiv.org/pdf/1801.04381.pdf)
2. [MobileNet V2: Pytorch Tutorial](https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py)
3. [MobileNet V2: Review Article](https://blog.csdn.net/mzpmzk/article/details/82976871)
