**MobileNetV2: Inverted Residuals and Linear Bottlenecks**    
*Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen*   
[[paper](https://arxiv.org/abs/1801.04381)]   
CVPR 2018   

In [2]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class InvertedResidualConv(nn.Module):
    def __init__(self, in_dim, out_dim, expand_ratio, stride=1) -> None:
        super(InvertedResidualConv, self).__init__()

        self.use_residual = in_dim == out_dim and stride == 1
        hidden_dim = int(in_dim * expand_ratio)

        self.expand = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels= hidden_dim, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6()
        )

        self.dwise = nn.Sequential(
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, stride=stride, padding=stride, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6()
        )

        self.project = nn.Sequential(
            nn.Conv2d(in_channels=hidden_dim, out_channels=out_dim, kernel_size=1, padding=1, bias=False),
            nn.BatchNrom2d(out_dim)
        )

    def forward(self, x):

        h = self.expand(x)
        h = self.dwise(h)
        h = self.project(h)

        if self.use_residual:
            h += x 

        return h

In [None]:
class MobileNetV2(nn.Module):
    def __init__(self, init_dim=32, num_classes=1000) -> None:
        super(MobileNetV2, self).__init__()

        self.dim = init_dim

        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=self.dim, kernel_size=3, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(self.dim),
            nn.ReLU6()
        )

        # c 32 -> 16, t=1, n=1, s=1
        self.bottleneck1 = InvertedResidualConv(in_dim=self.dim, out_dim=self.dim//2, expand_ratio=1, stride=1)
        
        # c 16 -> 24, t=6, n=2, s=2
        self.bottleneck2 = nn.Sequential(
            *([InvertedResidualConv(in_dim=self.dim//2, out_dim=self.dim*3//4, expand_ratio=6, stride=2)]
            + [InvertedResidualConv(in_dim=self.dim*3//4, out_dim=self.dim*3//4, expand_ratio=6, stride=1) for _ in range(1)])
        )

        # c 24 -> 32, t=6, n=3, s=2
        self.bottleneck3 = nn.Sequential(
            *([InvertedResidualConv(in_dim=self.dim*3//4, out_dim=self.dim, expand_ratio=6, stride=2)]
            + [InvertedResidualConv(in_dim=self.dim,      out_dim=self.dim, expand_ratio=6, stride=1) for _ in range(2)])
        )

        # c 32 -> 64, t=6, n=4, s=2
        self.bottleneck4 = nn.Sequential(
            *([InvertedResidualConv(in_dim=self.dim,   out_dim=self.dim*2, expand_ratio=6, stride=2)]
            + [InvertedResidualConv(in_dim=self.dim*2, out_dim=self.dim*2, expand_ratio=6, stride=1) for _ in range(3)])
        )

        # c 64 -> 96, t=6, n=3, s=1
        self.bottleneck5 = nn.Sequential(
            *([InvertedResidualConv(in_dim=self.dim*2, out_dim=self.dim*3, expand_ratio=6, stride=1)]
            + [InvertedResidualConv(in_dim=self.dim*3, out_dim=self.dim*3, expand_ratio=6, stride=1) for _ in range(2)])
        )

        # c 96 -> 160, t=6, n=3, s=2
        self.bottleneck6 = nn.Sequential(
            *([InvertedResidualConv(in_dim=self.dim*3, out_dim=self.dim*5, expand_ratio=6, stride=2)]
            + [InvertedResidualConv(in_dim=self.dim*5, out_dim=self.dim*5, expand_ratio=6, stride=1) for _ in range(1)])
        )

        # c 160 -> 320, t=6, n=1, s=1
        self.bottleneck7 = InvertedResidualConv(in_dim=self.dim*5, out_dim=self.dim*10, expand_ratio=6, stride=1)

        self.last_conv = nn.Sequential(
            nn.Conv2d(in_channels=self.dim*10, out_channels=self.dim*40, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.dim*40),
            nn.ReLU6()
        )

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Conv2d(in_channels=self.dim*40, out_channels=num_classes, kernel_size=1, stride=1)

    def forward(self, x):

        h = self.init_conv(x)

        h = self.bottleneck1(h)
        h = self.bottleneck2(h)
        h = self.bottleneck3(h)
        h = self.bottleneck4(h)
        h = self.bottleneck5(h)
        h = self.bottleneck6(h)
        h = self.bottleneck7(h)

        p = self.pool(h)

        out = self.fc(p)

        return out

