In [None]:
import torch
import torch.nn.functional as F
from transformers import DPTFeatureExtractor, DPTForDepthEstimation

from modello_antispoofing import (
    AntiSpoofingModel,
    RGBEncoder,
    DepthEncoder
)

from insightface.app import FaceAnalysis
from ultralytics import YOLO

import os
from dotenv import load_dotenv

import cv2
import numpy as np
from collections import deque
from datetime import datetime
import time
from pathlib import Path
import pickle
from tqdm import tqdm
import requests
import pytz

import warnings
warnings.filterwarnings("ignore")

In [None]:
load_dotenv()
webhook_url = os.getenv("webhook_url")

#recupero datetime odierno
def get_current_datetime(timezone='Europe/Rome',as_utc=False):
    local_tz = pytz.timezone(timezone)
    now_local = datetime.now(local_tz)

    if as_utc:
        now_utc = now_local.astimezone(pytz.utc)
    else:
        now_utc = now_local

    return now_utc.strftime('%Y-%m-%d %H:%M:%S')

# chiamata webhook per innescare workflow n8n
def create_presence(name,recognition_score=0.0):
    try:
        name,employee_id = name.split("-") 

        datetime_utc = get_current_datetime(as_utc=True)

        data = {
            "name": name,
            "employee_id": employee_id,
            "check_in": datetime_utc,
            "recognition_score": recognition_score
        }

        response = requests.post(
            url=webhook_url,
            json=data
        )
        response.raise_for_status()
        print(f"Presenza registrata: {name} alle {datetime_utc}")
        return response.json()
    
    except Exception as e:
        return f"Collegamento a n8n fallito: {e}"

In [None]:
app = FaceAnalysis(
    name="buffalo_l",
    root='~/.insightface',
    providers = ['OpenVINOExecutionProvider']
)

app.prepare(
    ctx_id=0,
    det_size=(640,640),
    det_thresh=0.6
)

def build_db(train_dir,min_det_score=0.7):
    db = {}
    stats = {
        "total": 0,
        "valid": 0,
        "rejected": 0
    }

    for person_dir in Path(train_dir).iterdir():
        embeddings = []
        person_name = person_dir.name

        #calcolo embeddings per ogni persone presente in db
        for img_path in tqdm(person_dir,desc=person_name):
            stats["total"] += 1

            #caricamento immagine
            img = cv2.imread(img_path)
            if img is None:
                stats["rejected"] += 1
                continue

            #estrazione embeddings
            faces = app.get(img)
            if len(faces) == 0:
                stats["rejected"] += 1
                print(f"Immagine {img_path} scartata")
                continue

            f = max(faces, key=lambda x: getattr(x, "det_score", 0.0))
            det_score = getattr(f, "det_score", 0.0)

            if det_score < min_det_score:
                stats["rejected"] += 1
                print(f"Immagine {img_path} scartata")
                continue

            emb = f.embedding.astype(np.float32).ravel()
            emb = emb / (np.linalg.norm(emb) + 1e-8)

            if emb.ndim != 1 or emb.size != 128:
                stats["rejected"] += 1
                print(f"Immagine {img_path} scartata")
                continue

            stats["valid"] += 1
            embeddings.append(emb)

        if embeddings:
            db[person_name] = np.stack(embeddings,axis=0)
            print(f"{person_name}: {len(embeddings)} immagini valide")
        else:
            print(f"Nessuna immagine valida per {person_name}")
        print("\n")

    print(f"\n Statistiche DB:")
    print(f"   Totale immagini: {stats['total']}")
    print(f"   Valide: {stats['valid']}")
    print(f"   Scartate: {stats['rejected']}")

    return db

print("Costruzione db\n")
db = build_db("train",min_det_score=0.7)

with open("face_db.pkl","wb") as f:
    pickle.dump(db,f)

In [None]:
def load_model(chekpoint_path,device):
    print("Caricamento modello per stima della prodonfità")
    model = AntiSpoofingModel(RGBEncoder(),DepthEncoder())
    try:
        state_dict = torch.load(chekpoint_path,map_location=device)
        model.load_state_dict(state_dict=state_dict,strict=False)
        print("Modello anti-spoofing caricato.")
    except Exception as e:
        return f"Errore nel caricamento del modello di anti-spoofing: {e}"

