In [55]:
!pip install pytorch-lightning
!pip install transformers
!pip install timm
!pip install -qq albumentations==1.0.3

In [56]:
import cv2
import os
import random
import pandas as pd
import numpy as np
import timm
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning import Callback
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report

In [57]:
class CFG:
    model_name = 'tf_efficientnet_b0' #efficient net.
    train_path = '../input/face-images/faces/training'
    test_path = '../input/face-images/faces/testing'
    save_dir = './'
    val_size = 0.2
    seed = 42
    batch_size = 16
    lr = 1e-4
    monitor = 'val_loss'
    patience = 3
    epochs = 4
    accumulate = 1
    loss_margin = 2.0
    embed_dims = 64
    
    similarity_factor = 3
    dissimilarity_factor = 3
    test_similarity_factor = 2
    test_dissimilarity_factor = 2

In [58]:
class SiameseDataset(Dataset):

  def __init__(self, image_df, transform=None): #image_df contains image_id1, image_id2, label
    self.image_df = image_df
    self.transform = transform

  def __len__(self):
    return self.image_df.shape[0]

  def __getitem__(self, index):
    image1 = cv2.imread(self.image_df.image1.iloc[index])
    image2 = cv2.imread(self.image_df.image2.iloc[index])
    if self.transform is not None:
        image1 = self.transform(image=image1)['image']
        image2 = self.transform(image=image2)['image']
    label = self.image_df.label.iloc[index]
    return (image1, image2, label)

In [59]:
class SiameseNetwork(nn.Module):

  def __init__(self, model_name):
    super(SiameseNetwork, self).__init__()
    self.embed_model = timm.create_model(model_name, pretrained=True)
    feats = self.embed_model.classifier.in_features
    self.embed_model.classifier = nn.Linear(feats, CFG.embed_dims)

  def forward(self, image1, image2):
    img1_embeds = self.embed_model(image1)
    img2_embeds = self.embed_model(image2)
    return img1_embeds, img2_embeds

class SiameseNetworkDirver(pl.LightningModule):

    def __init__(self, model, criterion, lr):
        super(SiameseNetworkDirver, self).__init__()
        self.model = model
        self.criterion = criterion

    def forward(self, images1, images2):
        return self.model(image1, image2)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=CFG.lr)
        return self.optimizer

    def training_step(self, batch, batch_idx):
        image1, image2, labels = batch[0], batch[1], batch[2]
        img1_embeds, img2_embeds = self.model(image1, image2)
        loss = self.criterion(img1_embeds, img2_embeds, labels)
        logs = { 'train_loss': loss, 'lr': self.optimizer.param_groups[0]['lr'] }
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        image1, image2, labels = batch[0], batch[1], batch[2]
        img1_embeds, img2_embeds = self.model(image1, image2)
        loss = self.criterion(img1_embeds, img2_embeds, labels)
        logs = { 'val_loss': loss }
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def predict_step(self, batch, batch_idx):
        image1, image2 = batch[0], batch[1]
        img1_embeds, img2_embeds = self.model(image1, image2)
        return img1_embeds, img2_embeds

In [60]:
def train_transform_object(DIM = 384):
    return albumentations.Compose(
        [
            albumentations.Resize(DIM,DIM),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(p=1.0),
        ]
    )

def valid_transform_object(DIM = 384):
    return albumentations.Compose(
        [
            albumentations.Resize(DIM,DIM),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(p=1.0)
        ]
    )

In [61]:
class SimilarityLoss(nn.Module):

  def __init__(self, margin=2.0):
    super(SimilarityLoss, self).__init__()
    self.margin = margin

  def forward(self, img1_embeds, img2_embeds, label):
    euclidean_distance = F.pairwise_distance(img1_embeds, img2_embeds, keepdim=True)
    loss = torch.mean((label*torch.pow(euclidean_distance, 2) + 
                      (1 - label)*torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)))
    return loss

In [62]:
def append_to_data(path, image, imgs, label, data):
    image_path = os.path.join(path, image)
    for img in imgs:
        img_path = os.path.join(path, img)
#         if [img_path, image_path, label] not in data:
#             data.append([image_path, img_path, label])
        data.append([image_path, img_path, label])
    return data

def create_dataset(path, s_factor, ds_factor):
    data = []   

    folders = os.listdir(path)
    if 'README' in folders:
        folders.remove('README')

    file_names = {}
    for folder in folders:
        files = os.listdir(os.path.join(path, folder))
        images = [os.path.join(folder, file) for file in files if file.endswith('.png')]
        file_names[folder] = images

    for folder in tqdm(folders):
        images = file_names[folder]
        temp = folders.copy()
        temp.remove(folder)
        for image in images:
            imgs = random.sample(images, s_factor + 1)
            if image in imgs:
                imgs.remove(image)
            else:
                imgs = imgs[:-1]
            data = append_to_data(path, image, imgs, 1, data)

            sel_fs = random.sample(temp, ds_factor)
            imgs = []
            for f in sel_fs:
                imgs.append(random.sample(file_names[f], 1)[0])
            data = append_to_data(path, image, imgs, 0, data)
    
    return data


