In [1]:
from torch import nn
import os
import torchvision
from torch.nn import functional as F
import torch
import random
import argparse, random, copy
import numpy as np
from PIL import Image
from torchvision import transforms as T
from torch.optim.lr_scheduler import StepLR
from matplotlib import pyplot as plt
import torch.optim as optim
from torchvision.models import vgg16




In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data_file = "dataset/dataset/train"

In [4]:
# using vgg16 because number of variables in vgg19 are very large and it is taking too much time to train and also giving memory error
# so although I am using vgg16 but I have written vggg19 in the code so that it can be easily changed to vgg19
class SiameseNN(nn.Module):
    def __init__(self):

        super(SiameseNN, self).__init__()
        self.conolution = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 8, kernel_size=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, kernel_size=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(3, 3),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(4, 4),
        )
        self.fc = nn.Sequential(
            nn.Linear(2496, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, 8),
        )

    def forward_once(self, x):
        output = self.conolution(x)
        output = output.view(output.size()[0], -1)
        return output

    def forward(self, inp1):
        output = self.forward_once(inp1)
        output = self.fc(output)
        return output

In [5]:
def check(path1, path2):
    if (path1[22:].split('/')[0]==path2[22:].split('/')[0]):
        return True
    return False


In [6]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    distance_positive = F.pairwise_distance(anchor, positive)
    distance_negative = F.pairwise_distance(anchor, negative)
    loss = torch.clamp(margin + distance_positive - distance_negative, min=0.0)
    return loss.mean()

In [7]:
transformer = T.CenterCrop((150, 2000))

In [8]:
def train(img_pairs, model, loss_fn, optimizer, batch_size, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        random.shuffle(img_pairs)
        for i in range(0, len(img_pairs), batch_size):
            batch = img_pairs[i:i+batch_size]
            anchor_images = []
            positive_images = []
            negative_images = []
            for pair in batch:
                anchor_path, positive_path = pair
                anchor_image = Image.open(anchor_path).convert("L")
                positive_image = Image.open(positive_path).convert("L")
                anchor_image = transformer(anchor_image)
                positive_image = transformer(positive_image)
                negative_path = random.choice(img_pairs)[0]
                while check(anchor_path, negative_path):
                    negative_path = random.choice(img_pairs)[0]
                negative_image = Image.open(negative_path).convert("L")
                negative_image = transformer(negative_image)
                anchor_images.append(T.ToTensor()(anchor_image))
                positive_images.append(T.ToTensor()(positive_image))
                negative_images.append(T.ToTensor()(negative_image))
            
            anchor_images = torch.stack(anchor_images).to(device)
            positive_images = torch.stack(positive_images).to(device)
            negative_images = torch.stack(negative_images).to(device)

            optimizer.zero_grad()
            anchor_embeddings = model(anchor_images)
            positive_embeddings = model(positive_images)
            negative_embeddings = model(negative_images)

            loss = loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
            if ((i//batch_size+1)%500==0):
                print(f"Epoch: {epoch+1}, iteration: {i//batch_size+1}, loss: {loss.item()}")
        print("\n\n------------------------------------")
        print(f"Epoch: {epoch+1}, loss: {epoch_loss}")
        print("------------------------------------\n\n")
        
        


In [79]:
def test(img_pair, label, model):
    model.eval()
    anchor_path, test_path = img_pair
    anchor_image = Image.open(anchor_path).convert("L")
    test_image = Image.open(test_path).convert("L")
    test_image = transformer(test_image)
    anchor_image = transformer(anchor_image)

    anchor_tensor = T.ToTensor()(anchor_image).unsqueeze(0).to(device)
    test_tensor = T.ToTensor()(test_image).unsqueeze(0).to(device)

    anchor_embedding = model(anchor_tensor)
    test_embedding = model(test_tensor)

    distance = F.pairwise_distance(anchor_embedding, test_embedding)
    if (label == 1 and distance < 1) or (label == 0 and distance >= 1):
        return True
    return False

In [10]:
model = SiameseNN().to(device)

In [11]:
# total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
total_params

420296

In [12]:
img_pairs = []
anchors = dict()

In [13]:
for fld in os.listdir(data_file):
    img_set = os.listdir(data_file+"/"+fld)
    anchors[fld] = data_file+"/"+fld+"/"+img_set[0]
    for i in range(len(img_set)):
        for j in range(i+1, len(img_set)):
            img = img_set[i]
            img2 = img_set[j]
            if (img!=img2):
                img_pairs.append([data_file+"/"+fld+"/"+img, data_file+"/"+fld+"/"+img2])


In [15]:
batch_size = 16
num_epochs = 10
learning_rate = 0.001

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = triplet_loss

train(img_pairs, model, loss_fn, optimizer, batch_size, num_epochs)


Epoch: 1, iteration: 500, loss: 0.35691121220588684
Epoch: 1, iteration: 1000, loss: 0.5152056813240051
Epoch: 1, iteration: 1500, loss: 0.42794567346572876
Epoch: 1, iteration: 2000, loss: 0.2779752016067505
Epoch: 1, iteration: 2500, loss: 0.4004034698009491


------------------------------------
Epoch: 1, loss: 1307.9787173718214
------------------------------------


Epoch: 2, iteration: 500, loss: 0.35440436005592346
Epoch: 2, iteration: 1000, loss: 0.5129297971725464
Epoch: 2, iteration: 1500, loss: 0.8183356523513794
Epoch: 2, iteration: 2000, loss: 0.461360365152359
Epoch: 2, iteration: 2500, loss: 0.21959036588668823


------------------------------------
Epoch: 2, loss: 1267.2192949354649
------------------------------------


Epoch: 3, iteration: 500, loss: 0.37616997957229614
Epoch: 3, iteration: 1000, loss: 0.3220243453979492
Epoch: 3, iteration: 1500, loss: 0.5128110647201538
Epoch: 3, iteration: 2000, loss: 0.7328042984008789
Epoch: 3, iteration: 2500, loss: 0.3928639292

KeyboardInterrupt: 

In [18]:
torch.save(model, "model.pt")

In [19]:
torch.save(model.state_dict(), "model.pth")

In [80]:
correct = 0
total = 0
for img_pair in random.choices(img_pairs, k=5000):
    if (test(img_pair, 1, model)):
        correct += 1
    total += 1

In [81]:
correct

3934

In [82]:
total

5000

In [83]:
ind = 0
correct = 0
total = 0
for img_pair in random.choices(img_pairs, k=5000):
    ind += 1
    if (ind==10000):
        break
    img1_path, _ = img_pair
    img2_path = random.choice(img_pairs)[0]
    while(check(img1_path, img2_path)):
        img2_path = random.choice(img_pairs)[0]
    if (test([img1_path, img2_path], 0, model)):
        correct += 1
    total += 1

In [84]:
correct

3407

In [85]:
total

5000