# Determine if an nvidia GPU is available

In [36]:
import torch
import cv2

In [37]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on device: ", device)

Running on device:  cuda:0


# Load models

In [38]:
from facenet_pytorch import MTCNN, InceptionResnetV1
from types import MethodType

In [39]:
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
mtcnn = MTCNN(keep_all=True, device=device)

def detect_box(self, img, save_path=None):
    # Detect faces
    batch_boxes, batch_probs, batch_points = self.detect(img, landmarks=True)
    # Select faces
    if not self.keep_all:
        batch_boxes, batch_probs, batch_points = self.select_boxes(
            batch_boxes, batch_probs, batch_points, img, method=self.selection_method
        )
    # Extract faces
    faces = self.extract(img, batch_boxes, save_path)
    return batch_boxes, faces

mtcnn.detect_box = MethodType(detect_box, mtcnn)

# Get the saved faces

In [40]:
from torch.utils.data import DataLoader
from torchvision import datasets
import os

workers = 0 if os.name == 'nt' else 4

In [41]:
dataset = datasets.ImageFolder("./faces")
idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset ,collate_fn=lambda x: x[0], num_workers=workers)
idx_to_class.items()

dict_items([(0, 'Mayank'), (1, 'Seema'), (2, 'Vaibhav')])

# Get encoded features of all saved images

In [42]:
aligned_faces = []
labels = []

for image, label in loader:
    image_aligned = mtcnn(image)

    if image_aligned is not None:
        aligned_faces.append(image_aligned)
        labels.append(label)

aligned_faces = torch.cat(aligned_faces, dim=0).to(device)
embeddings = resnet(aligned_faces).detach().cpu()

# Classify image

In [43]:
import torch.nn.functional as F

In [44]:
def classify_face(image, threshold=0.7):
    embedding = resnet(image.to(device)).detach().cpu().reshape(1, -1)

    similarity = F.cosine_similarity(x1=embeddings, x2=embedding).reshape(-1)
    max_index = similarity.argmax().item()

    return labels[max_index] if similarity[max_index] > threshold else None

# Live Detection

In [45]:
def detect(cam=0, threshold=0.6):
    vid = cv2.VideoCapture(cam)

    while vid.grab():
        _, img, = vid.retrieve()
        batch_boxes, aligned_images = mtcnn.detect_box(img)

        if aligned_images is not None:
            for box, aligned in zip(batch_boxes, aligned_images):
                aligned = torch.Tensor(aligned.unsqueeze(0))
                x1, y1, x2, y2 = [int(x) for x in box]

                idx = classify_face(image=aligned, threshold=threshold)
                idx = idx_to_class[idx] if idx is not None else "Unknown"

                cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
                cv2.putText(img, idx, (x1 + 5, y1 + 10), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1)

        cv2.imshow("Face Recognition", img)
        if cv2.waitKey(1) == ord("q"):
            cv2.destroyAllWindows()
            break

In [46]:
detect(0)