# mmFace + InsightFace2D Hybrid Model

## Build and Load Dataset

In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from utils import by_experiment, get_crd_data
from glob import glob
from tqdm import tqdm
import torch
import json
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Extract 10 ARD frames and match with rgb_emb from Experiments 0, 5 and 10 
#   - Or duplicate RGB embedding for multiple ARD frames so can extract 75

class HybridDataset(Dataset):
    def __init__(self, radar_data, rgb_embs, subject_labels, liveness_labels, split):
        self.radar_data = radar_data
        self.rgb_embs = rgb_embs
        self.subject_labels = subject_labels
        self.liveness_labels = liveness_labels
        self.split = split

    def __len__(self):
        return len(self.radar_data)
    
    def __getitem__(self, idx):
        return self.radar_data[idx], self.rgb_embs[idx], self.subject_labels[idx], self.liveness_labels[idx]


def get_ard(path, subject, experiment, num_frames):
        with open(f"{path}/{subject}-{experiment}_radar.json", 'r') as f:
            ard = np.abs(get_crd_data(json.load(f), num_chirps_per_burst=16))[:num_frames].astype(np.float32)
            return ard[np.random.permutation(len(ard))]            

def load_dataset(raw_path, num_subjects, num_experiments=15, train_split=0.8, test_split=0.1, device="cuda", seed=24):
    np.random.seed(seed)
    subjects = list(range(num_subjects)) + [int(f"9{s}") for s in range(num_subjects)]
    experiments = range(num_experiments)
    val_split_end = train_split + test_split

    data = {"radar": [[], [], []], "rgb_embs": [[], [], []]}
    labels = {"subject": [[], [], []], "liveness": [[], [], []]}

    # Subjects x Experiments x Frames x Embedding_Size *(21 x 15 x 10 x 512)
    real_rgb = np.load("data/InsightFace_embs/real_insightface_embs.npy")
    # BE CAREFUL THIS IS MOSTLY ZEROS
    fake_rgb = np.load("data/InsightFace_embs/fake_insightface_embs.npy")

    # TODO: FILTER UNDETECTED EMBEDDINGS
    for subject in subjects:
        num_frames = 250 if subject < 90 else 74
        live = 0 if subject < 90 else 1
        experiments = range(num_experiments) if subject < 90 else [0, 5, 10]
        for experiment in experiments:
            exp_ard = get_ard(raw_path, subject, experiment, num_frames)
            n = 15 if subject < 90 else len(exp_ard)

            exp_train = exp_ard[:int(n*train_split)]
            exp_val = exp_ard[int(n*train_split):int(n*val_split_end)]
            exp_test = exp_ard[int(n*val_split_end):n]

            data["radar"][0].append(exp_train)
            data["radar"][1].append(exp_val)
            data["radar"][2].append(exp_test)

            data["rgb_embs"][0].append(rgb_train)
            data["rgb_embs"][1].append(rgb_val)
            data["rgb_embs"][2].append(rgb_test)

            labels["subject"][0].append([subject]*len(exp_train))
            labels["subject"][1].append([subject]*len(exp_val))
            labels["subject"][2].append([subject]*len(exp_test))

            labels["liveness"][0].append([live]*len(exp_train))
            labels["liveness"][1].append([live]*len(exp_val))
            labels["liveness"][2].append([live]*len(exp_test))




num_subjects = 21

train, validation, test = load_dataset()

## Model Creation and Loading

In [None]:
from neural_nets import MMFaceHybrid
from utils import load_model, load_history
from torch import nn

num_epochs = 50
learning_rate = 0.01

lambda1 = 1
lambda2 = 1

model = MMFaceHybrid(num_subjects).to(device)

# Loss + Optimiser
criterion = nn.CrossEntropyLoss()
optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=0.001, momentum=0.9)

model_name = f"mmFace-hybrid.pt"
cur_epoch, loss_history, train_acc, val_acc = load_model(model_name, model, optimiser)

if len(loss_history) > 0:
    print(f"{model_name}\n\tEpoch: {cur_epoch}\n\tLoss: {loss_history[-1]:.4f}\n\tTrain Accuracy: {train_acc[-1]:.4f}\n\tValidation Accuracy: {val_acc[-1]:.4f}")

## Training

