In [None]:
import cv2, numpy as np, torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device:", device)

# MTCNN on CPU to avoid MPS adaptive pooling issues
# Resize mặt crop về đúng 160x160 để đưa vào FaceNet InceptionResnetV1
mtcnn = MTCNN(image_size=160, margin=20, keep_all=True, device="cpu")
# InceptionResnetV1 on GPU/MPS for faster embeddings
resnet = InceptionResnetV1(pretrained="vggface2").eval().to(device)

THRESH = 0.7
ref_emb = None

In [None]:
def safe_close_windows(n=8):
    cv2.destroyAllWindows()
    for _ in range(n):
        cv2.waitKey(1)

def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    if a.dim() == 1: a = a.unsqueeze(0)
    if b.dim() == 1: b = b.unsqueeze(0)
    a = F.normalize(a, p=2, dim=1)
    b = F.normalize(b, p=2, dim=1)
    return float((a * b).sum(dim=1).item())

def detect_faces_mtcnn(rgb: np.ndarray):
    boxes, probs = mtcnn.detect(rgb)
    return boxes, probs

def crop_faces_mtcnn(rgb: np.ndarray):
    faces = mtcnn(rgb)
    return faces

@torch.no_grad()
def embed_faces(face_tensor: torch.Tensor):
    if face_tensor is None:
        return None
    if face_tensor.dim() == 3:
        face_tensor = face_tensor.unsqueeze(0)
    face_tensor = face_tensor.to(device)
    return resnet(face_tensor)  # (N,512)

def label_from_similarity(sim: float, thresh=THRESH):
    return "Matched" if sim > thresh else "Unknown"

def draw_boxes(frame: np.ndarray, boxes, labels=None, probs=None):
    if boxes is None:
        return frame
    for i, box in enumerate(boxes):
        x1,y1,x2,y2 = map(int, box)
        label_text = ""
        color = (0,255,0)
        if probs is not None:
            label_text += f"{probs[i]:.2f} "
        if labels is not None:
            name, sim = labels[i]
            label_text += f"{name} sim={sim:.2f}"
            color = (0,255,0) if name=="Matched" else (0,0,255)

        cv2.rectangle(frame, (x1,y1), (x2,y2), color, 2)
        if label_text:
            cv2.putText(frame, label_text, (x1, max(20, y1-10)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
    return frame


In [None]:
ref_emb = None
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    raise RuntimeError("Không mở được webcam")

print("Nhấn 'c' để capture reference, 'q' để thoát.")
win = "Capture Reference"
cv2.namedWindow(win, cv2.WINDOW_NORMAL)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    boxes, probs = detect_faces_mtcnn(rgb)

    # chỉ vẽ để xem
    preview = frame.copy()
    draw_boxes(preview, boxes, probs=probs)
    cv2.imshow(win, preview)

    key = cv2.waitKey(30)
    if key == ord('q') or key == 27:
        break

    if key == ord('c'):
        if boxes is None:
            print("Chưa thấy mặt rõ. Thử lại.")
            continue

        # chọn mặt lớn nhất
        areas = (boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1])
        idx = int(np.argmax(areas))

        faces = crop_faces_mtcnn(rgb)
        if faces is None:
            print("Không crop được mặt. Thử lại.")
            continue

        ref_emb = embed_faces(faces[idx]).squeeze(0)  # (512,)
        print("Đã capture reference embedding.")
        break

cap.release()
safe_close_windows()


In [None]:
if ref_emb is None:
    raise RuntimeError("Chưa có reference embedding.")

cap = cv2.VideoCapture(0)
win = "Realtime Recognition"
cv2.namedWindow(win, cv2.WINDOW_NORMAL)

k = 3
frame_id = 0
last_boxes, last_labels = None, None

while True:
    ret, frame = cap.read()
    if not ret:
        break

    if frame_id % k == 0:
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        boxes, probs = detect_faces_mtcnn(rgb)
        faces = crop_faces_mtcnn(rgb)

        last_boxes, last_labels = None, None
        if boxes is not None and faces is not None:
            n = min(len(boxes), faces.shape[0])
            embs = embed_faces(faces[:n])  # (n,512)
            labels = []
            for i in range(n):
                sim = cosine_similarity(ref_emb, embs[i])
                labels.append((label_from_similarity(sim), sim))
            last_boxes, last_labels = boxes[:n], labels

    out = frame.copy()
    if last_boxes is not None:
        draw_boxes(out, last_boxes, labels=last_labels)

    cv2.imshow(win, out)
    if cv2.getWindowProperty(win, cv2.WND_PROP_VISIBLE) < 1:
        break

    key = cv2.waitKey(30)
    if key == ord('q') or key == 27:
        break

    frame_id += 1

cap.release()
safe_close_windows()
