# Setup the environment

## Import

In [4]:
# NN stuff
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from datetime import datetime
from typing import Tuple, List, Optional
from dataclasses import dataclass, fields, astuple

# For cute animation bar
from pprint import pprint
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import json
import random
import numpy as np
import pandas as pd

import csv
import os

## Path

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    root_drive  = "/content/drive/MyDrive/Colab Notebooks/MobileNetV3/"
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print('Not running on CoLab')
    root_drive = './'


spec_small_path = root_drive +  "specification/mobilenetv3-small.json"
spec_large_path = root_drive +  "specification/mobilenetv3-large.json"

## Utils

In [7]:
# MobileNet Specification
@dataclass
class MobileNetSpecification:
    '''Class that contains MobileNet specifications.'''
    kernel: int
    exp_size: int
    out: int
    se: bool
    nl: str
    stride: int

    # The __post_init__ method, will be the last thing called by __init__.
    def __post_init__(self):
        self.kernel = int(self.kernel)
        self.exp_size = int(self.exp_size)
        self.out = int(self.out)
        self.se = bool(self.se)
        self.stride = int(self.stride)

    def __iter__(self):
        yield from astuple(self)

    @staticmethod
    def get_header() -> List[str]:
        return [field.name for field in fields(MobileNetSpecification)]

In [None]:
#swish

# MobileNetv3

In [None]:
class BottleNeck(nn.Module):
    def __init__(self, input_size:int, out_size:int, spec:MobileNetSpecification):
        super(BottleNeck, self).__init__()
        padding = (spec.kernel - 1) // 2
        self.use_res_connect = spec.stride == 1 and input_size == out_size

        # PointWise
        self.conv2d_pw  = nn.Conv2d(input_size, spec.exp_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.batch_norm = nn.BatchNorm2d(spec.exp_size)
        self.non_lin    = nn.ReLU(inplace=True)  if spec.nl == 'relu' else Hswish(inplace=True)

        # DepthWise
        self.conv2d_dw  = nn.Conv2d(spec.exp_size, spec.exp_size, spec.kernel, spec.stride, padding, groups=spec.exp_size,bias=False)
        self.squeeze_ex = SELayer(spec.exp_size) if spec.se else Identity(spec.exp_size)

        # PointWise-linear
        self.conv2d_pw_linear  = nn.Conv2d(spec.exp_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.batch_norm_linear = nn.BatchNorm2d(out_size)

    def forward(self, x):
        # PointWise
        out = self.conv2d_pw(x)
        out = self.batch_norm(out)
        out = self.non_lin(out)

        # DepthWise
        out = self.conv2d_dw(out)
        out = self.batch_norm(out)
        out = self.squeeze_ex(out)
        out = self.non_lin(out)

        # PointWise-linear
        out = self.conv2d_pw_linear(out)
        out = self.batch_norm_linear(out)

        out = x + out if self.use_res_connect else out

        return out

In [None]:
class MobileNetV3(nn.Module):
	def __init__(self, mode='small'):
		super(MobileNetV3, self).__init__()
		self.mode = mode
		self.load_specifications()
	
	def load_specifications(self) -> None:
		# Load specifications
		if self.mode == 'small':
			self.spec_file = spec_small_path
		else:
			self.spec_file = spec_large_path

		with open(self.spec_file, "r") as spec_f:
			data = json.load(spec_f)
			self.specifications = [MobileNetSpecification(*spec.values()) for spec in data]	


	def build_mobile_blocks(self):
		# building mobile blocks
		for spec in self.specifications:
			output_channel = make_divisible(spec.out * width_mult)
			exp_channel = make_divisible(spec.exp_size * width_mult)
			self.features.append(MobileBottleneck(input_channel, output_channel, spec.kernel, spec.stride, exp_channel, spec.se, spec.nl))
			input_channel = output_channel