In [None]:
for epoch in range(cur_epoch, num_epochs):
    print(f"\nEpoch [{epoch}/{num_epochs-1}]:")
    if os.path.exists(f"models/{model_name}"):
        loss_history, train_acc, val_acc = load_history(f"models/{model_name}")

    model.train()
    # Running Loss and Accuracy
    running_loss, running_acc_s, running_acc_l, total_s, total_l = 0., 0., 0., 0., 0.

    for radar, rgb_emb, labels_s, labels_l in tqdm(train):
        # Forward Pass
        out1, out2 = model(radar, rgb_emb)
        _, preds_s = torch.max(out1.data, 1)
        _, preds_l = torch.max(out2.data, 1)
        loss1 = criterion(out1, labels_s)
        loss2 = criterion(out2, labels_l)
        loss = lambda1*loss1 + lambda2*loss2

        # Backward Pass and Optimise
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        running_loss += loss.item()
        total_s += labels_s.size(0)
        total_l += labels_l.size(0)
        running_acc_s += (preds_s == labels_s).sum().item()
        running_acc_l += (preds_l == labels_l).sum().item()

        del radar, rgb_emb, labels_s, labels_l, out1, out2
        torch.cuda.empty_cache()
    
    avg_train_loss = running_loss/len(train)
    avg_train_acc_s = 100*running_acc_s/total
    avg_train_acc_l = 100*running_acc_l/total
    print(f"\tAverage Train Loss: {avg_train_loss:.4f}")
    print(f"\tTrain Accuracy (Subjects): {avg_train_acc_s:.4f}%")
    print(f"\tTrain Accuracy (Liveness): {avg_train_acc_l:.4f}%")

    torch.save({"epoch": epoch+1,
                "model_state_dict": model.state_dict(),
                "optimiser_state_dict": optimiser.state_dict(),
                "loss_history": loss_history + [avg_train_loss],
                "train_acc": train_acc + [(avg_train_acc_s, avg_train_acc_l)],
                "val_acc": val_acc},
                f"models/{model_name}")
    
    # Validation
    model.eval()
    with torch.no_grad():
        correct_s, correct_l, total_s, total_l = 0., 0., 0., 0.
        for radar, rgb_emb, labels_s, labels_l in validation:
            out1, out2 = model(radar, rgb_emb)
            _, preds_s = torch.max(out1.data, 1)
            _, preds_l = torch.max(out2.data, 1)
            total_s += labels_s.size(0)
            total_l += labels_l.size(0)
            correct_s += (preds_s == labels_s).sum().item()
            correct_l += (preds_l == labels_l).sum().item()
            del radar, rgb_emb, labels_s, labels_l, out1, out2
        
        avg_val_acc_s = 100*correct_s/total_s
        avg_val_acc_l = 100*correct_l/total_l
        print(f"\tValidation Accuracy (Subject): {avg_val_acc_s:.4f}%")
        print(f"\tValidation Accuracy (Liveness): {avg_val_acc_l:.4f}%")

    model_checkpoint = torch.load(f"models/{model_name}")
    model_checkpoint["val_acc"].append((avg_val_acc_s, avg_val_acc_l))
    torch.save(model_checkpoint, f"models/{model_name}")

## Testing

In [None]:
preds_subject, preds_liveness = [], []
true_subject, true_liveness = [], []

model.eval()
with torch.no_grad():
    correct_s, correct_l, total_s, total_l = 0., 0., 0., 0.
    for radar, rgb_emb, labels_s, labels_l in validation:
        out1, out2 = model(radar, rgb_emb)
        _, preds_s = torch.max(out1.data, 1)
        _, preds_l = torch.max(out2.data, 1)
        total_s += labels_s.size(0)
        total_l += labels_l.size(0)
        correct_s += (preds_s == labels_s).sum().item()
        correct_l += (preds_l == labels_l).sum().item()
        del radar, rgb_emb, labels_s, labels_l, out1, out2

        preds_subject.extend(preds_s.cpu().numpy())
        preds_liveness.extend(preds_l.cpu().numpy())
        true_subject.extend(labels_s.data.cpu().numpy())
        true_liveness.extend(labels_l.data.cpu().numpy())
        del radar, rgb_emb, labels_s, labels_l, out1, out2
    
    print(f"Test Accuracy (Subject): {100*correct_s/total_s:.4f}%")
    print(f"Test Accuracy (Liveness): {100*correct_l/total_l:.4f}%")

In [None]:
import matplotlib.pyplot as plt

losses, train_acc, val_acc = load_history(f"models/{model_name}")

fig, axs = plt.subplots(1, 2, figsize=(10, 3), dpi=120)

axs[0].plot(range(len(train_acc)), train_acc)
axs[0].plot(range(len(val_acc)), val_acc)
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Accuracy (%)")

axs[1].plot(range(len(losses)), losses)
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Loss")

axs[0].set_title(model_name)
plt.show()

### Subject Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import sys

# Build confusion matrix
cf_matrix = confusion_matrix(preds_subject, true_subject)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=subject_names, columns=subject_names)
plt.figure(figsize = (14, 5))
heatmap = sn.heatmap(df_cm, annot=True)
heatmap.set(xlabel ="Predictions", ylabel = "Actual", title ='Precision Confusion Matrix')

### Liveness Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import sys

# Build confusion matrix
cf_matrix = confusion_matrix(preds_liveness, true_liveness)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=["Real", "Fake"], columns=["Real", "Fake"])
plt.figure(figsize = (14, 5))
heatmap = sn.heatmap(df_cm, annot=True)
heatmap.set(xlabel ="Predictions", ylabel = "Actual", title ='Precision Confusion Matrix')