# MobileNet V3 in PyTorch

Based on this [paper](https://arxiv.org/pdf/1905.02244.pdf)

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torchinfo import summary

In [2]:
def conv_block(in_channels, out_channels, kernel_size=3, 
               stride=1, padding=0, groups=1,
               bias=False, bn=True, act = None):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 
                  padding=padding, groups=groups, bias=bias),
        nn.BatchNorm2d(out_channels) if bn else nn.Identity(),
        act if act else nn.Identity()
    ]
    return nn.Sequential(*layers)

In [3]:
class SEBlock(nn.Module):
    def __init__(self, c, r=4):
        super(SEBlock, self).__init__()
        self.squeeze = nn.AdaptiveMaxPool2d(1)
        self.excitation = nn.Sequential(
            nn.Conv2d(c, c // r, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(c // r, c, kernel_size=1),
            nn.Hardsigmoid()
        )
    def forward(self, x):
        s = self.squeeze(x)
        e = self.excitation(s)
        return x * e

In [4]:
class MBConv(nn.Module):
    """
    An implementation of the Inverted Residual from the paper. Borrowed from the EfficientNet code.
    """
    def __init__(self, n_in, n_out, exp_size, act, kernel_size=3, stride=1, dropout=0.1, use_se=True):
        super(MBConv, self).__init__()
        self.skip_connection = (n_in == n_out) and (stride == 1)
        padding = (kernel_size-1)//2
        
        self.expand_pw = conv_block(n_in, exp_size, kernel_size=1, act=act) if n_in != exp_size else nn.Identity()
        self.depthwise = conv_block(exp_size, exp_size, kernel_size=kernel_size, 
                                    stride=stride, padding=padding, groups=exp_size, act=act)
        self.se = SEBlock(exp_size) if use_se else nn.Identity()
        self.reduce_pw = conv_block(exp_size, n_out, kernel_size=1, act=None)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        residual = x
        x = self.expand_pw(x)
        x = self.depthwise(x)
        x = self.se(x)
        x = self.reduce_pw(x)
        if self.skip_connection:
            x = self.dropout(x)
            x = x + residual
        return x

In [5]:
act_map = [nn.ReLU6(), nn.Hardswish()]

In [6]:
CONFIGS = {
    "large": {
        "kernel_sizes": [3, 3, 3, 5, 5, 5, 3, 3, 3, 3, 3, 3, 5, 5, 5],
        "widths": [16, 16, 24, 24, 40, 40, 40, 80, 80, 80, 80, 112, 112, 160, 160, 160],
        "expansion_size": [16, 64, 72, 72, 120, 120, 240, 200, 184, 184, 480, 672, 672, 960, 960],
        "se": [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        "act": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        "strides": [1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1],
        "out_width": 1280
    },
    "small": {
        "kernel_sizes": [3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5],
        "widths": [16, 16, 24, 24, 40, 40, 40, 48, 48, 96, 96, 96],
        "expansion_size": [16, 72, 88, 96, 240, 240, 120, 144, 288, 576, 576],
        "se": [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        "act": [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        "strides": [2, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1],
        "out_width": 1024
    }
}

In [7]:
class MobileNetV3(nn.Module):
    """
    Generic MobileNet V3 class. Model is determined by CONFIGS object defined above.
    """
    def __init__(self, cfg, n_classes=1000):
        super(MobileNetV3, self).__init__()
        kernel_sizes, widths, expansion_size, se, act, strides = cfg["kernel_sizes"], cfg["widths"], cfg["expansion_size"], cfg["se"], cfg["act"], cfg["strides"]
        
        n_layers = min(len(widths), len(se), len(act), len(kernel_sizes), len(strides))
        
        self.cfg = cfg
        self.stem = conv_block(3, widths[0], stride=2, padding=1)
        
        layers = []
        for i in range(n_layers):
            block = MBConv(widths[i], widths[i + 1], expansion_size[i], 
                           act=act_map[act[i]], kernel_size=kernel_sizes[i], 
                           stride=strides[i], use_se = se[i] == 1)

            layers.append(block)
            
        self.layers = nn.Sequential(*layers)
        self.pre = nn.Sequential(
            conv_block(widths[-1], expansion_size[-1], kernel_size=1, act=nn.Hardswish()),
        )
        self.head = nn.Sequential(
            nn.Linear(expansion_size[-1], cfg['out_width']),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(cfg["out_width"], n_classes)

        )
            
    def forward(self, x):
        x = self.stem(x)
        x = self.layers(x)
        x = self.pre(x)
        x = F.avg_pool2d(x, 7)
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x

In [8]:
def mobilenetv3_large(n_classes = 1000):
    return MobileNetV3(CONFIGS["large"], n_classes=n_classes)
def mobilenetv3_small(n_classes = 1000):
    return MobileNetV3(CONFIGS["small"], n_classes=n_classes)

In [9]:
m_large = mobilenetv3_large()
m_small = mobilenetv3_small()

In [10]:
%%time
inp = torch.randn(16, 3, 224, 224)
m_small(inp).shape

CPU times: user 464 ms, sys: 232 ms, total: 695 ms
Wall time: 277 ms


torch.Size([16, 1000])

In [11]:
import os
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [12]:
print_size_of_model(m_large)
print_size_of_model(m_small)

Size (MB): 22.125713
Size (MB): 10.278745


In [13]:
def fmat(n):
    return "{:.2f}M".format(n / 1_000_000)

In [14]:
def params(model, f=True):
    s = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return fmat(s) if f else s

In [15]:
params(m_large), params(m_small)

('5.48M', '2.54M')

In [17]:
summary(m_large, (1, 3, 224, 224), depth=3)

Layer (type:depth-idx)                        Output Shape              Param #
MobileNetV3                                   --                        --
├─Sequential: 1-1                             [1, 16, 112, 112]         --
│    └─Conv2d: 2-1                            [1, 16, 112, 112]         432
│    └─BatchNorm2d: 2-2                       [1, 16, 112, 112]         32
│    └─Identity: 2-3                          [1, 16, 112, 112]         --
├─Sequential: 1-2                             [1, 160, 7, 7]            --
│    └─MBConv: 2-4                            [1, 16, 112, 112]         --
│    │    └─Identity: 3-1                     [1, 16, 112, 112]         --
│    │    └─Sequential: 3-2                   [1, 16, 112, 112]         176
│    │    └─Identity: 3-3                     [1, 16, 112, 112]         --
│    │    └─Sequential: 3-4                   [1, 16, 112, 112]         288
│    │    └─Dropout: 3-5                      [1, 16, 112, 112]         --
│    └─MBConv: 2-