In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from PIL import Image, ImageDraw, ImageFont
from datetime import datetime
from math import ceil, floor
from skimage.util import montage
from sklearn.metrics import confusion_matrix
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torchvision.transforms import Compose
from typing import *
import copy 
import cv2
import glob
import inspect
import itertools
import json
import math
import matplotlib.pyplot as plt 
import multiprocessing
import numpy as np
import os
import pandas as pd
import pathlib
import pickle
import random 
import scipy
import sklearn
import socket
import string
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim

### Read and process the data

In [3]:
from src import data_prep, dataset, train

In [4]:
df = data_prep.get_df_from_folder('/home/anuj/code/data/lfw_train')
df_train, df_val = data_prep.split_train_val(df)

In [5]:
np.sum(df_train.groupby('label').count() > 1), np.sum(df_val.groupby('label').count() > 1)

(idx     1356
 path    1356
 dtype: int64, idx     340
 path    340
 dtype: int64)

In [6]:
dataset_train, dataloader_train = dataset.get_dataloader(df_train, batch_size=5*24)
dataset_val, dataloader_val = dataset.get_dataloader(df_val, batch_size=5*24)

Training: 327,846 total positive pairs 2,733 mini batches
Training: 15,200 total positive pairs 127 mini batches


### Visualize some

### Model, Optimizer, Loss

In [9]:
from src.models import segnet
from src.loss import ContrastiveLoss

In [10]:
device_id = 1

In [11]:
model = segnet.SiameseNetworkLarge(256)
model = torch.nn.DataParallel(model, device_ids=[1]).cuda(device_id)

In [24]:
model.module.load_state_dict(torch.load('face-contrastive-2.04-7500.pt'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [22]:
def get_df_dist_vs_label(dataloader, n_iters=50):
    all_labels, all_dists = [], []

    for ix, batch in enumerate(dataloader):
        if ix >= n_iters:
            break
        images1, images2, labels = dataset.flatten(batch) 

        with torch.no_grad():
            feats1, feats2 = model(images1, images2)
            dist = torch.nn.functional.pairwise_distance(feats1, feats2)

        all_labels.extend(labels.data.cpu().numpy())
        all_dists.extend(dist.data.cpu().numpy())

    all_dists = np.array(all_dists)
    all_labels = np.array(all_labels)

#     df_dist_label = pd.DataFrame([all_dists, all_labels], index=['dist', 'label']).T
    return all_dists, all_labels

In [15]:
%%time
all_dists_train, all_labels_train = get_df_dist_vs_label(dataloader_train)

CPU times: user 58.7 s, sys: 33 s, total: 1min 31s
Wall time: 1min 46s


In [16]:
%%time
all_dists_val, all_labels_val = get_df_dist_vs_label(dataloader_val)

CPU times: user 58.5 s, sys: 32.5 s, total: 1min 30s
Wall time: 1min 38s


In [17]:
from sklearn.metrics import precision_recall_fscore_support

In [18]:
threshs = [0.1, 0.5, 0.8, 1, 1.2, 1.5, 2, 3]

In [19]:
for thresh in threshs:
    preds = np.array(all_dists_train < thresh, dtype=np.int)
    print(thresh, precision_recall_fscore_support(all_labels_train, preds)[2])

0.1 [0.67609602 0.08090822]
0.5 [0.9333438  0.92462044]
0.8 [0.96024942 0.95898065]
1 [0.96562656 0.96570668]
1.2 [0.96620278 0.96719898]
1.5 [0.95368485 0.95710101]
2 [0.60005822 0.77716139]
3 [0.22483544 0.69601787]


In [20]:
for thresh in threshs:
    preds = np.array(all_dists_val < thresh, dtype=np.int)
    print(thresh, precision_recall_fscore_support(all_labels_val, preds)[2])

0.1 [0.70860653 0.30236176]
0.5 [0.85651984 0.80726421]
0.8 [0.88963495 0.8705062 ]
1 [0.90359183 0.89495451]
1.2 [0.91119468 0.90970968]
1.5 [0.90288256 0.91116518]
2 [0.63617487 0.78519598]
3 [0.21466825 0.69446695]


In [30]:
THRESH = 1.2

### Test data

In [27]:
df_test = data_prep.get_df_from_folder('/home/anuj/code/data/lfw_heldout/')

In [28]:
dataset_test, dataloader_test = dataset.get_dataloader(df_test, batch_size=5*24)

Training: 422 total positive pairs 4 mini batches


In [29]:
%%time
all_dists_test, all_labels_test = get_df_dist_vs_label(dataloader_test)

CPU times: user 2.07 s, sys: 1.24 s, total: 3.31 s
Wall time: 4.42 s


In [35]:
preds = np.array(all_dists_test < 1.5, dtype=np.int)
print(THRESH, precision_recall_fscore_support(all_labels_test, preds))

1.2 (array([0.74429224, 0.7635468 ]), array([0.77251185, 0.73459716]), array([0.75813953, 0.74879227]), array([422, 422]))