In [None]:
def db_match_from_embedding(query_embedding,db,top_k=3):
    best_name = "Sconosciuto"
    best_score = -1.0

    for name,emb in db.items():
        arr = np.asarray(emb,dtype=np.float32)

        if arr.ndim == 0:
            print(f"{name} ha un embedding vuoto")
            continue

        elif arr.ndim ==1 :
            if arr.shape[0] != 512:
                print(f"{name} ha un embedding di dimensione errata: {arr.shape[0]} anziché 512")
                continue
            similarity = float(np.dot(query_embedding,arr))
            
        elif arr.ndim == 2:
            if arr.shape[1] != 512:
                print(f"{name} ha un embedding di dimensione errata: {arr.shape[0]} anziché 512")
                continue
            #prodotto matriciale --> (N,512) @ (512,1) = (Nresults,1)
            similarities = arr @ query_embedding
            k = np.min(top_k,len(similarities))
            if k > 0:
                top_indices = np.argpartition(similarities,-k)[-k:]
                top_similarities = similarities[top_indices]
                similarity = float(np.mean(top_similarities))
            else:
                similarity = -1.0
        
        if similarity > best_score:
            best_score = similarity
            best_name = name
    
    return best_name, best_score



def recognize_face_from_frame(frame_bgr,db,threshold=0.8):
    results = []

    try:
        faces = app.get(frame_bgr)
    except Exception as e:
        return f"Errore nella detection del frame: {e}"

    for f in faces: 
        ##########################
        det_score = getattr(f,"det_score",0.0)
        if det_score < 0.7:
            continue

        emb = f.embeddings.astype(np.float32)
        emb = emb / (np.linalg(emb) + 1e-8)

        matched_name, score = db_match_from_embedding(emb,db)

        if score < threshold:
            name = "Sconosciuto"
            label = f"{name} (sim={score:.2f})"
        else:
            name = matched_name
            label = f"{name} (sim={score}:.2f)"
        
    bbox = tuple(map(int,f.bbox))
    results.append((bbox,label,name,score))

    return results


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

def estimate_depth_map(rgb_frame,feature_extractor,depth_model):
    img_rgb = cv2.imread(rgb_frame,cv2.COLOR_BGR2RGB)
    inputs = feature_extractor(
        images=img_rgb,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = depth_model(**inputs)
        predicted_depth = outputs.predicted_depth.squeeze().cpu().numpy()
    
    depth = predicted_depth - predicted_depth.min()
    if depth.max > 0:
        depth /= depth.max()
        
    return depth.astype(np.float32)
    

def preprocess_for_antispoofing(rgb_frame, depth_frame):
    """Preprocessa frame per anti-spoofing (ridimensiona a 224x224)"""
    # RGB
    rgb = cv2.cvtColor(rgb_frame, cv2.COLOR_BGR2RGB)
    rgb = cv2.resize(rgb, (224, 224))
    rgb = rgb.astype(np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    rgb = (rgb - mean) / std
    rgb_tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0)
    
    # Depth
    depth = cv2.resize(depth_frame, (224, 224))
    depth = np.clip(depth, 0, 1).astype(np.float32)
    depth_tensor = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0)
    
    return rgb_tensor.float(), depth_tensor.float()

In [None]:
# 1. Modello anti-spoofing
antispoofing_model_path = os.getenv("checkpoint_path")
antispoofind_model = load_model(antispoofing_model_path)

# 2. Modello calcolo mappe di profondità
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
depth_model.eval()

# 3. YOLO
yolo_stride = 1
yolo_weights = os.getenv("yolo_weights")
yolo = YOLO(yolo_weights)
model_names = yolo.model.names
phone_id = next((i,n) for i,n in model_names.items() if n.lower() == "cell phone")[0]

# 4. Db embeddings
with open("face_db.pkl","rb") as f:
    df = pickle.load(f)
    
print(f"Database caricato: {len(db)} persone\n")

# Configurazione
ANTISPOOFING_THRESHOLD = 0.7
ANTISPOOFING_STRIDE = 1
FACE_RECOGNITION_THRESHOLD = 0.8
PHONE_CONF_THRESHOLD = 0.7

phone_lock_timestamp = 0
last_phone_dets = []
last_antispoofing_result = (False, 0.0)
should_webhook = False


cap = cv2.VideoCapture(0)
if not cap.isOpened():
    raise IOError("Non si accende la webcam")

cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))

frame_idx = 0
device_lock = False
antispoofing_buffer = deque(maxlen=5)
t_prev = time.time()


print("\n" + "="*70)
print(" SISTEMA ATTIVO - MODALITÀ FRAME INTERO")
print("="*70)
print("Comandi:")
print("  'q' = Esci")
print("  'r' =   Reset lock telefono (UNICO MODO PER SBLOCCARE)")
print(f" Soglia AS: {ANTISPOOFING_THRESHOLD:.2f}")
print(f" Soglia telefono: {PHONE_CONF_THRESHOLD:.2f}")
print(f"  Lock: MANUALE ONLY (nessun timeout automatico)")

