In [1]:
from all_imports import *

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from cifar10 import *

In [4]:
from mobile_net import *

In [5]:
bs=64
sz=32

In [6]:
data = get_data(sz, bs)

In [7]:
class Block(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, in_planes, out_planes, expansion, stride):
        super(Block, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = out + self.shortcut(x) if self.stride==1 else out
        return out

In [8]:
class expand_depthwise_block(nn.Module):
    ## Thanks to https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py
    def __init__(self, in_c, out_c, expansion, stride):
        super().__init__()
        t = expansion
        tk = in_c * t
        self.stride = stride
        self.l1 = nn.Sequential(nn.Conv2d(in_c, tk, 1, bias=False), nn.BatchNorm2d(tk), nn.ReLU6(inplace=True),
                               nn.Conv2d(tk, tk, 3, groups=tk, padding=1, bias=False), 
                                nn.BatchNorm2d(tk), nn.ReLU6(inplace=True),
                               nn.Conv2d(tk, out_c, 1, bias=False), nn.BatchNorm2d(out_c))
        self.resd = nn.Sequential()
        if stride == 1 and in_c != out_c:      
            self.resd = nn.Sequential(nn.Conv2d(in_c, out_c, 1, bias=False), nn.BatchNorm2d(out_c))
            
    def forward(self, inp):
        
        out = self.l1(inp)
        if self.stride == 1:
            out = out + self.resd(inp)
        return out

In [9]:
class mblnetv2(nn.Module):
    def __init__(self, block, inc_scale, inc_start, tuple_list, num_classes):
        super().__init__()
        # assuming tuple list of form:
        # expansion, out_planes, num_blocks, stride 
        self.num_blocks = len(tuple_list)
        self.in_planes = inc_start // inc_scale
        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        lyrs = []
        for expf, inc, nb, strl in tuple_list:
            lyrs.append(self._make_layer(block, expf, inc, nb, strl))
            
        self.lyrs = nn.Sequential(*lyrs)
        self.linear = nn.Linear(tuple_list[-1][1], num_classes)
        
    
    def _make_layer(self, block, expf, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, expf, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def forward(self, inp):
        out = F.relu(self.bn1(self.conv1(inp)))
        out = self.lyrs(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return F.log_softmax(out, dim=-1)


In [10]:
tpl = [(1,  16, 1, 1),
       (6,  24, 2, 1),  
       (6,  32, 3, 2),
       (6,  64, 4, 2),
       (6,  96, 3, 1),
       (6, 160, 3, 2),
       (6, 320, 1, 1)]

In [11]:
md_mbl_blck = mblnetv2(Block, 1, 32,
                      tpl,
                      num_classes=10)

In [13]:
md_mbl_exp = mblnetv2(expand_depthwise_block, 1, 32,
                      tpl,
                      num_classes=10)

In [12]:
data = get_data(sz, bs)

In [13]:
learn_blck = ConvLearner.from_model_data(md_mbl_blck, data)

In [16]:
learn_exp = ConvLearner.from_model_data(md_mbl_exp, data)

In [14]:
total_model_params(learn_blck.summary())

Total parameters in the model :1875162


In [18]:
total_model_params(learn_exp.summary())

Total parameters in the model :1875162


In [15]:
learn_blck.fit(5e-2, 5, cycle_len=1, best_save_name='best_mblnetv2_blk_1', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.149386   1.102802   0.6022    
    1      0.948014   0.855649   0.6948                      
    2      0.772911   0.69118    0.7557                      
    3      0.671254   0.595959   0.7901                      
    4      0.594443   0.542263   0.8121                      



[array([0.54226]), 0.8121]

In [23]:
learn_blck.load('best_mblnetv2_blk_1')
learn_blck.unfreeze()

In [24]:
learn_blck.fit(1e-2, 5, cycle_len=1, best_save_name='best_mblnetv2_blk_2', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.540091   0.504126   0.8287    
    1      0.519431   0.488398   0.8347                      
    2      0.490579   0.475887   0.8395                      
    3      0.450472   0.463159   0.8423                      
    4      0.456585   0.444863   0.8498                      



[array([0.44486]), 0.8498]

In [25]:
learn_blck.fit(1e-2, 5, cycle_len=1, use_clr=(20, 5), best_save_name='best_mblnetv2_blk_3', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.462152   0.437694   0.8546    
    1      0.429822   0.42111    0.8568                      
    2      0.414158   0.41187    0.8601                      
    3      0.39125    0.401361   0.8639                      
    4      0.39197    0.401618   0.8654                      



[array([0.40162]), 0.8654]

In [26]:
learn_blck.fit(5e-2, 5, cycle_len=1, use_clr=(20, 5), best_save_name='best_mblnetv2_blk_4', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.46978    0.43491    0.8512    
    1      0.41638    0.406135   0.8584                      
    2      0.424425   0.388045   0.8682                      
    3      0.38399    0.378791   0.87                        
    4      0.370013   0.367938   0.8729                      



[array([0.36794]), 0.8729]

In [None]:
learn

In [30]:
learn_blck.fit(5e-2, 5, cycle_len=10, use_clr_beta=(20, 13.68, 0.95, 0.85), best_save_name='best_mblnetv2_blk_5', metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=50), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.396722   0.389371   0.8685    
    1      0.3895     0.386122   0.8671                      
    2      0.396267   0.411666   0.8571                      
    3      0.393032   0.420157   0.8532                      
    4      0.343123   0.393091   0.8694                      
    5      0.329349   0.362433   0.8773                      
    6      0.309325   0.354853   0.8806                      
    7      0.26548    0.321819   0.8894                      
    8      0.195675   0.289179   0.9042                      
    9      0.175346   0.287945   0.9041                      
    10     0.261564   0.314469   0.8932                      
    11     0.271705   0.346759   0.8858                      
    12     0.283616   0.358257   0.8805                      
    13     0.287176   0.351186   0.8826                      
    14     0.289707   0.366931   0.8785                      
    15     0.277409   0.32

[array([0.28656]), 0.9209]