In [1]:
import matplotlib
matplotlib.use('Agg')
import sys
# sys.path.append("/home/ec2-user/anaconda3/external/fastai/")
sys.path.append('/scratch/arka/miniconda3/external/fastai/')


from tqdm import tqdm
tqdm.monitor_interval = 0

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
import json
import pandas as pd
from sklearn.metrics import *

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
PATH = Path("/scratch/arka/miniconda3/external/fastai/courses/dl2/data/cifar10/")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))
def get_data(sz,bs):
    tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
    return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)

In [4]:
def total_model_params(summary):
    total_params = 0
    for key,value in summary.items():
        total_params+=int(summary[key]['nb_params'])
    print("Total parameters in the model :"+str(total_params))

In [5]:
bs=64
sz=32

In [6]:
class expand_depthwise_block(nn.Module):
    ## Thanks to https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py
    def __init__(self, in_c, out_c, expansion, stride):
        super().__init__()
        t = expansion
        tk = in_c * t
        self.stride = stride
        self.l1 = nn.Sequential(nn.Conv2d(in_c, tk, 1, bias=False), nn.BatchNorm2d(tk), nn.ReLU6(inplace=True),
                               nn.Conv2d(tk, tk, 3, groups=tk, padding=1, bias=False), nn.BatchNorm2d(tk), nn.ReLU6(inplace=True),
                               nn.Conv2d(tk, out_c, 1, bias=False), nn.BatchNorm2d(out_c))
        self.resd = nn.Sequential()
        if stride == 1 and in_c != out_c:      
            self.resd = nn.Sequential(nn.Conv2d(in_c, out_c, 1, bias=False), nn.BatchNorm2d(out_c))
            
    def forward(self, inp):
        
        out = self.l1(inp)
        if self.stride == 1:
            out = out + self.resd(inp)
        return out

In [7]:
class mblnetv2(nn.Module):
    def __init__(self, block, inc_scale, inc_start, tuple_list, num_classes):
        super().__init__()
        # assuming tuple list of form:
        # expansion, out_planes, num_blocks, stride 
        self.num_blocks = len(tuple_list)
        self.in_planes = inc_start // inc_scale
        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        lyrs = []
        for expf, inc, nb, strl in tuple_list:
            lyrs.append(self._make_layer(block, expf, inc, nb, strl))
            
        self.lyrs = nn.Sequential(*lyrs)
        self.linear = nn.Linear(tuple_list[-1][1], num_classes)
        
    
    def _make_layer(self, block, expf, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, expf, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def forward(self, inp):
        out = F.relu(self.bn1(self.conv1(inp)))
        out = self.lyrs(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return F.log_softmax(out, dim=-1)


In [8]:
tpl = [(1,  16, 1, 1),
       (6,  24, 2, 1),  
       (6,  32, 3, 2),
       (6,  64, 4, 2),
       (6,  96, 3, 1),
       (6, 160, 3, 2),
       (6, 320, 1, 1)]
md_mbl = mblnetv2(expand_depthwise_block, 1, 32,
                  tpl,
                  num_classes=10)

In [9]:
data = get_data(sz, bs)

In [10]:
learn = ConvLearner.from_model_data(md_mbl, data)

In [None]:
learn.summary()

In [None]:
total_model_params(learn.summary())

In [None]:
learn.fit(1e-2, 5, cycle_len=1, cycle_mult=2, best_save_name='best_mblnetv2_1', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=31), HTML(value='')))

  1%|▏         | 10/782 [00:04<06:02,  2.13it/s, loss=2.3]