# Setup the environment

## Import

In [54]:
import os
import json
import torch
import random
import numpy as np
import pandas as pd
from torch import nn
import pytorch_lightning as pl
from torch.nn import functional as F
from typing import Tuple, List, Optional
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass, fields, astuple

# For cute animation bar an plots
from pprint import pprint
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'pytorch_lightning'

## Constants

In [49]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True 

In [50]:
# Check if the GPU is available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
NUM_WORKERS = int(os.cpu_count() / 2)

cpu


## Path

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    root_drive  = "/content/drive/MyDrive/Colab Notebooks/MobileNetV3/"
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print('Not running on CoLab')
    root_drive = './'

spec_small_path = root_drive +  "specification/mobilenetv3-small.json"
spec_large_path = root_drive +  "specification/mobilenetv3-large.json"

## Utils

In [55]:
# MobileNet Specification
@dataclass
class BNeckSpecification:
    '''Class that contains MobileNet specifications.'''
    kernel: int
    input_size: int
    exp_size: int
    out_size: int
    se: bool
    nl: str
    stride: nn.Module

    # The __post_init__ method, will be the last thing called by __init__.
    def __post_init__(self) -> None:
        self.kernel     = int(self.kernel)
        self.input_size = int(self.input_size)
        self.exp_size = int(self.exp_size)
        self.out_size = int(self.out_size)
        self.se  = bool(self.se)
        self.nl  = nn.ReLU(inplace=True) if self.nl == "relu" else Hswish(inplace=True)
        self.stride = int(self.stride)

    def __iter__(self):
        yield from astuple(self)

    @staticmethod
    def get_header() -> List[str]:
        return [field.name for field in fields(BNeckSpecification)]

# MobileNetv3

## nn.Modules

In [24]:
class Hsigmoid(nn.Module):
    def __init__(self, inplace=True) -> None:
        super(Hsigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x) -> torch.Tensor:
        return F.relu6(x + 3, inplace=self.inplace) / 6

In [25]:
class Hswish(nn.Module):
    def __init__(self, inplace=True) -> None:
        super(Hswish, self).__init__()
        self.hsigmoid = Hsigmoid(inplace)

    def forward(self, x) -> torch.Tensor:
        return x * self.hsigmoid(x)

In [26]:
class SeModule(nn.Module):
    def __init__(self, channel, reduction=4) -> None:
        super(SeModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.se = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            Hsigmoid()
        )

    def forward(self, x) -> torch.Tensor:
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.se(y).view(b, c, 1, 1)
        return x * y.expand_as(x)



# class SeModule(nn.Module):
#     def __init__(self, in_size, reduction=4):
#         super(SeModule, self).__init__()
#         self.se = nn.Sequential(
#             nn.AdaptiveAvgPool2d(1),
#             nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(in_size // reduction),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(in_size),
#             hsigmoid()
#         )

#     def forward(self, x):
#         return x * self.se(x)

In [39]:
class BottleNeck(nn.Module):
    def __init__(self, spec:BNeckSpecification) -> None:
        super(BottleNeck, self).__init__()
        self.spec = spec
        
        padding = (spec.kernel - 1) // 2
        self.use_res_connect = spec.stride == 1 and spec.input_size == spec.out_size

        # PointWise
        self.conv2d_pw  = nn.Conv2d(spec.input_size, spec.exp_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.batch_norm = nn.BatchNorm2d(spec.exp_size)
        self.non_lin    = spec.nl

        # DepthWise
        self.conv2d_dw  = nn.Conv2d(spec.exp_size, spec.exp_size, spec.kernel, spec.stride, padding, groups=spec.exp_size,bias=False)
        self.squeeze_ex = SeModule(spec.exp_size)

        # PointWise-linear
        self.conv2d_pw_linear  = nn.Conv2d(spec.exp_size, spec.out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.batch_norm_linear = nn.BatchNorm2d(spec.out_size)

    def forward(self, x) -> torch.Tensor:
        # PointWise
        out = self.conv2d_pw(x)
        out = self.batch_norm(out)
        out = self.non_lin(out)

        # DepthWise
        out = self.conv2d_dw(out)
        out = self.batch_norm(out)
        if self.spec.se: out = self.squeeze_ex(out)
        out = self.non_lin(out)

        # PointWise-linear
        out = self.conv2d_pw_linear(out)
        out = self.batch_norm_linear(out)

        out = x + out if self.use_res_connect else out

        return out

In [46]:
def conv2d_block(input_size:int, output_size:int, kernel_size:int, stride:int=1) -> nn.Sequential:
	return nn.Sequential(
        nn.Conv2d(input_size, output_size, kernel_size, stride, padding=0, bias=False),
        nn.BatchNorm2d(output_size),
        Hswish(inplace=True)
    )

In [47]:
class MobileNetV3(nn.Module):
    def __init__(self, num_class:int, dropout:float, mode='small') -> None:
        super(MobileNetV3, self).__init__()
        self.num_class = num_class
        self.mode = mode

        # Load specifications from file
        self.bneck_specs = self.load_bneck_specs()

        # Generate all the net blocks
        self.net_blocks = [conv2d_block(input_size=3, output_size=16, kernel_size=3, stride=2)]
        self.build_bneck_blocks()
        self.build_last_layers()

        # Transform it nn.Sequential
        self.net_blocks = nn.Sequential(*self.net_blocks)

        # Building the classifier
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),    # refer to paper section 6
            nn.Linear(self.last_channel, self.num_class),
        )

    def forward(self, x) -> torch.Tensor:
        out = self.net_blocks(x)
        # out = out.mean(3).mean(2)
        out = self.classifier(out)
        return out
    
    def load_bneck_specs(self) -> List[BNeckSpecification]:
        if self.mode == 'small':
            self.spec_file = spec_small_path
        else:
            self.spec_file = spec_large_path
        # Load specifications
        with open(self.spec_file, "r") as spec_f:
            data = json.load(spec_f)
            bneck_specs = [BNeckSpecification(*spec.values()) for spec in data]
        return bneck_specs

    def build_bneck_blocks(self) -> None:
        # Building mobile blocks
        for bneck_spec in self.bneck_specs:
            self.net_blocks.append(BottleNeck(bneck_spec))

    def build_last_layers(self) -> None:
        # Building last layers
        input_channel = self.bneck_specs[-1].out_size # Take the last bottleneck output size 
        if self.mode == 'large':
            self.last_conv = 960 # make_divisible(960 * width_mult)
            self.last_channel = 1280
        elif self.mode == 'small':
            self.last_conv = 576 # make_divisible(576 * width_mult)
            self.last_channel = 1024
        
        self.net_blocks.append(conv2d_block(input_channel, self.last_conv, kernel_size=1))
        self.net_blocks.append(nn.AdaptiveAvgPool2d(1)) # or  out = F.avg_pool2d(out, 7)
        self.net_blocks.append(nn.Conv2d(self.last_conv, self.last_channel, kernel_size=1, stride=1, padding=0))
        self.net_blocks.append(Hswish(inplace=True))
        # should I add another layer?
        # self.net_blocks.append(nn.Conv2d(self.last_channel, self.num_class, kernel_size=1, stride=1, padding=0))
        # self.last_channel = self.num_class


## Lightning Module

In [None]:
class MobileNetV3Module(pl.LightningModule):
    def __init__(self, num_classes=1000, dropout=0.8) -> None:
        super().__init__()

        net = MobileNetV3(num_classes, dropout)

    def training_step(self):
        loss = 0
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

## Create the model