In [133]:
import torch
from torch import nn
from torchinfo import summary

In [134]:
class SEBlock(nn.Module):
    def __init__(self, in_channles, r=4):
        super().__init__()

        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(in_channles, in_channles//r),
            nn.ReLU(),
            nn.Linear(in_channles//r, in_channles),
            nn.Hardsigmoid()
        )
    
    def forward(self, x):
        SE = self.squeeze(x)
        SE = SE.reshape(x.shape[0], x.shape[1])
        
        SE = self.fc(SE)
        SE = SE.unsqueeze(dim=2).unsqueeze(dim=3)

        x = x*SE
        return x

In [135]:
class DepSepConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, use_se, use_hs, stride):
        super().__init__()

        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.Hardswish() if use_hs==True else nn.ReLU()
        )

        self.SE_block = SEBlock(in_channels, r=4) if use_se==True else None

        self.seperate = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
            # No Activation at Mobilenet_V3
        )

    def forward(self, x):
        x = self.depthwise(x)

        if self.SE_block is not None:
            x = self.SE_block(x)
            
        x = self.seperate(x)
        return x

In [136]:
class InvertedBottleneck(nn.Module):
    def __init__(self, in_channels, exp_channels, out_channels, kernel_size, use_se, use_hs, stride):
        super().__init__()       

        layers = []

        if in_channels != exp_channels:
            layers += [nn.Sequential(
                nn.Conv2d(in_channels, exp_channels, 1, bias=False),
                nn.BatchNorm2d(exp_channels),
                nn.Hardswish() if use_hs else nn.ReLU()
            )]  
        layers += [DepSepConv(exp_channels, out_channels, kernel_size, use_se=use_se, use_hs=use_hs, stride=stride)]
        self.residual = nn.Sequential(*layers)
        
        self.use_skip = True if stride==1 and in_channels==out_channels else False
    
    def forward(self, x):
        residual = self.residual(x)
        
        if self.use_skip==True:
            output = x + residual
        else:
            output = residual
        
        return output

In [137]:
class MobileNetV3(nn.Module):
    def __init__(self, configs, last_channels, num_classes=1000):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.Hardswish()
        )

        in_channels = 16
        layers = []

        for k, exp, c, use_se, use_hs, s in configs:
            exp_channels = exp
            out_channels = c            
            layers += [InvertedBottleneck(in_channels, exp_channels, out_channels, k, use_se=use_se, use_hs=use_hs, stride=s)]
            in_channels = out_channels        
        self.layers = nn.Sequential(*layers)

        self.last_conv = nn.Sequential(
            nn.Conv2d(in_channels, exp_channels, 1, bias=False),
            nn.BatchNorm2d(exp_channels),
            nn.Hardswish()
        )

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(exp_channels, last_channels),
            nn.Hardswish(),
            nn.Linear(last_channels, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.layers(x)
        x = self.last_conv(x)

        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        
        return x

In [138]:
def MobileNetV3_Large(num_classes=1000):
    configs = [
                #k,  exp, c,  SE,    HS,    s
                [3,  16,  16, False, False, 1],
                [3,  64,  24, False, False, 2],
                [3,  72,  24, False, False, 1],
                [5,  72,  40, True,  False, 2],
                [5, 120,  40, True,  False, 1],
                [5, 120,  40, True,  False, 1],
                [3, 240,  80, False, True,  2],
                [3, 200,  80, False, True,  1],
                [3, 184,  80, False, True,  1],
                [3, 184,  80, False, True,  1],
                [3, 480, 112, True,  True,  1],
                [3, 672, 112, True,  True,  1],
                [5, 672, 160, True,  True,  2],
                [5, 960, 160, True,  True,  1],
                [5, 960, 160, True,  True,  1]
        ]

    return MobileNetV3(configs, last_channels=1280, num_classes=num_classes)

def MobileNetV3_Small(num_classes=1000):
    configs = [
                #k,  exp, c,  SE,    HS,    s
                [3,  16,  16, True,  False, 2],
                [3,  72,  24, False, False, 2],
                [3,  88,  24, False, False, 1],
                [5,  96,  40, True,  True,  2],
                [5, 240,  40, True,  True,  1],
                [5, 240,  40, True,  True,  1],
                [5, 120,  48, True,  True,  1],
                [5, 144,  48, True,  True,  1],
                [5, 288,  96, True,  True,  2],
                [5, 576,  96, True,  True,  1],
                [5, 576,  96, True,  True,  1]
        ]

    return MobileNetV3(configs, last_channels=1024, num_classes=1000)

In [139]:
model = MobileNetV3_Large(num_classes=1000)
summary(model, (2, 3, 224, 224))

Layer (type:depth-idx)                                  Output Shape              Param #
MobileNetV3                                             [2, 1000]                 --
├─Sequential: 1-1                                       [2, 16, 112, 112]         --
│    └─Conv2d: 2-1                                      [2, 16, 112, 112]         432
│    └─BatchNorm2d: 2-2                                 [2, 16, 112, 112]         32
│    └─Hardswish: 2-3                                   [2, 16, 112, 112]         --
├─Sequential: 1-2                                       [2, 160, 7, 7]            --
│    └─InvertedBottleneck: 2-4                          [2, 16, 112, 112]         --
│    │    └─Sequential: 3-1                             [2, 16, 112, 112]         464
│    └─InvertedBottleneck: 2-5                          [2, 24, 56, 56]           --
│    │    └─Sequential: 3-2                             [2, 24, 56, 56]           3,440
│    └─InvertedBottleneck: 2-6                         

In [140]:
model = MobileNetV3_Small(num_classes=1000)
summary(model, (2, 3, 224, 224))

Layer (type:depth-idx)                                  Output Shape              Param #
MobileNetV3                                             [2, 1000]                 --
├─Sequential: 1-1                                       [2, 16, 112, 112]         --
│    └─Conv2d: 2-1                                      [2, 16, 112, 112]         432
│    └─BatchNorm2d: 2-2                                 [2, 16, 112, 112]         32
│    └─Hardswish: 2-3                                   [2, 16, 112, 112]         --
├─Sequential: 1-2                                       [2, 96, 7, 7]             --
│    └─InvertedBottleneck: 2-4                          [2, 16, 56, 56]           --
│    │    └─Sequential: 3-1                             [2, 16, 56, 56]           612
│    └─InvertedBottleneck: 2-5                          [2, 24, 28, 28]           --
│    │    └─Sequential: 3-2                             [2, 24, 28, 28]           3,864
│    └─InvertedBottleneck: 2-6                         