In [None]:
# pip install torch facenet_pytorch tqdm

In [None]:
import os
import cv2
import torch
from facenet_pytorch import InceptionResnetV1, MTCNN
from tqdm import tqdm
from types import MethodType

In [None]:
def encode(img):
    res = resnet(torch.Tensor(img))
    return res

def detect_box(self, img, save_path=None):
    # Detect faces
    batch_boxes, batch_probs, batch_points = self.detect(img, landmarks=True)
    # Select faces
    if not self.keep_all:
        batch_boxes, batch_probs, batch_points = self.select_boxes(
            batch_boxes, batch_probs, batch_points, img, method=self.selection_method
        )
    # Extract faces
    faces = self.extract(img, batch_boxes, save_path)
    return batch_boxes, faces


### load model
resnet = InceptionResnetV1(pretrained='vggface2').eval()
mtcnn = MTCNN(
  image_size=224, keep_all=True, thresholds=[0.4, 0.5, 0.5], min_face_size=60
)
mtcnn.detect_box = MethodType(detect_box, mtcnn)

In [None]:
# Assuming you have already defined the 'encode' function and 'mtcnn' model

saved_pictures = "./dataset/"
all_people_faces = {}

for filename in os.listdir(saved_pictures):
    if filename.endswith('.jpg'):
        person_face, extension = os.path.splitext(filename)
        img = cv2.imread(os.path.join(saved_pictures, filename))
        cropped = mtcnn(img)
        
        if cropped is not None:
            all_people_faces[person_face] = encode(cropped)[0, :]


In [None]:
def detect(cam=0, thres=0.7):
    vdo = cv2.VideoCapture(cam)
    while vdo.grab():
        _, img0 = vdo.retrieve()
        batch_boxes, cropped_images = mtcnn.detect_box(img0)

        if cropped_images is not None:
            for box, cropped in zip(batch_boxes, cropped_images):
                x, y, x2, y2 = [int(x) for x in box]
                img_embedding = encode(cropped.unsqueeze(0))
                detect_dict = {}
                for k, v in all_people_faces.items():
                    detect_dict[k] = (v - img_embedding).norm().item()
                min_key = min(detect_dict, key=detect_dict.get)

                if detect_dict[min_key] >= thres:
                    min_key = 'Undetected'
                
                cv2.rectangle(img0, (x, y), (x2, y2), (0, 0, 255), 2)
                cv2.putText(
                  img0, min_key, (x + 5, y + 10), 
                   cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1)
                
        ### display
        cv2.imshow("output", img0)
        if cv2.waitKey(1) == ord('q'):
            cv2.destroyAllWindows()
            break

if __name__ == "__main__":
    detect(0)

In [2]:
import cv2
import os
import torch
from facenet_pytorch import InceptionResnetV1
from facenet_pytorch import MTCNN

def initialize_mtcnn():
    mtcnn = MTCNN(image_size=224, keep_all=True, thresholds=[0.4, 0.5, 0.5], min_face_size=60)
    return mtcnn

def encode(img, resnet_model):
    # Ensure that img is correctly formatted (RGB with 3 channels) and resize if needed
    if img.shape[-1] != 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if img.shape[:2] != (160, 160):
        img = cv2.resize(img, (160, 160))
    return resnet_model(torch.Tensor(img))

def detect_faces_and_recognize(cam=0, thres=0.7, saved_pictures="./dataset/"):
    resnet = InceptionResnetV1(pretrained='vggface2').eval()
    mtcnn = initialize_mtcnn()

    all_people_faces = {}
    for filename in os.listdir(saved_pictures):
        if filename.endswith('.jpg'):
            person_face, extension = os.path.splitext(filename)
            img = cv2.imread(os.path.join(saved_pictures, filename))

            img = encode(img, resnet)  # Ensure that saved images are correctly formatted
            all_people_faces[person_face] = img[0, :]

    vdo = cv2.VideoCapture(cam)
    while vdo.isOpened():
        ret, img0 = vdo.read()
        if not ret:
            break
        batch_boxes, _ = mtcnn.detect(img0)

        if batch_boxes is not None:
            for box in batch_boxes:
                x, y, x2, y2 = [int(x) for x in box]
                cropped = img0[y:y2, x:x2]
                img_embedding = encode(cropped, resnet)
                detect_dict = {}
                for k, v in all_people_faces.items():
                    detect_dict[k] = (v - img_embedding).norm().item()
                min_key = min(detect_dict, key=detect_dict.get)

                if detect_dict[min_key] >= thres:
                    min_key = 'Undetected'

                cv2.rectangle(img0, (x, y), (x2, y2), (0, 0, 255), 2)
                cv2.putText(img0, min_key, (x + 5, y + 10), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1)

        cv2.imshow("output", img0)
        if cv2.waitKey(1) == ord('q'):
            cv2.destroyAllWindows()
            break

if __name__ == "__main__":
    detect_faces_and_recognize(0)


RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 160, 160, 3] to have 3 channels, but got 160 channels instead