In [15]:
import torch
from torch_snippets import *
device='cuda' if torch.cuda.is_available() else 'cpu'
import random
import numpy as np
import glob

# define the dataset class

In [16]:
class siamese_network (Dataset):
    def __init__ (self,folder,transform=None,should_invert=True):
        self.folder=folder
        self.transform=transform
        self.items=Glob(f'{self.folder}/*/*')
        
    def __getitem__(self,ix):
        itemA=self.items[ix]
        person=str(self.items[ix]).split('\\')[-2]
        same_person = randint(2)
        if same_person:
            rando=random.randint(1,10)
            itemB=os.path.join(self.folder,person,'{}.png'.format(str(rando)))
        else: 
            while True:
                itemB=choose(self.items)
                if person != (str(itemB).split('\\')[-2]):
                    break
                    
        imgA=read(itemA)
        imgB=read(itemB)
        if self.transform:
            imgA=self.transform(imgA)
            imgB=self.transform(imgB)
            
        return imgA, imgB, np.array([1-same_person])
    
    def __len__(self):
        return len(self.items)
    
    
        
        

# define transformations needed in training and validation data

In [18]:
from torchvision import transforms

trn_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(5, (0.01,0.2),
                            scale=(0.9,1.1)),
    transforms.Resize((100,100)),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

val_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((100,100)),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [19]:
trn_ds = siamese_network (folder=r"C:\Users\mosta\Desktop\png faces\all_images\training", transform=trn_tfms)
val_ds = siamese_network (folder=r"C:\Users\mosta\Desktop\png faces\all_images\testing", transform=val_tfms)

2021-12-10 23:34:02.019 | INFO     | torch_snippets.paths:inner:24 - 370 files found at C:\Users\mosta\Desktop\png faces\all_images\training/*/*
2021-12-10 23:34:02.022 | INFO     | torch_snippets.paths:inner:24 - 30 files found at C:\Users\mosta\Desktop\png faces\all_images\testing/*/*


# make dataloader for both datasets with batch size=64

In [21]:
trn_dl = DataLoader(trn_ds, shuffle=True, batch_size=64)
val_dl = DataLoader(val_ds, shuffle=False, batch_size=64)

# define convolution block

In [23]:

def convBlock(ni, no):
    return nn.Sequential(
        nn.Dropout(0.2),
        nn.Conv2d(ni, no, kernel_size=3, padding=1, padding_mode='reflect'),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(no),
    )



# define architecture of siamese network 

In [24]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.features = nn.Sequential(
            convBlock(1,4),
            convBlock(4,12),
            convBlock(12,16),
            nn.Flatten(),
            nn.Linear(16*100*100,500), nn.ReLU(inplace=True),
            nn.Linear(500,500), nn.ReLU(inplace=True),
            nn.Linear(500,500), nn.ReLU(inplace=True),
            nn.Linear(500,5)
        )

    def forward(self, input1, input2):
        output1 = self.features(input1)
        output2 = self.features(input2)
        return output1, output2

# define contrastive loss function

In [25]:
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2,keepdim=True)
        pos = (1-label) * torch.pow(euclidean_distance, 2)
        neg = (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = torch.mean( pos + neg )
        acc = ((euclidean_distance>0.6)==label).float().mean()
        return loss_contrastive,acc

# define training batch function to get loss and accuracy at each batch 

In [26]:
def train_batch(model, data, optimizer, criterion):
    imgsA, imgsB, labels = [t.to(device) for t in data]
    optimizer.zero_grad()
    codesA, codesB = model(imgsA, imgsB)
    loss,acc = criterion(codesA, codesB, labels)
    loss.backward()
    optimizer.step()
    return loss.item(),acc.item()


@torch.no_grad()
def validate_batch(model, data, criterion):
    imgsA, imgsB, labels = [t.to(device) for t in data]
    codesA, codesB = model(imgsA, imgsB)
    loss,acc = criterion(codesA, codesB, labels)
    return loss.item(),acc.item()

In [27]:
model = SiameseNetwork().to(device)
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)

# train the model

In [28]:


n_epochs = 100
log = Report(n_epochs)
for epoch in range(n_epochs):
    N = len(trn_dl)
    for i, data in enumerate(trn_dl):
        loss,acc= train_batch(model, data, optimizer, criterion)
        log.record(epoch+(1+i)/N, trn_loss=loss, trn_acc=acc, end='\r')
     
    N = len(val_dl)
    for i, data in enumerate(val_dl):
        loss,acc= validate_batch(model, data, criterion) 
        log.record(epoch+(1+i)/N, val_loss=loss, val_acc=acc, end='\r')
    if (epoch+1)%20==0: log.report_avgs(epoch+1)
    if epoch==10: optimizer = optim.Adam(model.parameters(), lr=0.0005)    

EPOCH: 20.000	trn_loss: 0.469	trn_acc: 0.695	val_loss: 0.245	val_acc: 0.867	(25.08s - 100.34s remaining)
EPOCH: 40.000	trn_loss: 0.352	trn_acc: 0.776	val_loss: 0.150	val_acc: 0.900	(50.74s - 76.12s remaining)
EPOCH: 60.000	trn_loss: 0.281	trn_acc: 0.867	val_loss: 0.097	val_acc: 1.000	(76.46s - 50.97s remaining)
EPOCH: 80.000	trn_loss: 0.258	trn_acc: 0.864	val_loss: 0.109	val_acc: 0.900	(101.52s - 25.38s remaining)
EPOCH: 100.000	trn_loss: 0.196	trn_acc: 0.931	val_loss: 0.057	val_acc: 0.967	(127.03s - 0.00s remaining)
