In [25]:
####

In [26]:
import torch
from torch import nn

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [28]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_grps = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        
        groups = out_channels if use_grps else 1
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding , 
                               groups = groups)
        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.relu = nn.ReLU6()

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.relu(x)
        return x

In [29]:
class Inverted_Res_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 t , 
                 stride = (1 , 1)):
        super(Inverted_Res_Block , self).__init__()
        
        self.use_residual = in_channels == out_channels and stride == 1
        hidden_dim = in_channels * t
        self.conv1 = Conv(in_channels , hidden_dim , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)
        self.conv2 = Conv(hidden_dim , hidden_dim , stride=stride , use_grps=True)
        self.conv3 = Conv(hidden_dim , out_channels , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)
    def forward(self , x):
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.use_residual:
            return x + x_
        else :
            return x

In [None]:
x = torch.randn(2 , 32 , 112 , 112).to(device)
inverted_res_block = Inverted_Res_Block(32 , 16 , 6 , stride=2).to(device)
z = inverted_res_block(x)
z.shape

In [31]:
config = [
          # t , out_channels , repeats , stride
          [1 , 16 , 1 , 1] , 
          [6 , 24 , 1 , 1] , 
          [6 , 32 , 3 , 2] , 
          [6 , 64 , 4 , 2] , 
          [6 , 96 , 3 , 1] , 
          [6 , 160 , 3 , 2] , 
          [6 , 320 , 1 , 1] , 
]

In [32]:
class MobileNet_v2(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels_model , 
                 config = config):
        super(MobileNet_v2 , self).__init__()

        self.layers = nn.ModuleList()

        self.layers.append(Conv(3 , 32 , stride=(2 , 2)))
        in_channels = 32
        for layer in config:
            if isinstance(layer , list):
                t , out_channels , repeats , stride = layer
                for _ in range(repeats):
                    self.layers.append(
                        Inverted_Res_Block(
                            in_channels , 
                            out_channels , 
                            t , 
                            stride
                        )
                    )
                    in_channels = out_channels
                    out_channels = out_channels
                in_channels = out_channels
        
        self.adp_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_last = Conv(in_channels , out_channels_model , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)

    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        x = self.adp_avg_pool(x)
        x = self.conv_last(x)
        return x.squeeze(-1).squeeze(-1)        

In [None]:
x = torch.randn(2 , 3 , 224 , 224).to(device)
mobile_net = MobileNet_v2(3 , 1280).to(device)
z = mobile_net(x)
z.shape