In [1]:
import pandas as pd
import numpy as np
import torch as pt
import cv2 as cv
import matplotlib.pyplot as plt
from tqdm import tqdm
from os.path import exists

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torchvision.datasets import LFWPairs
from torch.utils.data import DataLoader
from torchvision import transforms

In [None]:
training = LFWPairs('./data/train', split='train', download=True, transform=transforms.Compose([
    transforms.Resize((250, 250)),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
]))

validating = LFWPairs('./data/validate', download=True, transform=transforms.Compose([
    transforms.Resize((250, 250)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
]))

testing = LFWPairs('./data/test', split='test', download=True, transform=transforms.Compose([
    transforms.Resize((250, 250)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
]))

In [None]:
torch.manual_seed(42)
training_set = DataLoader(training, batch_size=16, shuffle=True)
validation_set = DataLoader(validating, batch_size=16, shuffle=True)
testing_set = DataLoader(testing, batch_size=16, shuffle=True)

In [13]:
class ContrastiveLoss(nn.Module): 
    
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin
    
    def forward(self, input1, input2, label):
        pairwise_dist = F.pairwise_distance(input1, input2)
        distance = (1 - label) * torch.pow(pairwise_dist, 2) + \
            (label) * torch.pow(torch.clamp(self.margin - pairwise_dist, min=0), 2)
        return distance
        

In [26]:
class SiameseNetwork(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.distance = ContrastiveLoss()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(7, 7), stride=2),
            nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(64),
            
            nn.Conv2d(64, 64, kernel_size=(1, 1)),
            nn.Conv2d(64, 192, kernel_size=(3, 3)),
            nn.BatchNorm2d(192),
            nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),
            
            nn.Conv2d(192, 192, kernel_size=(1, 1)),
            nn.Conv2d(192, 384, kernel_size=(3, 3)),
            nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),
            
            nn.Conv2d(384, 384, kernel_size=(1, 1)),
            nn.Conv2d(384, 256, kernel_size=(3, 3)),
            
            nn.Conv2d(256, 256, kernel_size=(1, 1)),
            nn.Conv2d(256, 256, kernel_size=(3, 3)),
            
            nn.Conv2d(256, 256, kernel_size=(1, 1)),
            nn.Conv2d(256, 256, kernel_size=(3, 3)),
            nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1)
        )
        
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4096, 2048),
            nn.ReLU(inplace=True),
            
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(512, 1),
            nn.Sigmoid()   
        )
    
    def forward_once(self, img):
        return self.fc1(self.cnn(img))
    
    def get_embeddings(self, img):
        return self.forward_once(img)
    
    def forward(self, img1, img2, label):
        out1 = self.forward_once(img1).view(-1, 1)
        out2 = self.forward_once(img2).view(-1, 1)
        dist = self.distance(out1, out2, label).view(1, -1)
        return self.fc2(dist) > 0.5

In [None]:
class Trainer:
    
    def __init__(self, model, criterion, optim, device):
        self.model = model.to(device)
        self.criterion = criterion.to(device)
        self.optim = optim
        self.device = device 
        self.metrics = {
            'train_loss': [],
            'train_acc': [],
            'valid_loss': [],
            'valid_acc': []
        }
    
    def _calc_accuracy(self, preds, label):
        total = 0
        correct = 0
        for i, pred in enumerate(preds):
            correct += 1 if pred == label[i] else 0
            total += 1
        return correct / total    
        
    def _update_metrics(self, loss, accuracy, dataset_type="train"):
        self.metrics[f"{dataset_type}_loss"].append(loss.item())
        self.metrics[f"{dataset_type}_acc"].append(accuracy)
        
    def _train(self, dataset):
        self.model.train()
        for i, (img1, img2, labels) in enumerate(dataset):
            img1, img2, labels = img1.to(self.device), img2.to(self.device), labels.to(self.device)
            preds = self.model(img1, img2, labels)
            loss = self.criterion(preds, labels)
            
            acc = self._calc_accuracy(preds, labels)
            self._update_metrics(loss, acc)
            
            loss.backward()
            self.optim.step()
    
    @torch.no_grad()
    def _valid(self, dataset):
        self.model.eval()
        for i, (img1, img2, labels) in enumerate(dataset):
            img1, img2, labels = img1.to(self.device), img2.to(self.device), labels.to(self.device)
            preds = self.model(img1, img2, labels)
            
            loss = self.criterion(preds, labels)
            acc = self._calc_accuracy(preds, labels)
            self._update_metrics(loss, acc, 'valid')
            
    def _plot_metrics(self, epochs):
        fig, (ax1, ax2) = plt.subplots(1, 2)

        ax1.plot(epochs, self.metrics["train_loss"], label='Train Loss')
        ax1.plot(epochs, self.metrics["valid_loss"], label='Validation Loss') if self.metrics["valid_loss"] else None
        ax1.set_xlabel('Epochs')
        ax1.set_ylabel('Loss')
        ax1.set_title('Loss Curves')
        ax1.legend()

        ax2.plot(epochs, self.metrics["train_acc"], label='Train Accuracy')
        ax2.plot(epochs, self.metrics["valid_acc"], label='Validation Accuracy') if self.metrics["valid_acc"] else None
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Accuracy Curves')
        ax2.legend()

        plt.tight_layout()
        plt.show()
    
    def fit(self, epochs, train_set, valid_set=None, save_every=5):
        for epoch in tqdm(range(epochs)):
            self._train(train_set)
            if epoch % save_every == 0:
                torch.save(self.model.state_dict(), f"checkpoint_{epoch}")
            if valid_set:
                self._valid(valid_set)
        self._plot_metrics(epochs)
        
                
    @torch.no_grad()
    def predict(self, dataset, name="Test"):
        self.model.eval()
        accuracy = []
        for i, (img1, img2, labels) in tqdm(enumerate(dataset)):
            img1, img2, labels = img1.to(self.device), img2.to(self.device), labels.to(self.device)
            preds = self.model(img1, img2, labels)
            
            accuracy.append(self._calc_accuracy(preds, labels))
        
        total_acc = torch.sum(accuracy) / len(accuracy)
        print(f"{name} Accuracy: {total_acc}")
        return total_acc
            

In [None]:
torch.manual_seed(24)
PATH = './model_state.pt'
EPOCHS = 10

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SiameseNetwork()
criterion = nn.BCELoss()
torch.load_state_dict(torch.load(PATH)) if exists(PATH) else None
optim = Adam(model.parameters())

In [None]:
trainer = Trainer(model, criterion, optim, device)
trainer.fit(EPOCHS, training_set, validation_set)

In [None]:
trainer.predict(training_set, "Train")
trainer.predict(testing_set)

In [None]:
torch.save(model.state_dict(), PATH)