# Facial Similarity with Siamese network and LFW dataset

In [10]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms

### DataLoader

In [12]:
class LFWDataset(Dataset):
    def __init__(self, pairs_file, img_folder, transform=None):
        self.pairs_file = pairs_file
        self.img_folder = img_folder
        self.transform = transform
        self.image_pairs, self.labels = self._load_pairs()


    def _load_pairs(self):
        with open(self.pairs_file, 'r') as f:
            pairs = f.readlines()[1:]
        image_pairs = []
        labels = []

        for pair in pairs:
            pair = pair.strip().split()
            if len(pair) == 3:
                person = pair[0]
                img_1 = os.path.join(self.img_folder, person, f"{person}_{int(pair[1]):04d}.jpg")
                img_2 = os.path.join(self.img_folder, person, f"{person}_{int(pair[2]):04d}.jpg")
                label = 1
            else:
                person_1 = pair[0]
                person_2 = pair[2]
                img_1 = os.path.join(self.img_folder, person_1, f"{person_1}_{int(pair[1]):04d}.jpg")
                img_2 = os.path.join(self.img_folder, person_2, f"{person_2}_{int(pair[3]):04d}.jpg")
                label = 0

            image_pairs.append((img_1, img_2))
            labels.append(label)

        return image_pairs, labels
    

    def __getitem__(self, index):
        path_img_1, path_img_2 = self.image_pairs("RGB")
        label = self.labels[index]

        img_1 = Image.open(path_img_1).convert("RGB")
        img_2 = Image.open(path_img_2).convert("RGB")

        if self.transform:
            img_1 = self.transform(img_1)
            img_2 = self.transform(img_2)

        return img_1, img_2, torch.tensor([label], dtype=torch.float32)
    

    def __len__(self):
        return len(self.image_pairs)

transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

lfw_dataset = LFWDataset(pairs_file='/home/mehran/Documents/Datasets/pairsDevTrain.txt', img_folder='/home/mehran/Documents/Datasets/lfw-deepfunneled', transform=transform)
train_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=32, shuffle=True)


### Siamese Network Architecture

In [7]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward_once(self, x):
        x = self.cnn(x)
        x = self.fc(x)
        return x
    
    def forward(self, input_1, input_2):
        output_1 = self.forward_once(input_1)
        output_2 = self.forward_once(input_2)
        return output_1, output_2
    

model = SiameseNetwork()

### Loss Function(Contrastive Loss)

In [9]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    
    def forward(self, output_1, output_2, label):
        euclidean_distance = nn.functional.pairwise_distance(output_1, output_2)
        loss_contrastive = torch.mean(
        (1 - label) * torch.pow(euclidean_distance, 2) +
        label * torch.pow(torch.clamp(
            self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss_contrastive
    

criterion = ContrastiveLoss()