In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from PIL import Image, ImageDraw, ImageFont
from datetime import datetime
from math import ceil, floor
from skimage.util import montage
from sklearn.metrics import confusion_matrix
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torchvision.transforms import Compose
from typing import *
import copy 
import cv2
import glob
import inspect
import itertools
import json
import math
import matplotlib.pyplot as plt 
import multiprocessing
import numpy as np
import os
import pandas as pd
import pathlib
import pickle
import random 
import scipy
import sklearn
import socket
import string
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim

### Read and process the data

In [None]:
from src import data_prep, dataset, train

In [None]:
df = data_prep.get_df_from_folder('/home/anuj/code/data/lfw_train')
df_train, df_val = data_prep.split_train_val(df)

In [None]:
np.sum(df_train.groupby('label').count() > 1), np.sum(df_val.groupby('label').count() > 1)

In [None]:
dataset_train, dataloader_train = dataset.get_dataloader(df_train, image_side=160, batch_size=5*24)
dataset_val, dataloader_val = dataset.get_dataloader(df_val, image_side=160, batch_size=5*24)

### Model, Optimizer, Loss

In [None]:
from src import models
from src.loss import ContrastiveLoss

In [None]:
device_id = 1

In [None]:
model = models.SiameseNet(160)
model = torch.nn.DataParallel(model, device_ids=[1]).cuda(device_id)

In [None]:
model.module.load_state_dict(torch.load('weights/face-siamese-crop.pt'))

In [None]:
def get_df_dist_vs_label(dataloader, n_iters=50):
    all_labels, all_dists = [], []

    for ix, batch in enumerate(dataloader):
        if ix >= n_iters:
            break
        images1, images2, labels = dataset.flatten(batch) 

        with torch.no_grad():
            feats1, feats2 = model(images1, images2)
            dist = torch.nn.functional.pairwise_distance(feats1, feats2)

        all_labels.extend(labels.data.cpu().numpy())
        all_dists.extend(dist.data.cpu().numpy())

    all_dists = np.array(all_dists)
    all_labels = np.array(all_labels)

#     df_dist_label = pd.DataFrame([all_dists, all_labels], index=['dist', 'label']).T
    return all_dists, all_labels

In [None]:
%%time
all_dists_train, all_labels_train = get_df_dist_vs_label(dataloader_train)

In [None]:
%%time
all_dists_val, all_labels_val = get_df_dist_vs_label(dataloader_val)

In [None]:
from sklearn.metrics import precision_recall_fscore_support

In [None]:
threshs = [0.1, 0.2, 0.3, 0.5, 0.8, 1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2, 3]

In [None]:
for thresh in threshs:
    preds = np.array(all_dists_train < thresh, dtype=np.int)
    f1s = precision_recall_fscore_support(all_labels_train, preds)[2]
    print(thresh, f1s, f1s.mean())

In [None]:
for thresh in threshs:
    preds = np.array(all_dists_val < thresh, dtype=np.int)
    f1s = precision_recall_fscore_support(all_labels_val, preds)[2]
    print(thresh, f1s, f1s.mean())

In [None]:
THRESH = 1.8

### Test data

In [None]:
df_test = data_prep.get_df_from_folder('/home/anuj/code/data/lfw_heldout/')

In [None]:
dataset_test, dataloader_test = dataset.get_dataloader(df_test, image_side=160, batch_size=5*24)

In [None]:
%%time
all_dists_test, all_labels_test = get_df_dist_vs_label(dataloader_test)

In [None]:
preds = np.array(all_dists_test < THRESH, dtype=np.int)
print(THRESH, precision_recall_fscore_support(all_labels_test, preds))

### Visualize

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

import predict


def show_images(ax, image_path1, image_path2, label):
    im1 = cv2.imread(image_path1)
    im2 = cv2.imread(image_path2)
    im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB)
    im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)

    image = np.hstack([im1, im2])
    ax.imshow(image.astype(np.uint8))
    ax.set_title(label)

def visualize_preds(dataset, n=10, start=0, threshold=1.8):
    plt.figure(figsize=(20, 50))
    for ix in range(n):
        idx1, idx2 = dataset[start+ix]['pairs'][ix%2]

        i1 = str(dataset.df.loc[idx1]['path'])
        i2 = str(dataset.df.loc[idx2]['path'])
        # !python predict.py -img1 {i1} -img2 {i2}
        dist, issame = predict.predict(i1, i2, model, transforms, 'cpu', threshold)
        label = 'Dis-similarity: {:.2f}, Same person?: {}'.format(dist, bool(issame))

        ax = plt.subplot(n, 2, ix+1)
        show_images(ax, i1, i2, label)
    plt.show()

In [None]:
model = predict.get_model('cpu', 'weights/face-siamese-crop.pt')
transforms = preprocess.get_transforms_inference()

In [None]:
visualize_preds(dataset_train, n=10, start=1000)

In [None]:
visualize_preds(dataset_val, n=10, start=1000)

In [None]:
visualize_preds(dataset_test, n=10)

In [None]:
import torch
from src.models import segnet

def fix_weights(path, out, device):
    state_dict = torch.load(path, map_location={'cuda:2': device})

    m = torch.nn.DataParallel(segnet.SiameseNetworkLarge(160))
    m.load_state_dict(state_dict)
    torch.save(m.module.state_dict(), out)