In [1]:
import matplotlib
matplotlib.use('Agg')

In [2]:
from tqdm import tqdm
tqdm.monitor_interval = 0

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
import json
import pandas as pd
from sklearn.metrics import *

In [3]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [4]:
from mobile_net import *

In [5]:
from cifar10 import get_data

In [6]:
PATH = Path("/scratch/arka/miniconda3/external/fastai/courses/dl2/data/cifar10/")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [7]:
sz = 32
bs = 128

In [8]:
data = get_data(sz, bs, PATH)

In [9]:
class SiameseDS(BaseDataset):
    def __init__(self, ds, wgt_pos_neg=0.5):
        self.ds = ds
        super().__init__(ds.transform)
        self.label_set = set(data.trn_ds.y)
        self.l2i = {l: np.where(l == ds.y)[0] 
                    for l in self.label_set}
        self.wgt = T(wgt_pos_neg)
    
    def __len__(self):
        return self.get_n()
        
    def get_n(self):
        return self.ds.n
    
    def get_c(self):
        return 2
    
    def get_sz(self):
        return self.ds.sz
    
    def get_x(self, i):
        return self.ds.get_x(i)
    
    def get_y(self, i):
        return self.ds.get_y(i)
    
    def __getitem__(self, idx):
        targ = np.random.randint(0, 2)
        x1, y1 = self.get1item(idx)
        if targ == 1:
            new_idx = np.random.choice(self.l2i[y1])
            x2, y2 = self.get1item(new_idx)
            assert y1 == y2
        else:
            new_c = np.random.choice(sorted(list(self.label_set - {y1})))
            new_idx = np.random.choice(self.l2i[new_c])
            x2, y2 = self.get1item(new_idx)
            assert y1 != y2
        return [x1, x2, y1, y2, targ, targ]

In [10]:
class SiameseData(ImageData):
    @classmethod
    def from_image_classifier_data(cls, data, bs=64, num_workers=4):
        trn_ds = SiameseDS(data.trn_dl.dataset)
        val_ds = SiameseDS(data.val_dl.dataset)
        fix_ds = SiameseDS(data.fix_dl.dataset)
        aug_ds = SiameseDS(data.aug_dl.dataset)
        res = [trn_ds, val_ds, fix_ds, aug_ds]
        if data.test_dl is not None:
            test_ds = SiameseDS(data.test_dl.dataset)
            test_aug = SiameseDS(data.test_aug_dl.dataset)
            res += [test_ds, test_aug]
        else:
            res += [None, None]
        return cls(data.path, res, bs, num_workers, classes=['pos', 'neg'])

In [25]:
class SiameseModel(nn.Module):
    def __init__(self, mdl):
        super().__init__()
        self.mdl = mdl
        self.pdist = nn.PairwiseDistance()
        self.ffc = nn.Linear(1, 2)
    
#     def emb_out(self, inp):
#         out = F.relu(self.mdl.bn1(self.mdl.conv1(inp)))
#         out = F.adaptive_avg_pool2d(self.mdl.lyrs(out), 1)
#         out = out.view(out.size(0), -1)
#         return out
    
    def forward(self, inp0, inp1, y0, y1, targ):
        o0 = self.mdl.emb_out(inp0)
        o1 = self.mdl.emb_out(inp1)
#         import pdb; pdb.set_trace()
        out = self.pdist(o0, o1)
        out = F.log_softmax(self.ffc(out),dim=-1)
#         import pdb; pdb.set_trace()
        return out

In [12]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin
    def forward(self, outp, targ, size_average=True):
        tmp1 = targ.float() * outp**2 
        tmp2 = (1 - targ.float())
        tmp3 = F.relu(self.margin - outp)
        tmp4 = tmp2 * tmp3**2
        res = (tmp1 + tmp4) / 2
        return res.mean() if size_average else res.sum()

In [14]:
def acc_siamese(preds, targs):
    s = np.zeros(targs.shape)
    p = np.where(preds > 0.5)[0]
    s[p] = 1
    return (s==targs.cpu().numpy()).mean()

In [32]:
md_mbl = mblnetv1(depthwise_block, 
              inc_list=[64, 64, 128, 256], 
              inc_scale = 1, 
              num_blocks_list=[2, 2, 2], 
              stride_list=[1, 2, 2], 
              num_classes=10)

In [18]:
learn1 = ConvLearner.from_model_data(md_mbl, data)
learn1.load('mobilenetv1_3')

In [24]:
learn1.fit(0, 1)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.259186   0.339189   0.8875    



[array([0.33919]), 0.8875]

In [33]:
# md_mbl = EmbeddingNet()
# md_mbl_trained = learn1.model
sia_mdl = SiameseModel(md_mbl)
sia_data = SiameseData.from_image_classifier_data(data)
learn = ConvLearner.from_model_data(sia_mdl, sia_data)

