In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "Scripts")))
from config import Config
import cv2
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
import numpy as np
import time
import pickle
Load_DB = False


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = Config()

vectorDB = dict()
if Load_DB:
    with open('./vectorDb.pkl','rb') as f:vectorDB =  pickle.load(f)

mtcnn = MTCNN(image_size=160, margin=20, device=config.device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(config.device)

In [3]:
def make_entry (vectorDB,mtcnn,resnet,device=config.device): 
    id_ = input("Enter your ID")

    a = {id_: {
                "id_vectors": []
                }
            }
    
    print("Recording...") 
    a[id_]["id_vectors"] =  get_vector(mtcnn,resnet,device)
    print(f"Samples---{len(a[id_]["id_vectors"] )} running on {device}")
    vectorDB.update(a)


def get_vector(mtcnn,resnet,device):
    vectors = []
    cap = cv2.VideoCapture(0)

    n_samples = 50

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break

        percentage = (len(vectors)/n_samples)*100
        cv2.putText(frame, f"Recording: {int(percentage)}%", (10, frame.shape[0] - 10),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame_rgb)
        face_tensor = mtcnn(img)

        
        boxes, _ = mtcnn.detect(img)
        if boxes is not None:
            for box in boxes:
                box = [int(b) for b in box]
                cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)

        
        if face_tensor is not None:
            face_tensor = face_tensor.unsqueeze(0).to(device)
            with torch.no_grad():
                embedding = resnet(face_tensor)
            
            embedding_np = embedding.cpu().numpy().flatten()

            cv2.putText(frame, f"Recording data...", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
            vectors.append(embedding_np)
        if mtcnn(img) is None:
            cv2.putText(frame, "No face detected", (10, 30),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
        cv2.imshow('Live Face Authentication', frame)
        
        
        if percentage >= 98 or cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    return torch.tensor(np.array(vectors))
        
def cal_cosine(source_vec, target_vec):
    cosine_sim = F.cosine_similarity(source_vec, target_vec, dim=1)
    return torch.mean(cosine_sim)

def search_DB(ref = None):
    matched = dict()
    for k in vectorDB.keys():
        target = vectorDB[k]["id_vectors"]
        sim = cal_cosine(torch.tensor(ref), target.clone().detach()).item()
        matched.update({k:sim})
    id_ = max(matched, key=matched.get)
    val = matched[k]

    return id_,val  
        
def test():    
    cap = cv2.VideoCapture(0)
    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame_rgb)
        face_tensor = mtcnn(img)
        if face_tensor is not None:
            face_tensor = face_tensor.unsqueeze(0).to(config.device)
            with torch.no_grad():
                embedding = resnet(face_tensor)
        
            embedding_np = embedding.cpu().detach().flatten()
            id_,val = search_DB(embedding_np)
            
            if val>0.15:text = f"Hello {id_}... Confidence: {val:.2f}"
            else: text = "Who are You?"
            
            cv2.putText(frame, text, (10, 30),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)

        
        cv2.imshow("Webcam Feed", frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

In [6]:
make_entry (vectorDB,mtcnn,resnet)

Recording...
Samples---50 running on cuda


In [None]:
vectorDB.keys()

In [8]:
test()

  sim = cal_cosine(torch.tensor(ref), target.clone().detach()).item()
