In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import torchvision.datasets as datasets
import cv2
import numpy as np
import matplotlib.pyplot as plt
from facenet_pytorch import MTCNN
import random
from PIL import Image

In [2]:
transform=transforms.Compose([
    transforms.Resize((64,64)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    
])

In [3]:
data=datasets.ImageFolder("D:\\si\\Faces for Training\\Faces for Training",transform=transform)

In [4]:
device=torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [5]:
class siamese(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Conv2d(1,96,(11,11)),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Dropout(0.3),
            nn.Conv2d(96,256,(5,5)),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Dropout(0.3),
            nn.Conv2d(256,384,(3,3)),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Dropout(0.3),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Linear(384,1024),
            nn.Linear(1024,128)
        )
    def run(self,x):
        o=self.model(x)
        return o
    def forward(self,x1,x2):
        o1=run(x1)
        o2=run(x2)
        return o1,o2
        

In [6]:
model=siamese()

In [7]:

class SiameseNetworkDataset(Dataset):
    def __init__(self,imageFolderDataset,transform=None):
        super().__init__()
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)

       
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:

            while True:
                
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] != img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])

        img0 = img0.convert("L")
        img1 = img1.convert("L")

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        
        return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [8]:
sdata=SiameseNetworkDataset(data,transform)

In [9]:
data_loader=DataLoader(sdata,8,shuffle=True,num_workers=2)

In [10]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
       
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [11]:
net = siamese().cuda()
criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.0005 )

In [None]:
counter = []
loss_history = [] 
iteration_number= 0

for epoch in range(100):


    for i, (img0, img1, label) in enumerate(data_loader, 0):

       
        img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()

        
        optimizer.zero_grad()

        
        output1, output2 = net(img0, img1)

        
        loss_contrastive = criterion(output1, output2, label)

       
        loss_contrastive.backward()

        
        optimizer.step()

        
        if i % 10 == 0 :
            print(f"Epoch number {epoch}\n Current loss {loss_contrastive.item()}\n")
            iteration_number += 10

            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())

show_plot(counter, loss_history)