In [1]:
pip install torch facenet-pytorch tqdm

Note: you may need to restart the kernel to use updated packages.


In [26]:
import os
import cv2
import torch
from facenet_pytorch import InceptionResnetV1, MTCNN
from tqdm import tqdm
from types import MethodType
import time

In [7]:
def encode(img):
    res = resnet(torch.Tensor(img))
    return res
def detect_box(self, img, save_path = None):
    #Detect faces
    batch_boxes, batch_probs, batch_points = self.detect(img, landmarks = True)
    #select faces
    if not self.keep_all:
        batch_boxes, batch_probs, batch_points = self.select_boxes(
            batch_boxes, batch_probs, batch_points, img, method = self.selection_method
        )
    #Extract faces
    faces = self.extract(img, batch_boxes, save_path)
    return batch_boxes, faces
# load model
resnet = InceptionResnetV1(pretrained='vggface2').eval()
mtcnn = MTCNN(
    image_size = 224, keep_all = True, thresholds = [0.4, 0.5, 0.5], min_face_size = 60
)
mtcnn.detect_box = MethodType(detect_box, mtcnn)

In [28]:
# get encoded features for all saved images
saved_pictures = "./Saved/"
all_people_faces = {}

'''for file in person_face, extension = file.split(".")
    img = cv2.imread(f"{saved_pictures}/{person_face}.jpg")
    cropped = mtcnn(img)
    if cropped is not None:
        all_people_faces[person_face] = encode(cropped)[0, :1]'''
for file in os.listdir(saved_pictures):
    if file.endswith('.jpg') or file.endswith('.png'):
        person_face = os.path.splitext(file)[0]
        image_path = os.path.join(saved_pictures, file)

        img = cv2.imread(image_path)
        if img is None:
            continue 
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        cropped = mtcnn(img_rgb)

        if cropped is not None:
            if len(cropped.shape) == 3:
                cropped = cropped.unsqueeze(0) # add batch dim

            embedding = encode (cropped).detach()
            all_people_faces[person_face] = embedding

In [29]:
all_people_faces

{'Archit yadav (2)': tensor([[ 0.0839, -0.0056, -0.0204,  0.0109, -0.0289, -0.0229,  0.0097, -0.0086,
           0.0435,  0.0315, -0.0126, -0.0049,  0.0456,  0.0076, -0.0047,  0.0735,
           0.0279, -0.0189, -0.0296,  0.0510, -0.0010, -0.0515, -0.0137, -0.0275,
          -0.0179,  0.0303, -0.0820, -0.0130, -0.0087, -0.0501,  0.0054,  0.0481,
          -0.0623, -0.0431,  0.0176, -0.0207,  0.0286, -0.0691, -0.0830,  0.0288,
           0.0187,  0.0634, -0.0264,  0.0116, -0.0422, -0.0469,  0.0240,  0.0720,
          -0.0488,  0.0040, -0.0049, -0.0493, -0.0150, -0.0623, -0.0634, -0.0082,
          -0.0159, -0.0663, -0.0147,  0.0142,  0.0322,  0.1242, -0.0502,  0.0075,
          -0.0277, -0.0008,  0.0654, -0.0214,  0.0423, -0.0183,  0.0367,  0.0292,
          -0.0734,  0.0163, -0.0095, -0.0449,  0.0008,  0.0521,  0.0290, -0.0136,
           0.0101, -0.0478, -0.1124,  0.0401,  0.0182, -0.0601,  0.0434,  0.0711,
          -0.0810, -0.0037,  0.0612, -0.0346,  0.0239, -0.0427,  0.0653,  0.04

In [31]:
def detect(cam = 0, thres = 0.7):
    vdo = cv2.VideoCapture(cam)

    # Initialize time for FPS calculation
    prev_time = time.time()
    
    while vdo.grab():
        _, img0 = vdo.retrieve()
        batch_boxes, cropped_images = mtcnn.detect_box(img0)
    
        if cropped_images is not None:
            for box, cropped in zip(batch_boxes, cropped_images):
                x, y, x2, y2 = [int(x) for x in box]
                if len(cropped.shape) == 3:
                    cropped = cropped.unsqueeze(0)
                img_embedding = encode(cropped)
                
                detect_dict= {}
                for k, v in all_people_faces.items():
                    detect_dict[k] = (v - img_embedding).norm().item()

                min_key = min(detect_dict, key = detect_dict.get)
                if detect_dict[min_key] >= thres:
                    min_key = 'Undetected'
    
                cv2.rectangle(img0, (x, y), (x2, y2), (0,0,255), 2)
                cv2.putText(
                    img0, min_key, (x+5, y+10),
                    cv2.FONT_HERSHEY_DUPLEX, 0.5, (255,255,255), 1)
    
        # ==== FPS Calculation and Display ====
        curr_time = time.time()
        fps = 1 / (curr_time - prev_time)
        prev_time = curr_time

        # Show FPS on the top-left of the screen
        cv2.putText(
            img0, f"FPS: {fps:.2f}", (10, 30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2
        )
        # display
        cv2.imshow('output', img0)
        if cv2.waitKey(1) == ord('q'):
            cv2.destroyAllWindows()
            break
if __name__ == '__main__':
    detect(0)