In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:

from PIL import Image, ImageDraw, ImageFont
from datetime import datetime
from math import ceil, floor
from skimage.util import montage
from sklearn.metrics import confusion_matrix
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torchvision.transforms import Compose
from typing import *
import copy 
import cv2
import glob
import inspect
import itertools
import json
import math
import matplotlib.pyplot as plt 
import multiprocessing
import numpy as np
import os
import pandas as pd
import pathlib
import pickle
import random 
import scipy
import socket
import string
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim

### Read the data and create a 

In [None]:
def get_df_from_folder(path: str) -> pd.DataFrame:
    all_data = []
    for ix, p in enumerate(pathlib.Path(path).glob('**/*.jpg')):
        all_data.append({
            'path': p.absolute(),
            'label': p.parent.name,
            'idx': ix,
        })
    df = pd.DataFrame(all_data)
    return df

In [None]:
df = get_df_from_folder('/home/anuj/code/data/lfw_train')

In [None]:
df.shape

In [None]:
df.head()

In [None]:
np.sum(df.groupby('label').count() > 1)

In [None]:
### Shuffle and split on the basis of images (pairs would be more ideal, but more difficult)
df = df.sample(frac=1, random_state=1111)
n = len(df)
n_train = int(0.8 * n)
df_train, df_val = df.iloc[:n_train], df.iloc[n_train: ]

In [None]:
np.sum(df_train.groupby('label').count() > 1), np.sum(df_val.groupby('label').count() > 1)

In [None]:
class PairDataset:
    """
    Given two images, the label is defined as:
        - 0: if the images belong to different people
        - 1: if the images belong to the same person

    We'll pick up pairs of images from the given set using the following strategy:
        1. Positive pair: randomly pick any positive pair from all the available samples
        2. Negative pair: for one of the images picked for the positive pair, find a negative pair and add
    """
    def __init__(self, df: pd.DataFrame) -> None:
        df.reset_index(inplace=True, drop=True)
        df.idx = df.index

        # Find unique names in the dataset
        dfg_by_label = df.groupby('label')

        df_counts = dfg_by_label.count().reset_index()
        names = set(df_counts.label)

        # Get all the positive pairs in the dataset
        idxs = dfg_by_label.idx.apply(list).values
        idx_by_person = list(filter(lambda x: len(x) > 1, idxs))

        pos_pairs = []
        for list_idx_of_person in idx_by_person:
            pairs = itertools.combinations(list_idx_of_person, 2)
            pos_pairs.extend(pairs)
        
        self.names = names
        self.pos_pairs = pos_pairs
        self.df = df
        self.labels = df.label.values
        self.images = [self._read_image(p) for p in df.path]
    
    @staticmethod
    def _read_image(p: str) -> np.ndarray:
        image = cv2.imread(str(p))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image
    
    def __len__(self) -> int:
        return len(self.pos_pairs)
    
    def _sample_neg_pair(self, pos_pair) -> Tuple[int, int]:
        idx = pos_pair[np.random.choice([0, 1])]
        label = self.labels[idx]
        neg_idx = np.random.choice(self.df.loc[self.df.label != label].index)
        return (idx, neg_idx)
    
    def _image(self, ix):
        image = self.images[ix]
        image = image.transpose(2, 0, 1).astype(np.float32)[np.newaxis, ...]
        return image
        
    def __getitem__(self, ix: int) -> Dict[str, Any]:
        pos_pair = self.pos_pairs[ix]
        neg_pair = self._sample_neg_pair(pos_pair)
        images1 = np.vstack([self._image(pos_pair[0]), self._image(neg_pair[0])])
        images2 = np.vstack([self._image(pos_pair[1]), self._image(neg_pair[1])])
        labels = np.array([1, 0]).astype(np.long)
        
        return {
            'images1': images1,
            'images2': images2,
            'labels': labels,
        }

In [None]:
def get_dataloader(df, batch_size):
    dataset = PairDataset(df.copy())
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    print('Training: {:,} total positive pairs {:,} mini batches'.format(len(dataset), len(dataloader)))
    return dataset, dataloader

In [None]:
def flatten(batch):
    bsz, n, c, h, w = batch['images1'].shape
    images1, images2, labels = batch['images1'].view(-1, c, h, w), batch['images2'].view(-1, c, h, w), batch['labels'].view(-1)
    return images1, images2, labels

def visualize(batch):
    images1, images2, labels = flatten(batch)

    images1 = images1.data.cpu().numpy().transpose(0,2,3,1).astype(np.uint8)
    images2 = images2.data.cpu().numpy().transpose(0,2,3,1).astype(np.uint8)

    images = []
    for ix in range(images1.shape[0]):
        images.append(images1[ix])
        images.append(images2[ix])

    print(labels)
    plt.imshow(montage(images, grid_shape=(images1.shape[0], 2), multichannel=True))
    plt.show()

In [None]:
# dataset_train, dataloader_train = get_dataloader(df_train, 1)
dataset_val, dataloader_val = get_dataloader(df_val, 1)

### Visualize some

In [None]:
for ix, batch in enumerate(dataloader_train):
    if ix >= 2:
        break
    visualize(batch)