In [20]:
# learn.crit = ContrastiveLoss(margin=1)
# learn.predict()
# learn.fit(1e-2, 3, cycle_len=1, metrics=[acc_siamese])

In [34]:
learn.fit(1e-1, 3, cycle_len=1)

HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.604433   0.596333   0.6711    
    1      0.593759   0.580405   0.6879                      
    2      0.567241   0.559725   0.7069                      



[array([0.55973]), 0.7069]

In [35]:
learn.fit(1e-1, 2, cycle_len=10)

HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.598594   0.602307   0.6647    
    1      0.569675   0.550207   0.7152                      
    2      0.546522   0.521318   0.7298                      
    3      0.520268   0.505417   0.7481                      
    4      0.515452   0.478688   0.7666                      
    5      0.484826   0.464501   0.7705                      
    6      0.46484    0.446984   0.7863                      
    7      0.447556   0.431713   0.793                       
    8      0.451877   0.420704   0.8004                      
    9      0.439181   0.425464   0.7976                      
    10     0.497102   0.463512   0.7743                      
    11     0.471967   0.448574   0.786                       
    12     0.461787   0.458407   0.7814                      
    13     0.444516   0.428252   0.7969                      
    14     0.437356   0.398396   0.8135                      
    15     0.413933   0.39

[array([0.35737]), 0.8348]

In [35]:
learn.fit(1e-1, 2, cycle_len=10)

HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.598594   0.602307   0.6647    
    1      0.569675   0.550207   0.7152                      
    2      0.546522   0.521318   0.7298                      
    3      0.520268   0.505417   0.7481                      
    4      0.515452   0.478688   0.7666                      
    5      0.484826   0.464501   0.7705                      
    6      0.46484    0.446984   0.7863                      
    7      0.447556   0.431713   0.793                       
    8      0.451877   0.420704   0.8004                      
    9      0.439181   0.425464   0.7976                      
    10     0.497102   0.463512   0.7743                      
    11     0.471967   0.448574   0.786                       
    12     0.461787   0.458407   0.7814                      
    13     0.444516   0.428252   0.7969                      
    14     0.437356   0.398396   0.8135                      
    15     0.413933   0.39

[array([0.35737]), 0.8348]

In [36]:
learn.fit(1e-1, 2, cycle_len=10)

HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.438301   0.433398   0.7931    
    1      0.428252   0.416338   0.8025                      
    2      0.415245   0.409257   0.8058                      
    3      0.397022   0.391332   0.8139                      
    4      0.389665   0.376615   0.8265                      
    5      0.381606   0.351701   0.8411                      
    6      0.355832   0.34792    0.8417                      
    7      0.370487   0.341653   0.8445                      
    8      0.346154   0.334467   0.8518                      
    9      0.333973   0.335092   0.8496                      
    10     0.386522   0.379303   0.823                       
    11     0.376059   0.369022   0.8306                      
    12     0.386341   0.35766    0.8406                      
    13     0.36435    0.360092   0.8383                      
    14     0.365056   0.348125   0.8437                      
    15     0.339913   0.33

[array([0.31813]), 0.8615]

In [37]:
learn.fit(1e-1, 2, cycle_len=20)

HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.319367   0.317896   0.8624    
    1      0.367173   0.396415   0.8164                      
    2      0.38018    0.403097   0.817                       
    3      0.35979    0.333851   0.849                       
    4      0.366961   0.347397   0.8452                      
    5      0.342214   0.349037   0.8444                      
    6      0.339792   0.334612   0.8518                      
    7      0.348944   0.322016   0.8597                      
    8      0.327411   0.340066   0.8528                      
    9      0.317035   0.320986   0.8619                      
    10     0.318865   0.30911    0.8635                      
    11     0.303005   0.315545   0.8629                      
    12     0.312601   0.307441   0.8643                      
    13     0.288525   0.292337   0.873                       
    14     0.290255   0.285307   0.8751                      
    15     0.298328   0.29

[array([0.26984]), 0.888]

In [38]:
learn.save('mbnet1_siam1')

In [40]:
learn.load('mbnet1_siam1')

In [41]:
learn.fit(1e-2, 2, cycle_len=1, best_save_name='mbnet_siam2')

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.261653   0.268922   0.885     
    1      0.255375   0.266008   0.8895                      



[array([0.26601]), 0.8895]

In [43]:
learn.load('mbnet_siam2')
learn.unfreeze()

In [44]:
learn.fit(1e-1, 2, cycle_len=1, use_clr=(20, 5), best_save_name='mbnet_siam3')

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.269921   0.275691   0.8798    
    1      0.274408   0.285801   0.8761                      



[array([0.2858]), 0.8761]