In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import pandas as pd
import os
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))
workers = 0 if os.name == 'nt' else 4

Running on device: cpu


In [2]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

In [3]:
from PIL import Image
def collate_fn(x):
    return x[0]

dataset = datasets.ImageFolder('./images')
print({i:c for c, i in dataset.class_to_idx.items()})
dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)

{0: 'Dwight', 1: 'Michael'}


In [4]:
aligned = []
names = []
for x, y in loader:
    x_aligned, prob = mtcnn(x, return_prob=True)
    if x_aligned is not None:
        print('Face detected with probability: {:8f}'.format(prob))
        aligned.append(x_aligned)

        names.append(dataset.idx_to_class[y])

aligned = torch.stack(aligned).to(device)
embeddings = resnet(aligned)

Face detected with probability: 0.999847
Face detected with probability: 0.999954
Face detected with probability: 0.999962
Face detected with probability: 0.999998


In [5]:
dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]
print(pd.DataFrame(dists, columns=names, index=names))

           Dwight    Dwight   Michael   Michael
Dwight   0.000000  0.496898  1.555670  1.505317
Dwight   0.496898  0.000000  1.494655  1.475376
Michael  1.555670  1.494655  0.000000  0.511582
Michael  1.505317  1.475376  0.511582  0.000000


In [6]:

img = Image.open("images/Dwight/Dwight2.jpg")

# Get cropped and prewhitened image tensor
img_cropped = mtcnn(img, save_path="out/Dwight2.jpg")
aligned = torch.stack([img_cropped]).to(device)
# Calculate embedding (unsqueeze to add batch dimension)
img_embedding = resnet(aligned)
for embedding, name in zip(embeddings,names):
    print((embedding-img_embedding[0]).norm().item(),name)

0.496898353099823 Dwight
3.0352083513207617e-07 Dwight
1.4946551322937012 Michael
1.475375771522522 Michael
