In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch as torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models

In [2]:
import pickle as pkl
from tqdm.notebook import tqdm

In [3]:
DEVICE = torch.device("cuda:0")
DIM = 128
IMAGE_SIZE = 112
dtype = torch.cuda.FloatTensor

In [4]:
torch.cuda.is_available()

True

In [5]:
data_ = pkl.load(open('Data_siamese.pickle', 'rb'))

In [6]:
data_.shape

(6317, 20, 112, 112, 3)

In [7]:
data = data_[:4000, :, :, :, :]

In [8]:
model = models.mobilenet_v2(width_mult=1, pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))

In [9]:
class Bottleneck(nn.Module):
    def __init__(self, model, in_dim=1280, out_dim=DIM, spartial=(IMAGE_SIZE+31) // 32):
        super().__init__()
        self.model = model
        self.depthwise = nn.Conv2d(in_dim, in_dim, spartial, bias=False, groups=in_dim)
        self.linear = nn.Linear(in_dim, out_dim)
        
    def forward(self, x):
        x = self.model(x)
        x = self.depthwise(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

In [10]:
model = Bottleneck(model)
model = model.to(DEVICE)

In [12]:
class Probcalc(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.inp = nn.Linear(1, hidden_dim)
        self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        x = x.unsqueeze(-1)
        x = self.inp(x)
        x = F.leaky_relu(x, 0.1)
        x = self.linear1(x)
        x = F.leaky_relu(x, 0.1)
        x = self.linear2(x)
        x = F.leaky_relu(x, 0.1)
        x = self.linear3(x)
        x = F.leaky_relu(x, 0.1)
        x = self.out(x)
        x = x ** 2
        x = x.squeeze(-1)
        return torch.exp(-x)

In [13]:
calcprob = Probcalc().to(DEVICE)

In [14]:
optimizer = torch.optim.Adam(list(model.parameters()) + list(calcprob.parameters()), lr=0.003)

In [15]:
def change_lr(optimizer, epoch):
    lr = 1e-3
    if epoch < 20:
        lr = 1e-3
    elif epoch < 30:
        lr = 1e-3 / 2.0
    else:
        lr = 1e-4
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [16]:
rng = np.random.default_rng()

In [17]:
import cv2
from PIL import Image

In [18]:
def dist(x1, x2):
    return torch.abs(x1 - x2)

In [19]:
def d(x, y):
  # Cosine similarity
  return 1 - (x * y).sum(dim=-1) / (torch.sqrt((x * x).sum(dim=-1)) * torch.sqrt((y * y).sum(dim=-1)))

In [20]:
EPOCHS = 60
STEP_PER_EPOCH = 2000
BATCH = 128

In [21]:
def gen_pair():
    batch_data_1 = torch.zeros(BATCH, 3, 112, 112).byte()
    batch_data_2 = torch.zeros(BATCH, 3, 112, 112).byte()
    batch_target = torch.zeros(BATCH, 1)
    bad = rng.choice(4000, size = BATCH, replace=False)
    good = rng.choice(4000, size = BATCH // 2, replace=True)
    for i in range(BATCH):
        if (i % 2 == 0):
            ind = rng.choice(20, size=2, replace=False)
            batch_data_1[i] = torch.tensor((data[good[i // 2], ind[0]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_data_2[i] = torch.tensor((data[good[i // 2], ind[1]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_target[i] = 1
        else:
            ind = rng.choice(20, size=2, replace=True)
            batch_data_1[i] = torch.tensor((data[bad[i // 2], ind[0]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_data_2[i] = torch.tensor((data[bad[i // 2 + BATCH // 2], ind[1]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_target[i, 0] = 0
    return batch_data_1, batch_data_2, batch_target

In [22]:
from torchvision import transforms

In [23]:
for epoch in range(EPOCHS):
    avg_loss = 0.0
    for s in tqdm(range(STEP_PER_EPOCH)):
        change_lr(optimizer, epoch)
        optimizer.zero_grad()        
        X1, X2, target = gen_pair()
        X1 = X1.to(DEVICE).float()
        X2 = X2.to(DEVICE).float()
        target = target.to(DEVICE).float()
        diff = d(model(X1), model(X2))
        prob = calcprob(diff)
        loss = F.binary_cross_entropy(prob.unsqueeze(-1), target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(calcprob.parameters()), 10)
        optimizer.step()
        avg_loss += loss.item()
        
    print("Train epoch", epoch, "finished with avg_loss", avg_loss / STEP_PER_EPOCH)
    avg_loss = 0
    avg_acc = 0

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 0 finished with avg_loss 0.7005546538829803


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 1 finished with avg_loss 0.6869382807612419


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 2 finished with avg_loss 0.6833848099112511


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 3 finished with avg_loss 0.6813656768798828


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 4 finished with avg_loss 0.6776510311365127


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 5 finished with avg_loss 0.6739863164722919


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 6 finished with avg_loss 0.672089160501957


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 7 finished with avg_loss 0.6699801239073276


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 8 finished with avg_loss 0.6663911433815957


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 9 finished with avg_loss 0.6650240035951137


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 10 finished with avg_loss 0.6634651988446713


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 11 finished with avg_loss 0.6617213295400143


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 12 finished with avg_loss 0.6608727854192257


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 13 finished with avg_loss 0.6597510891258717


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 14 finished with avg_loss 0.6575711988508701


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 15 finished with avg_loss 0.6568473676145077


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 16 finished with avg_loss 0.6546279498040676


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 17 finished with avg_loss 0.6540958048999309


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 18 finished with avg_loss 0.6518770459592342


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 19 finished with avg_loss 0.6509693306088448


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 20 finished with avg_loss 0.6448221949636936


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 21 finished with avg_loss 0.6417228334844113


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 22 finished with avg_loss 0.639461326956749


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 23 finished with avg_loss 0.6369258677959442


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 24 finished with avg_loss 0.6355411625802517


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 25 finished with avg_loss 0.633323041856289


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 26 finished with avg_loss 0.6300939404070377


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 27 finished with avg_loss 0.6301407740414142


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 28 finished with avg_loss 0.6262338322401046


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 29 finished with avg_loss 0.6255683187544346


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 30 finished with avg_loss 0.6192424784302711


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 31 finished with avg_loss 0.6136476902663708


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 32 finished with avg_loss 0.6122951903343201


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 33 finished with avg_loss 0.6099782647788524


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 34 finished with avg_loss 0.6077533559054136


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 35 finished with avg_loss 0.6045248571336269


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 36 finished with avg_loss 0.6041024835705757


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 37 finished with avg_loss 0.6013510505110026


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 38 finished with avg_loss 0.601161992162466


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 39 finished with avg_loss 0.5998744196593762


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 40 finished with avg_loss 0.5984885098785162


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 41 finished with avg_loss 0.5961843260526657


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 42 finished with avg_loss 0.5952776058614254


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 43 finished with avg_loss 0.5933133986592293


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 44 finished with avg_loss 0.5926032464653254


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 45 finished with avg_loss 0.5906897675991059


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 46 finished with avg_loss 0.5886155189275741


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 47 finished with avg_loss 0.5878014213144779


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 48 finished with avg_loss 0.5857852206975221


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 49 finished with avg_loss 0.5849672388285398


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 50 finished with avg_loss 0.5821096986979246


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 51 finished with avg_loss 0.5819014406651258


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 52 finished with avg_loss 0.5812230263054371


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 53 finished with avg_loss 0.577198471441865


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 54 finished with avg_loss 0.5779558226466179


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 55 finished with avg_loss 0.5758484231978654


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 56 finished with avg_loss 0.5745375305116177


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 57 finished with avg_loss 0.5734826191365718


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 58 finished with avg_loss 0.5718943717181683


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 59 finished with avg_loss 0.572985222786665


In [24]:
torch.save(model.state_dict(), "model_ver2.state")
torch.save(calcprob.state_dict(), "dist2prob_ver2.state")

In [29]:
BATCH = 4

In [30]:
def gen_pair_test(true_pair = True):
    batch_data_1 = torch.zeros(BATCH, 3, 112, 112).byte()
    batch_data_2 = torch.zeros(BATCH, 3, 112, 112).byte()
    if true_pair:
        good = rng.choice(2300, size = BATCH, replace=True)
        for i in range(BATCH):
            ind = rng.choice(20, size=2, replace=True)
            batch_data_1[i] = torch.tensor((data[good[i], ind[0]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_data_2[i] = torch.tensor((data[good[i], ind[1]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
    else:
        for i in range(BATCH):
            ind1 = rng.choice(2300, size=2, replace=False)
            ind = rng.choice(20, size=2, replace=True)
            batch_data_1[i] = torch.tensor((data[ind1[0], ind[0]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_data_2[i] = torch.tensor((data[ind1[1], ind[1]] - 255 / 2) /  (255 / 2)).transpose(0, 2)   
        
    return batch_data_1, batch_data_2

In [32]:
cnt = 0.0
for s in tqdm(range(2000)):  
        X1, X2 = gen_pair_test()
        X1 = X1.to(DEVICE).float()
        X2 = X2.to(DEVICE).float()
        diff = d(model(X1), model(X2))
        prob = calcprob(diff)
        cnt += (prob >= 0.5).sum().item()
        
print("Accurancy on true pairs ", cnt / (2000 * BATCH))

cnt1 = 0.0
for s in tqdm(range(2000)):  
        X1, X2 = gen_pair_test(true_pair = False)
        X1 = X1.to(DEVICE).float()
        X2 = X2.to(DEVICE).float()
        diff = d(model(X1), model(X2))
        prob = calcprob(diff)
        cnt1 += (prob < 0.5).sum().item()
        
        
print("Accurancy on false pairs ", cnt1 / (2000 * BATCH))

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Accurancy on true pairs  0.6665


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Accurancy on false pairs  0.5255
