In [1]:
from all_imports import *

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from cifar10 import *

In [4]:
from mobile_net import *

In [5]:
bs=64
sz=32

In [6]:
data = get_data(sz, bs)

In [10]:
class exp_dw_block(nn.Module):
    ## Thanks to https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py
    def __init__(self, in_c, out_c, expansion, stride):
        super().__init__()
        self.stride = stride
        exp_out_c = in_c * expansion
        
        self.ptwise_conv = nn.Conv2d(in_c, exp_out_c, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(exp_out_c)
        self.dwise_conv = nn.Conv2d(exp_out_c, exp_out_c, kernel_size=3, 
                                    groups=exp_out_c, stride=self.stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(exp_out_c)
        self.lin_conv = nn.Conv2d(exp_out_c, out_c, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_c)
        
        self.res = nn.Sequential()
        if self.stride == 1 and in_c != out_c:
            self.res = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1, bias=False), 
                                    nn.BatchNorm2d(out_c))
    
    def forward(self, inp):
        out = F.relu6(self.bn1(self.ptwise_conv(inp)))
        out = F.relu6(self.bn2(self.dwise_conv(out)))
        out = self.bn3(self.lin_conv(out))
        if self.stride == 1:
            out = out + self.res(inp)
        return out
        


In [11]:
class mblnetv2(nn.Module):
    def __init__(self, block, inc_scale, inc_start, tuple_list, num_classes):
        super().__init__()
        # assuming tuple list of form:
        # expansion, out_planes, num_blocks, stride 
        self.num_blocks = len(tuple_list)
        self.in_planes = inc_start // inc_scale
        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        lyrs = []
        for expf, inc, nb, strl in tuple_list:
            lyrs.append(self._make_layer(block, expf, inc, nb, strl))
            
        self.lyrs = nn.Sequential(*lyrs)
        self.linear = nn.Linear(tuple_list[-1][1], num_classes)
        
    
    def _make_layer(self, block, expf, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, expf, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def forward(self, inp):
        out = F.relu(self.bn1(self.conv1(inp)))
        out = self.lyrs(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return F.log_softmax(out, dim=-1)


In [12]:
tpl = [(1,  16, 1, 1),
       (6,  24, 2, 1),  
       (6,  32, 3, 2),
       (6,  64, 4, 2),
       (6,  96, 3, 1),
       (6, 160, 3, 2),
       (6, 320, 1, 1)]
md_mbl1 = mblnetv2(exp_dw_, 1, 32,
                  tpl,
                  num_classes=10)

In [13]:
data = get_data(sz, bs)

In [15]:
learn = ConvLearner.from_model_data(md_mbl1, data)

total_model_params(learn.summary())

Total parameters in the model :1875162


In [None]:
learn.fit(5e-2, 1, cycle_len=50, use_clr_beta=(20, 13.68, 0.95, 0.85), best_save_name='best_mblnetv2_xp_1', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=50), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.189113   1.240661   0.563     
    1      0.925725   0.87501    0.6918                      
 63%|██████▎   | 496/782 [00:54<00:31,  9.02it/s, loss=0.806]