In [148]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch as torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models

In [149]:
import pickle as pkl
from tqdm.notebook import tqdm

In [150]:
DEVICE = torch.device("cuda:0")
DIM = 128
IMAGE_SIZE = 112
dtype = torch.cuda.FloatTensor

In [151]:
torch.cuda.is_available()

True

In [152]:
data_ = pkl.load(open('Data_siamese.pickle', 'rb'))

In [153]:
data_.shape

(6317, 20, 112, 112, 3)

In [168]:
data = data[:4000, :, :, :, :]

In [169]:
model = models.mobilenet_v2(width_mult=1, pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))

In [170]:
class Bottleneck(nn.Module):
    def __init__(self, model, in_dim=1280, out_dim=DIM, spartial=(IMAGE_SIZE+31) // 32):
        super().__init__()
        self.model = model
        self.depthwise = nn.Conv2d(in_dim, in_dim, spartial, bias=False, groups=in_dim)
        self.linear = nn.Linear(in_dim, out_dim)
        
    def forward(self, x):
        x = self.model(x)
        x = self.depthwise(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

In [171]:
model = Bottleneck(model)
model = model.to(DEVICE)

In [172]:
class Probcalc(nn.Module):
    def __init__(self, in_dim = 128):
        super().__init__()
        self.layer = nn.Linear(in_dim, 1)
    
    def forward(self, x):
        x = self.layer(x)
        x = torch.sigmoid(x)
        return x

In [173]:
calcprob = Probcalc().to(DEVICE)

In [174]:
optimizer = torch.optim.Adam(list(model.parameters()) + list(calcprob.parameters()), lr=0.003)

In [175]:
def change_lr(optimizer, epoch):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-3 if epoch < 30 else 1e-4

In [176]:
rng = np.random.default_rng()

In [177]:
import cv2
from PIL import Image

In [178]:
def dist(x1, x2):
    return torch.abs(x1 - x2)

In [179]:
EPOCHS = 60
STEP_PER_EPOCH = 2000
BATCH = 128

In [183]:
def gen_pair():
    batch_data_1 = torch.zeros(BATCH, 3, 112, 112).byte()
    batch_data_2 = torch.zeros(BATCH, 3, 112, 112).byte()
    batch_target = torch.zeros(BATCH, 1)
    bad = rng.choice(4000, size = BATCH, replace=False)
    good = rng.choice(4000, size = BATCH // 2, replace=True)
    for i in range(BATCH):
        if (i % 2 == 0):
            ind = rng.choice(20, size=2, replace=False)
            batch_data_1[i] = torch.tensor((data[good[i // 2], ind[0]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_data_2[i] = torch.tensor((data[good[i // 2], ind[1]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_target[i] = 1
        else:
            ind = rng.choice(20, size=2, replace=True)
            batch_data_1[i] = torch.tensor((data[bad[i // 2], ind[0]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_data_2[i] = torch.tensor((data[bad[i // 2 + BATCH // 2], ind[1]] - 255 / 2) /  (255 / 2)).transpose(0, 2)
            batch_target[i, 0] = 0
    return batch_data_1, batch_data_2, batch_target

In [184]:
from torchvision import transforms

In [185]:
for epoch in range(EPOCHS):
    avg_loss = 0.0
    for s in tqdm(range(STEP_PER_EPOCH)):
        change_lr(optimizer, epoch)
        optimizer.zero_grad()        
        X1, X2, target = gen_pair()
        X1 = X1.to(DEVICE).float()
        X2 = X2.to(DEVICE).float()
        target = target.to(DEVICE).float()
        diff = dist(model(X1), model(X2))
        prob = calcprob(diff)
        loss = F.binary_cross_entropy(prob, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(calcprob.parameters()), 10)
        optimizer.step()
        avg_loss += loss.item()
        
    print("Train epoch", epoch, "finished with avg_loss", avg_loss / STEP_PER_EPOCH)
    avg_loss = 0
    avg_acc = 0

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 0 finished with avg_loss 0.6938010583519936


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 1 finished with avg_loss 0.6905699066817761


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 2 finished with avg_loss 0.6877190639972687


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 3 finished with avg_loss 0.6869825836122035


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 4 finished with avg_loss 0.6849676965773106


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 5 finished with avg_loss 0.6829261487722397


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 6 finished with avg_loss 0.6806220610737801


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 7 finished with avg_loss 0.676718344271183


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 8 finished with avg_loss 0.674365771740675


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 9 finished with avg_loss 0.6730069693624974


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 10 finished with avg_loss 0.6714949445128441


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 11 finished with avg_loss 0.6695337740182876


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 12 finished with avg_loss 0.6683293931186199


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 13 finished with avg_loss 0.6661673392951488


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 14 finished with avg_loss 0.6645199703872204


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 15 finished with avg_loss 0.6635359118878842


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 16 finished with avg_loss 0.662490732461214


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 17 finished with avg_loss 0.6610523048341275


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 18 finished with avg_loss 0.659051618874073


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 19 finished with avg_loss 0.6590144867002964


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 20 finished with avg_loss 0.6574895200133324


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 21 finished with avg_loss 0.6576705891489982


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 22 finished with avg_loss 0.6553261959850788


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 23 finished with avg_loss 0.6545511888861656


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 24 finished with avg_loss 0.6543206753134727


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 25 finished with avg_loss 0.6510918318331241


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 26 finished with avg_loss 0.6511752507090569


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 27 finished with avg_loss 0.6510797645449639


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 28 finished with avg_loss 0.6482979970574378


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 29 finished with avg_loss 0.6472031606435775


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 30 finished with avg_loss 0.6411068694293499


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 31 finished with avg_loss 0.6370683892965316


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 32 finished with avg_loss 0.6352505640685558


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 33 finished with avg_loss 0.6328861267268657


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 34 finished with avg_loss 0.630656957000494


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 35 finished with avg_loss 0.6296594120264053


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 36 finished with avg_loss 0.628438313394785


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 37 finished with avg_loss 0.6268924472332


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 38 finished with avg_loss 0.6254522507786751


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 39 finished with avg_loss 0.6238077172338963


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 40 finished with avg_loss 0.6233115270733833


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 41 finished with avg_loss 0.6218768092393875


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 42 finished with avg_loss 0.620036832511425


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 43 finished with avg_loss 0.6191580193340779


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 44 finished with avg_loss 0.6178618906140327


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 45 finished with avg_loss 0.6165306119024754


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 46 finished with avg_loss 0.6167994079887867


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 47 finished with avg_loss 0.6136142076253891


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 48 finished with avg_loss 0.6112606583237647


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 49 finished with avg_loss 0.6111368442177773


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 50 finished with avg_loss 0.6107633477151394


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 51 finished with avg_loss 0.6080972201228142


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 52 finished with avg_loss 0.6074958393871784


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 53 finished with avg_loss 0.6053764556050301


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 54 finished with avg_loss 0.6056299603283405


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 55 finished with avg_loss 0.6030126135647297


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 56 finished with avg_loss 0.603149858251214


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 57 finished with avg_loss 0.6013134385347366


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 58 finished with avg_loss 0.5992722234725952


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train epoch 59 finished with avg_loss 0.6004420094788074


In [186]:
torch.save(model.state_dict(), "model.state")
torch.save(calcprob.state_dict(), "dist2prob.state")