In [63]:
data = create_dataset(CFG.train_path, CFG.similarity_factor, CFG.dissimilarity_factor)
dataframe = pd.DataFrame(data, columns=['image1', 'image2', 'label'])
print(dataframe.shape)
dataframe.head()

In [64]:
dataframe.label.value_counts()

In [65]:
labels = dataframe.label.values
dataframe.drop('label', inplace=True, axis=1)
train_data, val_data, train_labels, val_labels = train_test_split(dataframe, labels, 
                                                                  stratify=labels, 
                                                                  test_size=CFG.val_size, 
                                                                  random_state=CFG.seed)
train_data['label'] = train_labels
val_data['label'] = val_labels

transform = train_transform_object(256)
train_dataset = SiameseDataset(train_data, transform)
val_dataset = SiameseDataset(val_data, transform)

train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False)

In [66]:
logger = CSVLogger(save_dir=CFG.save_dir, name=CFG.model_name+'_logs')
logger.log_hyperparams(CFG.__dict__)

checkpoint_callback = ModelCheckpoint(monitor=CFG.monitor,
                                      save_top_k=1,
                                      save_last=True,
                                      save_weights_only=True,
                                      filename='{epoch:02d}-{valid_loss:.4f}-{valid_acc:.4f}',
                                      verbose=False,
                                      mode='min')

early_stop_callback = EarlyStopping(monitor=CFG.monitor, 
                                    patience=CFG.patience, 
                                    verbose=False, 
                                    mode="min")

trainer = Trainer(
    max_epochs=CFG.epochs,
    gpus=[0],
    accumulate_grad_batches=CFG.accumulate,
    callbacks=[checkpoint_callback, early_stop_callback], 
    logger=logger,
    weights_summary='top',
)

In [67]:
model = SiameseNetwork(CFG.model_name)
criterion = SimilarityLoss(margin=CFG.loss_margin)
driver = SiameseNetworkDirver(model, criterion, CFG.lr)

trainer.fit(driver, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [68]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')

fig, axes = plt.subplots(1,2, figsize = (12,4))

train_loss = metrics['train_loss'].dropna().reset_index(drop=True)
val_loss = metrics['val_loss'].dropna().reset_index(drop=True)

axes[0].grid(True)
axes[0].plot(train_loss, color="r", marker="o", label='train/loss')
axes[0].plot(val_loss, color="b", marker="x", label='valid/loss')
axes[0].legend(loc='upper right', fontsize=9)

lr = metrics['lr'].dropna().reset_index(drop=True)

axes[1].grid(True)
axes[1].plot(lr, color="g", marker="o", label='learning rate')
axes[1].legend(loc='upper right', fontsize=9)

In [69]:
data = create_dataset(CFG.test_path, CFG.test_similarity_factor, CFG.test_dissimilarity_factor)
test_data = pd.DataFrame(data, columns=['image1', 'image2', 'label'])
print(test_data.shape)
test_data.head()

In [70]:
test_data.label.value_counts()

In [71]:
test_dataset = SiameseDataset(test_data, transform)
test_dataloader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False)

In [72]:
predictions = trainer.predict(dataloaders=test_dataloader)

In [73]:
distance = []
img1_embeds, img2_embeds = [], []
for batch in predictions:
    img1_embeds = batch[0]
    img2_embeds = batch[1]
    euclidean_distance = F.pairwise_distance(img1_embeds, img2_embeds, keepdim=True)
    distance.append(euclidean_distance)
    
preds_dec = []
for batch in distance:
    preds_dec += batch.squeeze(1).tolist()
    
preds = [1 if pred < 1 else 0 for pred in preds_dec]
true = test_data.label.values

In [74]:
round(accuracy_score(true, preds), 2)

In [75]:
print(classification_report(true, preds))

In [80]:
fig, axes = plt.subplots(1,2, figsize=(8, 5))

image1 = Image.open(test_data.image1.iloc[0])
axes[0].imshow(image1)

image2 = Image.open(test_data.image2.iloc[0])
axes[1].imshow(image2)

print('Predicted Euclidean Distance : {}'.format(preds_dec[0]))
print('Actual Label : Same Person')

In [79]:
fig, axes = plt.subplots(1,2, figsize=(8, 5))

image1 = Image.open(test_data.image1.iloc[2])
axes[0].imshow(image1)

image2 = Image.open(test_data.image2.iloc[2])
axes[1].imshow(image2)

print('Predicted Euclidean Distance : {}'.format(preds_dec[2]))
print('Actual Label : Different Person')

In [86]:
fig, axes = plt.subplots(1,2, figsize=(8, 5))

image1 = Image.open(test_data.image1.iloc[86])
axes[0].imshow(image1)

image2 = Image.open(test_data.image2.iloc[86])
axes[1].imshow(image2)

print('Predicted Euclidean Distance : {}'.format(preds_dec[86]))
print('Actual Label : Different Person')

In [84]:
fig, axes = plt.subplots(1,2, figsize=(8, 5))

image1 = Image.open(test_data.image1.iloc[68])
axes[0].imshow(image1)

image2 = Image.open(test_data.image2.iloc[68])
axes[1].imshow(image2)

print('Predicted Euclidean Distance : {}'.format(preds_dec[68]))
print('Actual Label : Same Person')