### 1. Functions and variables for collecting predictions

In [None]:
import os
import numpy as np
import cv2
from detection import detect_face_for_testing
from feature_extraction import predict_spoof
from anti_spoof import load_antispoof_model
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

DATASET_PATH_REAL = "MSU-MFSD_pics/real"
DATASET_PATH_ATTACK = "MSU-MFSD_pics/attack"
SPOOF_THRESHOLDS = np.concatenate(([0], np.arange(0.94, 1.00, 0.01)))

antispoof_sess, antispoof_input = load_antispoof_model()

y_true_real = np.full(len(os.listdir(DATASET_PATH_REAL)), True)
y_true_attack = np.full(len(os.listdir(DATASET_PATH_ATTACK)), False)
y_true = np.concatenate((y_true_real, y_true_attack))

def get_predictions(dataset_path, correct_pred, y_pred, threshold):
    #j = 0

    for image in os.listdir(dataset_path):
        image_path = os.path.join(dataset_path, image)
        
        image_array = cv2.imread(image_path)
        face_data, _, image_rgb = detect_face_for_testing(image_array)

        if face_data is not None:
            # Check if image is spoofed based on the face data
            pred = predict_spoof(face_data, image_rgb, antispoof_sess, antispoof_input, threshold)
            y_pred.append(pred)

            if pred != correct_pred:
                #print(f'Incorrect prediction at: {image_path}')
                pass

        #j += 1

        #if j > 10: # To ensure not loading the entire dataset when testing the basic test functionality
        #    break

def run_test(path_real, path_attack, threshold=0.98):
    y_pred = []

    # Collect predictions for non-spoofed images
    get_predictions(path_real, True, y_pred, threshold)

    # Collect predictions for spoofed images
    get_predictions(path_attack, False, y_pred, threshold)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    print(f"Threshold: {threshold:.2f}")
    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}\n")

    return y_pred



### 2. Collecting predictions on images classified as real and spoofed, respectively, for each threshold

In [None]:
from plotting import plot_conf_mat, plot_ROC

y_pred_list = []

# Loop through the thresholds
for threshold in SPOOF_THRESHOLDS:
    y_pred = run_test(DATASET_PATH_REAL, DATASET_PATH_ATTACK, threshold)
    y_pred_list.append(y_pred)

plot_ROC(y_true, y_pred_list, SPOOF_THRESHOLDS)


### 3. Plotting a confusion matrix for each threshold

In [None]:
for y_pred in y_pred_list:
    plot_conf_mat(y_true, y_pred)