In [1]:
import matplotlib
matplotlib.use('Agg')

In [2]:
from tqdm import tqdm
tqdm.monitor_interval = 0

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
import json
import pandas as pd
from sklearn.metrics import *

In [3]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [4]:
from mobile_net import *

In [5]:
from cifar10 import get_data

In [6]:
PATH = Path("/scratch/arka/miniconda3/external/fastai/courses/dl2/data/cifar10/")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [7]:
sz = 32
bs = 128

In [8]:
data = get_data(sz, bs, PATH)

In [9]:
class SiameseDS(BaseDataset):
    def __init__(self, ds, wgt_pos_neg=0.5):
        self.ds = ds
        super().__init__(ds.transform)
        self.label_set = set(data.trn_ds.y)
        self.l2i = {l: np.where(l == ds.y)[0] 
                    for l in self.label_set}
        self.wgt = T(wgt_pos_neg)
    
    def __len__(self):
        return self.get_n()
        
    def get_n(self):
        return self.ds.n
    
    def get_c(self):
        return 2
    
    def get_sz(self):
        return self.ds.sz
    
    def get_x(self, i):
        return self.ds.get_x(i)
    
    def get_y(self, i):
        return self.ds.get_y(i)
    
    def __getitem__(self, idx):
        targ = np.random.randint(0, 2)
        x1, y1 = self.get1item(idx)
        if targ == 1:
            new_idx = np.random.choice(self.l2i[y1])
            x2, y2 = self.get1item(new_idx)
            assert y1 == y2
        else:
            new_c = np.random.choice(sorted(list(self.label_set - {y1})))
            new_idx = np.random.choice(self.l2i[new_c])
            x2, y2 = self.get1item(new_idx)
            assert y1 != y2
        return [x1, x2, targ]

In [10]:
class SiameseData(ImageData):
    @classmethod
    def from_image_classifier_data(cls, data, bs=64, num_workers=4):
        trn_ds = SiameseDS(data.trn_dl.dataset)
        val_ds = SiameseDS(data.val_dl.dataset)
        fix_ds = SiameseDS(data.fix_dl.dataset)
        aug_ds = SiameseDS(data.aug_dl.dataset)
        res = [trn_ds, val_ds, fix_ds, aug_ds]
        if data.test_dl is not None:
            test_ds = SiameseDS(data.test_dl.dataset)
            test_aug = SiameseDS(data.test_aug_dl.dataset)
            res += [test_ds, test_aug]
        else:
            res += [None, None]
        return cls(data.path, res, bs, num_workers, classes=['pos', 'neg'])

In [11]:
class SiameseModel(nn.Module):
    def __init__(self, mdl):
        super().__init__()
        self.mdl = mdl
        self.pdist = nn.PairwiseDistance()
    
#     def emb_out(self, inp):
#         out = F.relu(self.mdl.bn1(self.mdl.conv1(inp)))
#         out = F.adaptive_avg_pool2d(self.mdl.lyrs(out), 1)
#         out = out.view(out.size(0), -1)
#         return out
    
    def forward(self, inp0, inp1):
        o0 = self.mdl.emb_out(inp0)
        o1 = self.mdl.emb_out(inp1)
#         import pdb; pdb.set_trace()
        out = self.pdist(o0, o1)
#         import pdb; pdb.set_trace()
        return out

In [12]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin
    def forward(self, outp, targ, size_average=True):
        tmp1 = targ.float() * outp**2 
        tmp2 = (1 - targ.float())
        tmp3 = F.relu(self.margin - outp)
        tmp4 = tmp2 * tmp3**2
        res = (tmp1 + tmp4) / 2
        return res.mean() if size_average else res.sum()

In [17]:
??nn.Conv2d

In [24]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(3, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(1600, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output
    
    def emb_out(self, x):
        return self.forward(x)

    def get_embedding(self, x):
        return self.forward(x)

In [27]:
??accuracy

In [67]:
def acc_siamese(preds, targs):
    s = np.zeros(targs.shape)
    p = np.where(preds > 0.5)[0]
    s[p] = 1
    return (s==targs.cpu().numpy()).mean()

In [61]:
p,t = learn.predict_with_targs()

In [68]:
# md_mbl = mblnetv1(depthwise_block, 
#               inc_list=[64, 64, 128, 256], 
#               inc_scale = 1, 
#               num_blocks_list=[2, 2, 2], 
#               stride_list=[1, 2, 2], 
#               num_classes=10)
md_mbl = EmbeddingNet()
sia_mdl = SiameseModel(md_mbl)
sia_data = SiameseData.from_image_classifier_data(data)
learn = ConvLearner.from_model_data(sia_mdl, sia_data)

learn.crit = ContrastiveLoss(margin=1)
# learn.predict()
learn.fit(1e-2, 3, cycle_len=1, metrics=[acc_siamese])

HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))

epoch      trn_loss   val_loss   acc_siamese                 
    0      0.148218   0.150396   0.4707    
    1      0.148321   0.148054   0.4771                      
    2      0.14728    0.148373   0.4901                      



[array([0.14837]), 0.4901]

In [None]:
logp = learn.predict()