In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "Scripts")))

import cv2
import torch
from PIL import Image
import numpy as np
import time
import matplotlib.pyplot as plt
from PIL import Image

import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
from torchvision import transforms

from config import Config
from models import AutoEncoder
from utils import load_model
from dataset import get_dataloaders

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
config = Config()
vectorDB = dict()

model = AutoEncoder(config.latent_dim, config.image_size).to(config.device)
model.encoder.load_state_dict(torch.load(config.Encoder_path))
model.decoder.load_state_dict(torch.load(config.Decoder_path))

transform = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

In [14]:
def make_entry (): 
    id_ = input("Enter your ID")

    a = {id_: {
                "id_vectors": []
                }
            }
    
    print("Recording...") 
    a[id_]["id_vectors"] =  get_vector()
    print(f"Samples---{len(a[id_]["id_vectors"] )} running on {config.device}")
    return a

def get_vector():

    vectors = []
    cap = cv2.VideoCapture(0)

    n_samples = 50

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break
        
        percentage = (len(vectors)/n_samples)*100
        
        cv2.putText(frame, f"Recording: {int(percentage)}%", (10, frame.shape[0] - 10),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        

        image = transform(Image.fromarray(frame_rgb)).unsqueeze(0).to(config.device)
        lat = model.encoder(image)
        embedding_np =  lat[0].detach().cpu().numpy()

        cv2.putText(frame, f"Recording data...", (10, 30),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
        vectors.append(embedding_np)
        cv2.imshow('Live Face Authentication', frame)
        
        if percentage >= 98 or cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    return torch.tensor(np.array(vectors))  

def test():
    cap = cv2.VideoCapture(0)
    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame_rgb)
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = transform(Image.fromarray(frame_rgb)).unsqueeze(0).to(config.device)
        with torch.no_grad():
            lat = model.encoder(image)
        embedding_np =  lat[0].detach().cpu().numpy()
            
        id_,val = search_DB(embedding_np)
        if val>0.:text = f"Hello {id_}---- Confidence: {val:.2f}"
        else:text = "Who are you?"
        cv2.putText(frame, text, (10, 30),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
        cv2.imshow("Webcam Feed", frame)

        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

def cal_cosine(source_vec, target_vec):
    cosine_sim = F.cosine_similarity(source_vec, target_vec, dim=1)
    return torch.mean(cosine_sim)

def search_DB(ref = None):
    matched = dict()
    for k in vectorDB.keys():
        target = vectorDB[k]["id_vectors"]
        sim = cal_cosine(torch.tensor(ref), target.clone().detach()).item()
        matched.update({k:sim})
    id_ = max(matched, key=matched.get)
    val = matched[k]

    return id_,val
        

In [9]:
a = make_entry ()
vectorDB.update(a)

Recording...
Samples---50 running on cuda


In [10]:
vectorDB.keys()

dict_keys(['Mahesh', 'mahesh', 'Maryna', 'maryna'])

In [15]:
test()