In [28]:
from torch import nn
import os
import torchvision
from torch.nn import functional as F
import torch
import random
import argparse, random, copy
import numpy as np
from PIL import Image
import torch.optim as optim
from torchvision import transforms as T
from torch.optim.lr_scheduler import StepLR
from matplotlib import pyplot as plt

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
data_file = "dataset/dataset/train"

In [31]:
class SiameseNN(nn.Module):
    def __init__(self):
        super(SiameseNN, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.fc_in_features = self.resnet.fc.in_features
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])

        self.fc = nn.Sequential(
            nn.Linear(2*self.fc_in_features, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1)
        )
        self.sigmoid = nn.Sigmoid()
        self.resnet.apply(self.init_weights)
        self.fc.apply(self.init_weights)
    
    def init_weights(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
    
    def forward_once(self, x):
        output = self.resnet(x)
        output = output.view(output.size()[0], -1)
        return output

    def forward(self, inp1, inp2):
        ouput1 = self.forward_once(inp1)
        ouput2 = self.forward_once(inp2)
        output = torch.cat((ouput1, ouput2), 1)
        output = self.fc(output)
        output = self.sigmoid(output)
        return output

In [32]:
def triplet_loss(v1, v2, margin=1):
    scores = torch.matmul(v1, v2.t())
    class_size = scores.size(-2)
    positive = scores.diagonal(dim1=-2, dim2=-1)
    negative_without_positive = scores - torch.eye(class_size).to(scores.device) * 2
    negative = torch.max(negative_without_positive, dim=-1)[0]
    negative_zero_on_duplicate = scores * (1 - torch.eye(class_size).to(scores.device))
    mean_negative = negative_zero_on_duplicate.sum(dim=-1) / (class_size - 1)
    triplet_loss1 = torch.clamp(margin + negative - positive, min=0)
    triplet_loss2 = torch.clamp(margin + mean_negative - positive, min=0)
    triplet_loss = triplet_loss1.mean() + triplet_loss2.mean()
    return triplet_loss.mean()

transform = T.Resize((120, 2000))
ind = 0

def train(img_path_pair, label, model, loss_fn, optimizer):
    global ind
    print("Training", ind) if (ind%1000 == 0) else None
    ind += 1
    model.train()
    optimizer.zero_grad()
    img1_path, img2_path = img_path_pair
    img1 = Image.open(img1_path)
    img1 = transform(img1)
    img2 = Image.open(img2_path)
    img2 = transform(img2)
    label = torch.tensor(label).float().to(device)
    label = label.reshape(1, 1)
    img1 = torch.tensor(np.array(img1)).float().unsqueeze(0).to(device)
    img2 = torch.tensor(np.array(img2)).float().unsqueeze(0).to(device)
    img1 = img1.reshape(1, 1, 120, 2000)
    img2 = img2.reshape(1, 1, 120, 2000)
    output = model(img1, img2)
    loss = loss_fn(output, label)
    loss.backward()
    optimizer.step()
    return loss.item()


In [33]:
model = SiameseNN().to(device)



In [34]:
img_pairs = []

In [35]:
for fld in os.listdir(data_file):

    img_set = os.listdir(data_file+"/"+fld)
    for img in img_set:
        for img2 in img_set:
            if (img!=img2):
                img_pairs.append([data_file+"/"+fld+"/"+img, data_file+"/"+fld+"/"+img2])
        

In [36]:
img_pairs[0][0][22:]

'P577/B3.jpg'

In [37]:
def check(path1, path2):
    if (path1[22:].split('/')[0]==path2[22:].split('/')[0]):
        return True
    return False

In [38]:
for image_pair in img_pairs:
    train(image_pair, 1, model, triplet_loss, optim.Adam(model.parameters(), lr=0.0001))
    new_img = random.choice(img_pairs)[0]
    while(check(new_img, image_pair[0])):
        new_img = random.choice(img_pairs)[0]
    train((new_img, image_pair[0]), 0, model, triplet_loss, optim.Adam(model.parameters(), lr=0.0001))

Training 0
Training 1000
Training 2000
Training 3000
Training 4000
Training 5000
Training 6000
Training 7000
Training 8000
Training 9000
Training 10000
Training 11000
Training 12000
Training 13000
Training 14000
Training 15000
Training 16000
Training 17000
Training 18000
Training 19000
Training 20000
Training 21000
Training 22000
Training 23000
Training 24000
Training 25000
Training 26000
Training 27000
Training 28000
Training 29000
Training 30000
Training 31000
Training 32000
Training 33000
Training 34000
Training 35000
Training 36000
Training 37000
Training 38000
Training 39000
Training 40000
Training 41000
Training 42000
Training 43000
Training 44000
Training 45000
Training 46000
Training 47000
Training 48000
Training 49000
Training 50000
Training 51000
Training 52000
Training 53000
Training 54000
Training 55000
Training 56000
Training 57000
Training 58000
Training 59000
Training 60000
Training 61000
Training 62000
Training 63000
Training 64000
Training 65000
Training 66000
Training

KeyboardInterrupt: 

In [None]:
torch.save(model, "model.pt")

In [41]:
torch.save(model.state_dict(), "model.pth")

In [43]:
def test(img_path_pair, label, model):
    img1_path, img2_path = img_path_pair
    img1 = Image.open(img1_path)
    img1 = transform(img1)
    img2 = Image.open(img2_path)
    img2 = transform(img2)
    label = torch.tensor(label).float().to(device)
    label = label.reshape(1, 1)
    img1 = torch.tensor(np.array(img1)).float().unsqueeze(0).to(device)
    img2 = torch.tensor(np.array(img2)).float().unsqueeze(0).to(device)
    img1 = img1.reshape(1, 1, 120, 2000)
    img2 = img2.reshape(1, 1, 120, 2000)
    output = model(img1, img2)
    if (output==label):
        return 1
    return 0

In [49]:
test(img_pairs[0], 1, model)

0