In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import pandas as pd
import os

workers = 0 if os.name == 'nt' else 4

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cpu


In [3]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

In [4]:
resnet = InceptionResnetV1(
    classify=True,
    num_classes=1,
).to(device)

resnet.load_state_dict(torch.load('model_weights/model.pth'))

<All keys matched successfully>

In [5]:
def collate_fn(x):
    return x[0]

dataset = datasets.ImageFolder('../data/test_images_cropped')
dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)

In [6]:
aligned = []
names = []
for x, y in loader:
    x_aligned, prob = mtcnn(x, return_prob=True)
    if x_aligned is not None:
        print('Face detected with probability: {:8f}'.format(prob))
        aligned.append(x_aligned)
        names.append(dataset.idx_to_class[y])

Face detected with probability: 0.999999
Face detected with probability: 0.999993
Face detected with probability: 0.999985
Face detected with probability: 0.998590
Face detected with probability: 0.999997
Face detected with probability: 0.999817
Face detected with probability: 0.999987
Face detected with probability: 0.999577
Face detected with probability: 0.999992
Face detected with probability: 0.999877
Face detected with probability: 0.999999
Face detected with probability: 0.999975
Face detected with probability: 0.999994
Face detected with probability: 0.999723
Face detected with probability: 0.999794
Face detected with probability: 0.999891
Face detected with probability: 0.999945
Face detected with probability: 0.999992
Face detected with probability: 0.999952
Face detected with probability: 0.998611
Face detected with probability: 0.999939
Face detected with probability: 0.999980
Face detected with probability: 0.999993
Face detected with probability: 0.999999
Face detected wi

In [7]:
aligned = torch.stack(aligned).to(device)
embeddings = resnet(aligned).detach().cpu()
print(embeddings)
print(len(embeddings))

tensor([[ 0.8679],
        [-0.1288],
        [ 1.1762],
        [ 0.9454],
        [ 0.5870],
        [ 0.6509],
        [ 0.4021],
        [-0.3559],
        [-0.3252],
        [-0.3483],
        [-0.0775],
        [-0.3393],
        [-1.1245],
        [ 0.4625],
        [ 0.2653],
        [ 0.1825],
        [ 0.3078],
        [ 0.0404],
        [ 0.5282],
        [-0.6515],
        [-0.9897],
        [-0.1552],
        [-0.5227],
        [-0.1942],
        [-0.3633],
        [-0.5666],
        [-0.4121],
        [-0.4903],
        [ 0.5943],
        [ 1.1301],
        [ 0.4374],
        [-0.1890],
        [-0.1740],
        [ 0.2709],
        [ 0.6606],
        [-0.3859],
        [-0.5883]])
37


In [8]:
dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]
print(pd.DataFrame(dists, columns=names, index=names))

          ValidVic  ValidVic  ValidVic  ValidVic  ValidVic  ValidVic  \
ValidVic  0.000000  0.996750  0.308342  0.077513  0.280953  0.216957   
ValidVic  0.996750  0.000000  1.305093  1.074263  0.715798  0.779793   
ValidVic  0.308342  1.305093  0.000000  0.230830  0.589295  0.525300   
ValidVic  0.077513  1.074263  0.230830  0.000000  0.358465  0.294470   
ValidVic  0.280953  0.715798  0.589295  0.358465  0.000000  0.063995   
ValidVic  0.216957  0.779793  0.525300  0.294470  0.063995  0.000000   
ValidVic  0.465803  0.530947  0.774145  0.543316  0.184850  0.248846   
ValidVic  1.223805  0.227054  1.532147  1.301317  0.942852  1.006847   
ValidVic  1.193140  0.196389  1.501482  1.270652  0.912187  0.976182   
ValidVic  1.216203  0.219453  1.524545  1.293716  0.935250  0.999246   
ValidVic  0.945359  0.051392  1.253701  1.022871  0.664406  0.728401   
ValidVic  1.207232  0.210482  1.515574  1.284745  0.926279  0.990275   
ValidVic  1.992396  0.995645  2.300738  2.069908  1.711443  1.77