# Setup the environment

## Import

In [21]:
# NN stuff
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from datetime import datetime
from typing import Tuple, List, Optional
from dataclasses import dataclass, fields, astuple

# For cute animation bar
from pprint import pprint
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import json
import random
import numpy as np
import pandas as pd

import csv
import os

## Path

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    root_drive  = "/content/drive/MyDrive/Colab Notebooks/MobileNetV3/"
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print('Not running on CoLab')
    root_drive = './'


spec_small_path = root_drive +  "specification/mobilenetv3-small.json"
spec_large_path = root_drive +  "specification/mobilenetv3-large.json"

## Utils

In [23]:
# MobileNet Specification
@dataclass
class MobileNetSpecification:
    '''Class that contains MobileNet specifications.'''
    kernel: int
    input_size: int
    exp_size: int
    out_size: int
    se: bool
    nl: str
    stride: nn.Module

    # The __post_init__ method, will be the last thing called by __init__.
    def __post_init__(self):
        self.kernel     = int(self.kernel)
        self.input_size = int(self.input_size)
        self.exp_size = int(self.exp_size)
        self.out_size = int(self.out_size)
        self.se  = bool(self.se)
        self.nl  = nn.ReLU(inplace=True) if self.nl == "relu" else Hswish(inplace=True)
        self.stride = int(self.stride)

    def __iter__(self):
        yield from astuple(self)

    @staticmethod
    def get_header() -> List[str]:
        return [field.name for field in fields(MobileNetSpecification)]

In [None]:
#swish

# MobileNetv3

In [24]:
class Hsigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(Hsigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return F.relu6(x + 3, inplace=self.inplace) / 6

In [25]:
class Hswish(nn.Module):
    def __init__(self, inplace=True):
        super(Hswish, self).__init__()
        self.hsigmoid = Hsigmoid(inplace)

    def forward(self, x):
        return x * self.hsigmoid(x)

In [26]:
class SeModule(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SeModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.se = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            Hsigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.se(y).view(b, c, 1, 1)
        return x * y.expand_as(x)



# class SeModule(nn.Module):
#     def __init__(self, in_size, reduction=4):
#         super(SeModule, self).__init__()
#         self.se = nn.Sequential(
#             nn.AdaptiveAvgPool2d(1),
#             nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(in_size // reduction),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(in_size),
#             hsigmoid()
#         )

#     def forward(self, x):
#         return x * self.se(x)

In [28]:
class Bottleneck(nn.Module):
    def __init__(self, input_size:int, out_size:int, spec:MobileNetSpecification):
        super(Bottleneck, self).__init__()
        self.spec = spec
        
        padding = (spec.kernel - 1) // 2
        self.use_res_connect = spec.stride == 1 and input_size == out_size

        # PointWise
        self.conv2d_pw  = nn.Conv2d(input_size, spec.exp_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.batch_norm = nn.BatchNorm2d(spec.exp_size)
        self.non_lin    = spec.nl

        # DepthWise
        self.conv2d_dw  = nn.Conv2d(spec.exp_size, spec.exp_size, spec.kernel, spec.stride, padding, groups=spec.exp_size,bias=False)
        self.squeeze_ex = SeModule(spec.exp_size)

        # PointWise-linear
        self.conv2d_pw_linear  = nn.Conv2d(spec.exp_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.batch_norm_linear = nn.BatchNorm2d(out_size)

    def forward(self, x):
        # PointWise
        out = self.conv2d_pw(x)
        out = self.batch_norm(out)
        out = self.non_lin(out)

        # DepthWise
        out = self.conv2d_dw(out)
        out = self.batch_norm(out)
        if self.spec.se: out = self.squeeze_ex(out)
        out = self.non_lin(out)

        # PointWise-linear
        out = self.conv2d_pw_linear(out)
        out = self.batch_norm_linear(out)

        out = x + out if self.use_res_connect else out

        return out

In [29]:
class MobileNetV3(nn.Module):
    def __init__(self, input_channel=16,  mode='small'):
        super(MobileNetV3, self).__init__()
        self.mode = mode

        self.specifications = self.load_specifications()

        # self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)]
        self.build_mobile_blocks()
        self.build_last_layers(input_channel=96, last_channel=576)
        self.classifier = []
    
    def load_specifications(self) -> List[MobileNetSpecification]:
        if self.mode == 'small':
            self.spec_file = spec_small_path
        else:
            self.spec_file = spec_large_path
        # Load specifications
        with open(self.spec_file, "r") as spec_f:
            data = json.load(spec_f)
            specifications = [MobileNetSpecification(*spec.values()) for spec in data]
        return specifications

    def build_mobile_blocks(self):
        # Building mobile blocks
        for spec in self.specifications:
            input_channel   = self.spec.input_size
            exp_channel     = self.spec.exp_size
            output_channel  = self.spec.out_size
            self.features.append(Bottleneck(input_channel, output_channel, spec.kernel, spec.stride, exp_channel, spec.se, spec.nl))

    # def build_last_layers(self, input_channel, last_channel):
    #     # Building last layers
    #     if self.mode == 'large':
    #         last_conv = make_divisible(960 * width_mult)
    #         self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish))
    #         self.features.append(nn.AdaptiveAvgPool2d(1))
    #         self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))
    #         self.features.append(Hswish(inplace=True))
    #     elif self.mode == 'small':
    #         # last_conv = make_divisible(576 * width_mult)
    #         self.features.append(conv_1x1_bn(input_channel, last_channel, nlin_layer=Hswish))
    #         # self.features.append(SEModule(last_conv))  # refer to paper Table2, but I think this is a mistake
    #         self.features.append(nn.AdaptiveAvgPool2d(1))
    #         self.features.append(nn.Conv2d(last_channel, num_classes, 1, 1, 0))
    #         self.features.append(Hswish(inplace=True))


In [30]:
mn3 = MobileNetV3()

[MobileNetSpecification(kernel=3, exp_size=16, out=16, se=True, nl=ReLU(inplace=True), stride=2), MobileNetSpecification(kernel=3, exp_size=72, out=24, se=True, nl=ReLU(inplace=True), stride=2), MobileNetSpecification(kernel=3, exp_size=88, out=24, se=True, nl=ReLU(inplace=True), stride=1), MobileNetSpecification(kernel=5, exp_size=96, out=40, se=True, nl=Hswish(
  (hsigmoid): Hsigmoid()
), stride=2), MobileNetSpecification(kernel=5, exp_size=240, out=40, se=True, nl=Hswish(
  (hsigmoid): Hsigmoid()
), stride=1), MobileNetSpecification(kernel=5, exp_size=240, out=40, se=True, nl=Hswish(
  (hsigmoid): Hsigmoid()
), stride=1), MobileNetSpecification(kernel=5, exp_size=120, out=48, se=True, nl=Hswish(
  (hsigmoid): Hsigmoid()
), stride=1), MobileNetSpecification(kernel=5, exp_size=144, out=48, se=True, nl=Hswish(
  (hsigmoid): Hsigmoid()
), stride=1), MobileNetSpecification(kernel=5, exp_size=288, out=96, se=True, nl=Hswish(
  (hsigmoid): Hsigmoid()
), stride=2), MobileNetSpecification(ke