In [1]:
import os, sys
import torch
import torchvision
import torch.nn as nn

### Setting block

In [39]:
###Reference: https://github.com/xiaolai-sqlai/mobilenetv3/blob/master/mobilenetv3.py

class Hswish(nn.Module):
    def forward(self, x):
        out = x * nn.functional.relu6(x + 3., inplace=True) / 6.
        
        return out
    
class Hsigmoid(nn.Module):
    def forward(self, x):
        out = nn.functional.relu6(x + 3., inplace=True) / 6.
        
        return out
    
class depthwise_conv(nn.Module):
    def __init__(self, in_channels, kernels_per_layer, out_channels, groups, **kwarg):
        super(depthwise_conv, self).__init__()
        
        assert in_channels%groups == 0, "Groups Error: in_channels should be divisible by groups"
            
        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=kernels_per_layer*in_channels,
                                        groups=in_channels,
                                        **kwarg)
        
    def forward(self, x):
        out = self.depthwise_conv(x)
        
        return out
    
class SE_Block(nn.Module):
    def __init__(self, in_channels, squeeze_ratio):
        super(SE_Block, self).__init__()
        
        assert 0 < squeeze_ratio <= 1, 'Squeeze ratio for squeeze and excitation should be within (0,1]'
        
        squeeze_channel = int(in_channels * squeeze_ratio)
        
        self.se_block = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels, squeeze_channel, kernel_size=(1,1), stride=1, padding=0), #use 1x1 conv replace to dense
            nn.ReLU(inplace=True),
            nn.Conv2d(squeeze_channel, in_channels, kernel_size=(1,1), stride=1, padding=0),
            Hsigmoid()
        )
    
    def forward(self, x):
        return x * self.se_block(x)
    
class MobileV3_Block(nn.Module):
    def __init__(self, kernel_size, in_channels, expand_ratio, out_channels, activation_fn, se_block, strides, bias):
        super(MobileV3_Block, self).__init__()
        
        assert type(kernel_size) == int
        
        self.strides = strides
        self.se_block = se_block
        self.use_residual = strides == 1 or strides == (1,1)
        hidden_dim = round(in_channels * expand_ratio)
        
        self.v2_block = nn.Sequential(
            #pw: expansion
            nn.Conv2d(in_channels, hidden_dim, kernel_size=(1,1), stride=1, padding=0, bias=bias),
            nn.BatchNorm2d(hidden_dim),
            activation_fn,
            
            #dp
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=self.strides,
                      padding=kernel_size//2, groups=hidden_dim, bias=bias),
            nn.BatchNorm2d(hidden_dim),
            activation_fn,
            
            #pw: compression
            nn.Conv2d(hidden_dim, out_channels, kernel_size=(1,1), stride=1, padding=0, bias=bias),
            nn.BatchNorm2d(out_channels),
        )
        
        self.shortcut_block = nn.Sequential()
        
        if self.use_residual and in_channels != out_channels:
            self.shortcut_block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), stride=1, padding=0, bias=bias),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        out = self.v2_block(x)
        if self.se_block != None:
            out = self.se_block(out)
            
        out = out + self.shortcut_block(x) if self.use_residual else out
        
        return out

### Testing on Mnist

In [40]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [49]:
squeeze_excitation_block = SE_Block(in_channels=64, squeeze_ratio=0.25)
model = MobileV3_Block(kernel_size=3, in_channels=1, expand_ratio=4,
                       out_channels=64, activation_fn=Hswish(), se_block=squeeze_excitation_block,
                       strides=1, bias=False)

In [50]:
out = model(x)
out.size()

torch.Size([64, 64, 28, 28])

### Reference:
1. [MobileNet V3: Paper](https://arxiv.org/pdf/1905.02244.pdf)
2. [MobileNet V3: Pytorch Tutorial](https://github.com/xiaolai-sqlai/mobilenetv3/blob/master/mobilenetv3.py)
3. [MobileNet V3: Review Article](https://www.jiqizhixin.com/articles/2019-05-09-2)