#### Implementation of MobileNetV2 architecture using PyTorch
Paper: [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381)

MobileNetV2 is an improved version of MobileNetV1, which introduces the concept of inverted residuals and linear bottlenecks. It uses depthwise separable convolutions and introduces a new block called the inverted residual block.

##### Components of MobileNetV2:
1. **Inverted Residual Block**: A residual block where the number of channels is expanded first, then reduced (inverted from traditional ResNet).
2. **Linear Bottleneck**: The final layer of the block uses a linear activation function instead of ReLU to avoid non-linearity in the bottleneck layer.
3. **Depthwise Separable Convolution**: Similar to MobileNetV1, it uses depthwise separable convolutions to reduce the number of parameters and computations.
4. **Skip Connections**: Similar to ResNet, it uses skip connections to allow gradients to flow through the network more easily.

Depthwise Separable Convolution is a key component of MobileNetV2, which consists of two layers:
- Depthwise Convolution: Applies a single filter to each input channel.
- Pointwise Convolution: A 1x1 convolution that combines the outputs of the depthwise convolution.

Linear Bottleneck is used in the final layer of the inverted residual block to reduce the number of parameters and computations while maintaining performance.

Inverted Residual Block:
- Original residual block contains an input followed by several bottlenecks then followed by expansion and the shortcuts exist between thick layers (layers with many channels).
- However, inspired by the intuition that the bottlenecks actually contain all the necessary information and expansion layer acts merely as a non-linear transformation, MobileNetV2 uses shortcuts directly between the bottlenecks (thin layers). Hatched layers use linear activation.
- ReLU6 is used as the non-liner activation because of its robustness when used with low-precision computation.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        hidden_dim = int(in_channels * expand_ratio)
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        layers = []
        if expand_ratio != 1:
            # 1. Expansion (Conv 1x1) -> only change channels
            layers.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))
        
        # 2. Depthwise Convolution
        layers.append(nn.Conv2d(
            in_channels=hidden_dim, 
            out_channels=hidden_dim,
            kernel_size=3, 
            stride=stride,
            padding=1,
            groups=hidden_dim,
            bias=False
        ))
        layers.append(nn.BatchNorm2d(hidden_dim))
        layers.append(nn.ReLU6(inplace=True))

        # 3. Projection (Conv 1x1, no activation)
        layers.append(nn.Conv2d(hidden_dim, out_channels=out_channels, kernel_size=1, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))

        self.block = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_res_connect:
            return x + self.block(x)
        else:
            return self.block(x)

In [None]:
class MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV2, self).__init__()
    
        # (t, c, n, s): (expand_ratio, out_channels, num_blocks, stride)
        self.configs = [
            (1, 16, 1, 1),
            (6, 24, 2, 2),
            (6, 32, 3, 2),
            (6, 64, 4, 2),
            (6, 96, 3, 1),
            (6, 160, 3, 2),
            (6, 320, 1, 1),
        ]

        # initial layer
        input_channels = 32
        layers = [
            nn.Conv2d(3, input_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(input_channels),
            nn.ReLU6(inplace=True)
        ]

        # Inverted Residual blocks
        for t, c, n, s in self.configs:
            output_channel = c
            for i in range(n):
                stride = s if i == 0 else 1
                layers.append(InvertedResidual(input_channels, output_channel, stride, t))
                input_channels = output_channel

        # Final Conv 1x1
        last_channel = 1280
        layers.append(nn.Conv2d(input_channels, last_channel, kernel_size=1, bias=False))
        layers.append(nn.BatchNorm2d(last_channel))
        layers.append(nn.ReLU6(inplace=True))
        self.features = nn.Sequential(*layers)

        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(last_channel, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)  # output: [B, 1280, 1, 1]
        x = x.view(x.size(0), -1)  # flatten to [B, 1280]
        x = self.classifier(x)
        return x

In [None]:
from torchinfo import summary

model = MobileNetV2(num_classes=1000)
x = torch.randn(1, 3, 224, 224)  # dummy input
out = model(x)
print(out.shape)  # -> torch.Size([1, 1000])

summary(model, input_size=(1, 3, 224, 224))

torch.Size([1, 1000])


Layer (type:depth-idx)                   Output Shape              Param #
MobileNetV2                              [1, 1000]                 --
├─Sequential: 1-1                        [1, 1280, 7, 7]           --
│    └─Conv2d: 2-1                       [1, 32, 112, 112]         864
│    └─BatchNorm2d: 2-2                  [1, 32, 112, 112]         64
│    └─ReLU6: 2-3                        [1, 32, 112, 112]         --
│    └─InvertedResidual: 2-4             [1, 16, 112, 112]         --
│    │    └─Sequential: 3-1              [1, 16, 112, 112]         896
│    └─InvertedResidual: 2-5             [1, 24, 56, 56]           --
│    │    └─Sequential: 3-2              [1, 24, 56, 56]           5,136
│    └─InvertedResidual: 2-6             [1, 24, 56, 56]           --
│    │    └─Sequential: 3-3              [1, 24, 56, 56]           8,832
│    └─InvertedResidual: 2-7             [1, 32, 28, 28]           --
│    │    └─Sequential: 3-4              [1, 32, 28, 28]           10,000
│  