# MobileNet V1 in PyTorch

In [1]:
import torch
from torch import nn
from torchinfo import summary

In [2]:
def conv_block(in_channels, out_channels, kernel_size=3, 
               stride=1, padding=0, groups=1,
               bias=False, bn=True, act = True):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 
                  padding=padding, groups=groups, bias=bias),
        nn.BatchNorm2d(out_channels) if bn else nn.Identity(),
        nn.ReLU() if act else nn.Identity()
    ]
    return nn.Sequential(*layers)

In [3]:
class MBConv(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, dropout=0.1):
        super(MBConv, self).__init__()
        self.skip_connection = (n_in == n_out) and (stride == 1)
        padding = (kernel_size-1)//2
        expanded = expansion*n_in
        
        self.depthwise = conv_block(n_in, n_in, kernel_size=kernel_size, 
                                    stride=stride, padding=padding, groups=n_in)
        self.reduce_pw = conv_block(n_in, n_out, kernel_size=1, act=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.depthwise(x)
        x = self.reduce_pw(x)
        return x

In [None]:
### Obtained from Paper ###
widths = [32, 64, 128, 256, 512]
depths = [1, 2, 3, 4, 3, 3, 1]
strides = [1, 2, 2, 2, 1, 2, 1]