In [None]:
import insightface
from insightface.app.common import Face
from insightface.model_zoo import model_zoo
from neural_nets import MMFaceFE, MMFaceClassifier, insightface_model, InsightFaceClassifier
# REQUIRED FOR CUDA TO BE USED
import torch

def recognise(rgb_input, radar_input, device="cuda"):
    det_model = model_zoo.get_model("../models/buffalo_l/det_10g.onnx")
    det_model.prepare(ctx_id=0, input_size=(480, 640), det_thres=0.5)
    rec_model = model_zoo.get_model("../models/buffalo_l/w600k_r50.onnx")

    mmface_model = MMFaceFE().to(device)
    mmface_model.eval()

    with torch.no_grad():
        rgb_emb = insightface_model(rgb_input[..., ::-1], det_model, Face, rec_model)
        radar_emb = mmface_model(radar_input)
    
    return rgb_emb, radar_emb


def classify(rgb_emb, radar_emb, num_subjects=21, device="cuda"):
    insightface_classifier = InsightFaceClassifier(num_subjects).to(device)
    insightface_classifier.eval()

    mmface_classifier = MMFaceClassifier().to(device)
    mmface_classifier.eval()

    with torch.no_grad():
        subject = insightface_classifier(rgb_emb)
        liveness = mmface_classifier(radar_emb)
    
    return subject, liveness