# Setup the environment

## Import

In [4]:
# NN stuff
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from datetime import datetime
from typing import Tuple, List, Optional
from dataclasses import dataclass, fields, astuple

# For cute animation bar
from pprint import pprint
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import json
import random
import numpy as np
import pandas as pd

import csv
import os

## Path

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    root_drive  = "/content/drive/MyDrive/Colab Notebooks/MobileNetV3/"
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print('Not running on CoLab')
    root_drive = './'


spec_small_path = root_drive +  "specification/mobilenetv3-small.json"
spec_large_path = root_drive +  "specification/mobilenetv3-large.json"

## Utils

In [7]:
# MobileNet Specification
@dataclass
class MobileNetSpecification:
    '''Class that contains MobileNet specifications.'''
    k: int
    exp_size: int
    out: int
    se: bool
    nl: str
    s: int

    # The __post_init__ method, will be the last thing called by __init__.
    def __post_init__(self):
        self.k = int(self.k)
        self.exp_size = int(self.exp_size)
        self.out = int(self.out)
        self.se = bool(self.se)
        self.s = int(self.s)

    def __iter__(self):
        yield from astuple(self)

    @staticmethod
    def get_header() -> List[str]:
        return [field.name for field in fields(MobileNetSpecification)]

In [None]:
#swish

# MobileNetv3

In [None]:
class MobileNetV3(nn.Module):
	def __init__(self, mode='small'):
		super(MobileNetV3, self).__init__()
		self.mode = mode

		if mode == 'small':
			self.spec_file = spec_small_path
		else:
			self.spec_file = spec_large_path

		with open(self.spec_file, "r") as spec_f:
			data = json.load(spec_f)
			self.specifications = [MobileNetSpecification(*spec.values()) for spec in data]	