while True:
    ret,frame = cap.read()
    if not ret:
        print("Errore nella lettura del frame dalla webcam")
        break
    
    frame = cv2.resize(frame,(0,0),fx=1.5,fy=1.5)
    vis = frame.copy()
    depth_colored = None

    ### Yolo detection
    if frame_idx % yolo_stride == 0 and not device_lock:
        try:
            res = yolo(frame,conf=PHONE_CONF_THRESHOLD,classes=[phone_id])
            phone_dets_current = []
            for b in res.boxes:
                cls_id = int(b.cls[0])
                if cls_id == phone_id:
                    x1,y1,x2,y2 = map(int,b.xyxy[0].tolist())
                    phone_dets_current.append((
                        (x1,y1,x2,y2),
                        float(b.conf[0])
                    ))
            if phone_dets_current:
                last_phone_dets = phone_dets_current
                
                if not device_lock:
                    device_lock = True
                    phone_lock_timestamp = time.time()
                    print(f"Device lock attivo, telefono rilevato")
                    for (x1,y1,x2,y2),conf in phone_dets_current:
                        print(f"Bbox: ({x1},{y1})-({x2},{y2})) , conf={conf:.3f}")
                        
        except Exception as e: 
            print(f"Errore nella rivelazione di cellulari: {e}")
            continue


    ### Anti-spoofing sul frame
    frame_is_fake = False
    prob_fake_smooth_global = 0.0

    if not device_lock and (frame_idx % ANTISPOOFING_STRIDE == 0):
        try:
            depth_map_full = estimate_depth_map(frame,feature_extractor,depth_model)
            rgb_t,depth_t = preprocess_for_antispoofing(frame,depth_map_full)
            rgb_t, depth_t = rgb_t.to(device), depth_t.to(device)
            
            with torch.no_grad():
                logits = antispoofind_model(rgb_t,depth_t)
                probs = F.softmax(logits,dim=1)
                prob_fake = probs[0,1].item()
                antispoofing_buffer.append(prob_fake)
            
            prob_fake_smooth_global = np.mean(antispoofing_buffer)
            frame_is_fake = prob_fake_smooth_global > ANTISPOOFING_THRESHOLD
            last_antispoofing_result = (frame_is_fake, prob_fake_smooth_global)
            
            depth_colored = cv2.applyColorMap((depth_map_full * 255).astype(np.float32))
            depth_colored = cv2.resize(depth_colored,(200,200))
        except Exception as e:
            print(f"Errore nel flusso anti-spoofing: {e}")
    else:
        frame_is_fake,prob_fake_smooth_global = last_antispoofing_result
        
    ### Riconoscimento facciale e visualizzazione 
    for (x1,y1,x2,y2) in last_phone_dets:
        cv2.rectangle(vis,(x1,y1),(x2,y2),(0,0,255),3)
        cv2.putText(vis,f"Telefono rilevato: {conf:.2f}",(x1,y1-10),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,0,255),2)
        
    if frame_is_fake:
        cv2.rectangle(vis, (5, 5), (vis.shape[1]-5, 80), (0, 0, 255), -1)
        cv2.putText(vis, f" Frame falso ({prob_fake_smooth_global:.2f})",(15, 45), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
    
    if device_lock:
        elapsed = time.time() - phone_lock_timestamp
        
        cv2.rectangle(vis, (5, 90), (vis.shape[1]-5, 150), (0, 0, 200), -1)
        cv2.putText(vis, f"  DISPOSITIVO RILEVATO - FR BLOCCATO",
                   (15, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        cv2.putText(vis, f" Lock attivo da {elapsed:.1f}s | Premi 'r' per sbloccare",
                   (15, 145), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
    
    else:
        fr_results = recognize_face_from_frame(frame,db,FACE_RECOGNITION_THRESHOLD)
        
        for (x1,y1,x2,y2), label, name, score in fr_results:
            if name != "Sconosciuto":
                if frame_is_fake:
                    label = f"Fake {name} ({prob_fake_smooth_global:.2f})"
                    color = (0,0,255)
                else:
                    
                    label = f"{name} ({1-prob_fake_smooth_global:.2f})"
                    color = (0,255,0)
            else:
                should_webhook = True
                color = (0,165,255)
                label = f"Sconodciuto: {conf}"
            
            cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
            cv2.putText(vis, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
            
    # Chiamata webhook verso n8n
    if should_webhook:
        create_presence(name=name,recognition_score=score)
        break
    
    # Comandi da tastiera
    key = cv2.waitKey(1) & 0xFF
    if key == ord('q'):
        break
    elif key == ord('r'):
        if device_lock:
            device_lock = False
            last_phone_dets = []
            print(" Lock resettato manualmente")
        else:
            print(" Lock già disattivato")
    
    frame_idx += 1
        
cap.release()
cv2.destroyAllWindows()
print("\n  Sistema terminato")