Import Libraries

In [None]:
%matplotlib inline
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor


Configuration

In [None]:
class ConfigClass():

    train_dir = './data/faces/training'
    test_dir = './data/faces/testing'
    
    train_batch_size = 64
    train_num_epochs = 100

Siamese Dataset

In [None]:
class SiameseDataset(Dataset):

    def __init__(self, image_folder, transform = None, invert = False):
        self.image_folder = image_folder
        self.transform = transform
        self.invert = invert

    def __getitem__(self, index):
        sample_1 = random.choice(self.image_folder.imgs)
        same_class = random.randint(0, 1)

        if same_class:
            while True:
                sample_2 = random.choice(self.image_folder.imgs)
                if sample_1[1] == sample_2[1]:
                    break
        else:
            while True:
                sample_2 = random.choice(self.image_folder.imgs)
                if sample_1[1] != sample_2[1]:
                    break

        img_1 = Image.open(sample_1[0]).convert("L")
        img_2 = Image.open(sample_2[0]).convert("L")

        if self.invert:
            img_1 = PIL.ImageOps.invert(img_1)
            img_2 = PIL.ImageOps.invert(img_2)

        if self.transform:
            img_1 = self.transform(img_1)
            img_2 = self.transform(img_2)

        return img_1, img_2, torch.from_numpy(np.array([int(sample_1[1] != sample_2[1])], dtype = np.float32))

    def __len__(self):
        return len(self.image_folder.imgs)

Model Definition

In [None]:
class Model(nn.Module):

    def __init__(self, weights_path="Model_Finished.pth"):
        super(Model, self).__init__()
        

        # Load the weights from the saved model
        if weights_path is not None:
            pretrained_dict = torch.load(weights_path)
            model_dict = self.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            self.load_state_dict(model_dict)

    def forward(self, img1, img2):
        
        output1 = self.layer1(img1)
        output2 = self.layer2(img2)
        output = torch.cat((output1, output2), dim=1)
        return output

image_folder = dataset.ImageFolder(root=ConfigClass.train_dir)
siamese_dataset = SiameseDataset(image_folder=image_folder,
                                 transform=transforms.Compose([transforms.Resize((100, 100)), transforms.ToTensor()]),
                                 invert=False)

In [None]:
def imshow(img, text=None):
    np_img = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic', fontweight='bold', bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [None]:
dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=8, batch_size=8)
data_iter = iter(dataloader)
vis_batch = next(data_iter)
merged = torch.cat((vis_batch[0], vis_batch[1]), 0)
imshow(torchvision.utils.make_grid(merged))
vis_batch[2].numpy()

Siamese Network

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        self.reflection_pad = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3)
        self.conv3 = nn.Conv2d(8, 8, kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.batch_norm1 = nn.BatchNorm2d(4)
        self.batch_norm2 = nn.BatchNorm2d(8)
        self.fc1 = nn.Linear(8 * 100 * 100, 500)
        self.fc2 = nn.Linear(500, 500)
        self.fc3 = nn.Linear(500, 5)

    def forward_one_branch(self, x):
        x = self.batch_norm1(self.relu(self.conv1(self.reflection_pad(x))))
        x = self.batch_norm2(self.relu(self.conv2(self.reflection_pad(x))))
        x = self.batch_norm2(self.relu(self.conv3(self.reflection_pad(x))))
        x = x.view(x.size()[0], -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

    def forward(self, input1, input2):
        output1 = self.forward_one_branch(input1)
        output2 = self.forward_one_branch(input2)

        return output1, output2

Constrastive Loss

In [None]:
class ConstrastiveLoss(torch.nn.Module):

    def __init__(self, margin=2.0):
        super(ConstrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        distance = F.pairwise_distance(output1, output2, keepdim=True)
        contrastive_loss = torch.mean((1 - label)*torch.pow(distance, 2)
                                      + (label)*torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))

        return contrastive_loss

In [None]:
dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=8, batch_size=ConfigClass.train_batch_size)
model = SiameseNetwork()
criterion = ConstrastiveLoss()  # Replace this with your own implementation or a compatible alternative
optimizer = optim.Adam(model.parameters(), lr=0.0005)

counter = []
loss_history = []
iteration = 0

Traning Optimization

In [None]:

dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=0, batch_size=ConfigClass.train_batch_size // 2)

for epoch in range(ConfigClass.train_num_epochs):
    for i, data in enumerate(dataloader, 0):
        input1, input2, label = data
        optimizer.zero_grad()

        # Move the tensors to the desired device (CPU or CUDA) inside the DataLoader loop
        input1, input2, label = input1.to(torch.device('cpu')), input2.to(torch.device('cpu')), label.to(torch.device('cpu'))

        output1, output2 = model(input1, input2)
        contrastive_loss = criterion(output1, output2, label)
        contrastive_loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            print("Epoch: {} \t Loss: {}".format(epoch, contrastive_loss.item()))
            iteration += 10
            loss_history.append(contrastive_loss.item())
            counter.append(iteration)


show_plot(counter, loss_history)

Sample Test

In [None]:
test_folder = dataset.ImageFolder(root=ConfigClass.test_dir)
siamese_dataset = SiameseDataset(image_folder=test_folder,
                                 transform=transforms.Compose([transforms.Resize((100, 100)), transforms.ToTensor()]),
                                 invert=False)

dataloader = DataLoader(siamese_dataset, num_workers=6, batch_size=1, shuffle=True)
data_iter = iter(dataloader)
img0, _, _ = next(data_iter)

for i in range(10):
    _, img1, label = next(data_iter)
    merged = torch.cat((img0,img1), 0)

    output1, output2 = model(img0, img1)
    distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(merged), 'Dissimilarity: {:.2f}'.format(distance.item()))

Performance Evaluation

In [None]:
image_path_0 = "/content/2a.png"
image_path_1 = "/content/gen2b.png"

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
       

    def forward(self, img1, img2):
        output1 = self.layer1(img1)
        output2 = self.layer2(img2)
        output = torch.cat((output1, output2), dim=1)
        return output

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.utils as utils
import torch.nn.functional as F

# Define transforms
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.Grayscale(),  # Convert to grayscale
    transforms.ToTensor(),
])

# Load and transform images
img0 = transform(Image.open(image_path_0).convert("RGB"))
img1 = transform(Image.open(image_path_1).convert("RGB"))

# Pass images through the model
output1, output2 = model(img0.unsqueeze(0), img1.unsqueeze(0))
distance = F.pairwise_distance(output1, output2)

# Show images and distance
merged = torch.cat((img0, img1), 1)  # Concatenate along channel dimension

grid = utils.make_grid(merged)


imshow(grid, 'Dissimilarity: {:.2f}'.format(distance.item()))