In [1]:
import matplotlib
matplotlib.use('Agg')

In [2]:
from tqdm import tqdm
tqdm.monitor_interval = 0

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
import json
import pandas as pd
from sklearn.metrics import *

In [4]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [5]:
from mobile_net import *

In [7]:
from cifar10 import get_data

In [6]:
PATH = Path("/scratch/arka/miniconda3/external/fastai/courses/dl2/data/cifar10/")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [8]:
sz = 32
bs = 128

In [9]:
data = get_data(sz, bs, PATH)

In [11]:
data.trn_ds.transform

[<fastai.transforms.Scale object at 0x7f9e307e45f8>, <fastai.transforms.AddPadding object at 0x7f9e307e4940>, <fastai.transforms.RandomFlip object at 0x7f9e307e4f98>, <fastai.transforms.RandomCrop object at 0x7f9e307e4a90>, <fastai.transforms.Normalize object at 0x7f9e307e47b8>, <fastai.transforms.ChannelOrder object at 0x7f9e307e4d68>]

In [22]:
data.trn_ds.sz

32

In [133]:
c0 = 0
c1 = 0
tot = 0
for t in tqdm(range(1000)):
#     tmp = torch.LongTensor(1).random_(0, 2)
#     tmp = torch.bernoulli(T(0.5))
    tmp = np.random.randint(0, 2)
    
    c1 += tmp
    c0 += 1 - tmp
    tot += 1
assert (tot == c0 + c1)
print(float(c0) / tot)
print(float(c1) / tot)

100%|██████████| 1000/1000 [00:00<00:00, 585142.86it/s]
0.509
0.491


In [128]:
np.random.randint(0, 2)

1

In [None]:
np.random.choice()

In [332]:
class SiameseDS(BaseDataset):
    def __init__(self, ds, wgt_pos_neg=0.5):
        self.ds = ds
        super().__init__(ds.transform)
        self.label_set = set(data.trn_ds.y)
        self.l2i = {l: np.where(l == ds.y)[0] 
                    for l in self.label_set}
        self.wgt = T(wgt_pos_neg)
    
    def __len__(self):
        return self.get_n()
        
    def get_n(self):
        return self.ds.n
    
    def get_c(self):
        return 2
    
    def get_sz(self):
        return self.ds.sz
    
    def get_x(self, i):
        return self.ds.get_x(i)
    
    def get_y(self, i):
        return self.ds.get_y(i)
    
    def __getitem__(self, idx):
        targ = np.random.randint(0, 2)
        x1, y1 = self.get1item(idx)
        if targ == 1:
            new_idx = np.random.choice(self.l2i[y1])
            x2, y2 = self.get1item(new_idx)
            assert y1 == y2
        else:
            new_c = np.random.choice(sorted(list(self.label_set - {y1})))
            new_idx = np.random.choice(self.l2i[new_c])
            x2, y2 = self.get1item(new_idx)
            assert y1 != y2
        return [x1, x2, targ]

In [333]:
class SiameseData(ImageData):
    @classmethod
    def from_image_classifier_data(cls, data, bs=64, num_workers=4):
        trn_ds = SiameseDS(data.trn_dl.dataset)
        val_ds = SiameseDS(data.val_dl.dataset)
        fix_ds = SiameseDS(data.fix_dl.dataset)
        aug_ds = SiameseDS(data.aug_dl.dataset)
        res = [trn_ds, val_ds, fix_ds, aug_ds]
        if data.test_dl is not None:
            test_ds = SiameseDS(data.test_dl.dataset)
            test_aug = SiameseDS(data.test_aug_dl.dataset)
            res += [test_ds, test_aug]
        else:
            res += [None, None]
        return cls(data.path, res, bs, num_workers, classes=['pos', 'neg'])

In [334]:
class SiameseModel(nn.Module):
    def __init__(self, mdl):
        super().__init__()
        self.mdl = mdl
        self.pdist = nn.PairwiseDistance()
    
    def emb_out(self, inp):
        out = F.relu(self.mdl.bn1(self.mdl.conv1(inp)))
        out = F.adaptive_avg_pool2d(self.mdl.lyrs(out), 1)
        out = out.view(out.size(0), -1)
        return out
    
    def forward(self, inp0, inp1):
        o0 = self.emb_out(inp0)
        o1 = self.emb_out(inp1)
#         import pdb; pdb.set_trace()
        out = self.pdist(o0, o1)
#         import pdb; pdb.set_trace()
        return out

In [335]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin
    def forward(self, outp, targ, size_average=True):
        tmp1 = targ.float() * outp**2 
        tmp2 = (1 - targ.float())
        tmp3 = F.relu(self.margin - outp)
        tmp4 = tmp2 * tmp3**2
        res = (tmp1 + tmp4) / 2
        return res.mean() if size_average else res.sum()

In [342]:
??Callback

In [346]:
md_mbl = mblnetv1(depthwise_block, 
              inc_list=[64, 64, 128, 256], 
              inc_scale = 1, 
              num_blocks_list=[2, 2, 2], 
              stride_list=[1, 2, 2], 
              num_classes=10)

sia_mdl = SiameseModel(md_mbl)
sia_data = SiameseData.from_image_classifier_data(data)
learn = ConvLearner.from_model_data(sia_mdl, sia_data)

learn.crit = ContrastiveLoss(margin=1)
# learn.predict()
learn.fit(1e-2, 1, cycle_len=1)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

  0%|          | 0/782 [00:00<?, ?it/s]--Return--
> <ipython-input-345-96d3b0999615>(3)on_batch_begin()->None
-> import pdb; pdb.set_trace()
(Pdb) ls
*** NameError: name 'ls' is not defined
(Pdb) q



BdbQuit: 

In [337]:
learn.fit(1e-2, 1, cycle_len=1)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.125619   0.125715   0.5017    



[array([0.12572]), 0.5017]