In [1]:
import os

import PIL
import torch
from sklearn.metrics import roc_auc_score
from skimage.io import imread
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

from settings.paths import LFW_PAIRS_6000, LFW_FUNNELED_DIR, LIGHT_CNN_9_WEIGHT
from src.light_cnn import LightCNN_9Layers

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = '6, 7'

In [3]:
def init_model(name='LightCNN_9', cuda=True):
    if name is 'LightCNN_9':
        model_class = LightCNN_9Layers
        num_classes=79077
    else:
        raise ValueError('No such model {}'.format(name))
        
    model = model_class(num_classes=num_classes)
    model.eval()
    
    if cuda:
        model = torch.nn.DataParallel(model).cuda()
    
    return model

In [4]:
def load_weights(model, weight_path):
    checkpoint = torch.load(weight_path)
    model.load_state_dict(checkpoint['state_dict'])
    return model

In [5]:
model = init_model(LightCNN_9, cuda=True)
model = load_weights(model, LIGHT_CNN_9_WEIGHT)

In [6]:
class LFWFunneled6000Pairs(Dataset):
    def __init__(self, lfw_funneled_dir, pairs_file, downscale=False):
        self._lfw_funneled_dir = lfw_funneled_dir
        self._pairs = open(pairs_file).read().split('\n')
        
        transforms_list = [
            transforms.ToPILImage(),
            transforms.Grayscale(),
            transforms.CenterCrop(128),
        ]
        if downscale is True:
            transforms_list.append(transforms.Resize(32, 0))
            transforms_list.append(transforms.Resize(128, PIL.Image.BICUBIC))
        transforms_list.append(transforms.ToTensor())
        
        self._transforms = transforms.Compose(transforms_list)
        
    def _make_filepath(self, name, number):
        image_name ='{}_{:04d}.jpg'.format(name, int(number))
        return os.path.join(self._lfw_funneled_dir, name, image_name)
        
    def __getitem__(self, index):
        pair = self._pairs[index + 1].split('\t')
        if len(pair) == 3:
            label = 1
            name, number_1, number_2 = pair

            image_path_1 = self._make_filepath(name, number_1)
            image_path_2 = self._make_filepath(name, number_2)
        elif len(pair) == 4:
            label = 0
            name_1, number_1, name_2, number_2 = pair

            image_path_1 = self._make_filepath(name_1, number_1)
            image_path_2 = self._make_filepath(name_2, number_2)

        return self._transforms(imread(image_path_1)), \
               self._transforms(imread(image_path_2)), \
               label

    def __len__(self):
        return 6000


In [7]:
def calculate_roc_auc(lwf_loader, model):
    labels = []
    thresholds = []
    
    for image_batch_1, image_batch_2, current_labels in tqdm(lwf_loader):
        _, features_1 = model(Variable(image_batch_1, volatile=True).cuda())
        _, features_2 = model(Variable(image_batch_2, volatile=True).cuda())
        
        current_thresholds = -torch.mean((features_1 - features_2) ** 2, dim=-1)
        
        thresholds.extend(current_thresholds.data.cpu().numpy())
        labels.extend(current_labels)
        
    return roc_auc_score(labels, thresholds)

In [8]:
lwf_dataset_hr = LFWFunneled6000Pairs(LFW_FUNNELED_DIR, LFW_PAIRS_6000, downscale=False)
lwf_loader_hr = DataLoader(
    dataset=lwf_dataset_hr,
    batch_size=16,
    shuffle=False
)

roc_auc_score_hr = calculate_roc_auc(lwf_loader_hr, model)

100%|██████████| 375/375 [00:49<00:00, 11.73it/s]


In [9]:
roc_auc_score_hr

0.9820735555555555

In [10]:
lwf_dataset_hr = LFWFunneled6000Pairs(LFW_FUNNELED_DIR, LFW_PAIRS_6000, downscale=True)
lwf_loader_hr = DataLoader(
    dataset=lwf_dataset_hr,
    batch_size=16,
    shuffle=False
)

roc_auc_score_hr = calculate_roc_auc(lwf_loader_hr, model)

100%|██████████| 375/375 [00:39<00:00,  8.12it/s]


In [11]:
roc_auc_score_hr

0.9352327777777778