### 1. Preparing for collecting predictions

In [None]:
import os
import numpy as np
import cv2
from detection import detect_face_for_testing
from feature_extraction import predict_spoof
from anti_spoof import load_antispoof_model

DATASET_PATH_REAL = "MSU-MFSD_pics/real"
DATASET_PATH_ATTACK = "MSU-MFSD_pics/attack"

antispoof_sess, antispoof_input = load_antispoof_model()

y_true_real = np.full(len(os.listdir(DATASET_PATH_REAL)), True)
y_true_attack = np.full(len(os.listdir(DATASET_PATH_ATTACK)), False)
y_true = np.concatenate((y_true_real, y_true_attack))

def get_predictions(dataset_path, correct_pred, y_pred):
    #j = 0

    for image in os.listdir(dataset_path):
        image_path = os.path.join(dataset_path, image)
        
        image_array = cv2.imread(image_path)
        face_data, _, image_rgb = detect_face_for_testing(image_array)

        if face_data is not None:
            # Check if image is spoofed based on the face data
            pred = predict_spoof(face_data, image_rgb, antispoof_sess, antispoof_input)
            y_pred.append(pred)

            if pred != correct_pred:
                print(f'Incorrect prediction at: {image_path}')

        #j += 1

        #if j > 10: # To ensure not loading the entire dataset when testing the basic test functionality
        #    break



### 2. Collecting predictions on images classified as real and spoofed, respectively

In [None]:
y_pred = []

# Collect predictions for non-spoofed images
get_predictions(DATASET_PATH_REAL, True, y_pred)

# Collect predictions for spoofed images
get_predictions(DATASET_PATH_ATTACK, False, y_pred)

# Look at the updated predictions list
print(y_pred)

### 3. Computing the common accuracy scores and plotting a confusion matrix

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}\n")

# Global font
mpl.rcParams['font.family'] = 'Times New Roman'

# Global font size
mpl.rcParams['font.size'] = 18

conf_mat = confusion_matrix(y_true, y_pred)

# Plotting the confusion matrix
conf_mat_display = ConfusionMatrixDisplay(conf_mat)
conf_mat_display.plot()

# Labels
plt.xlabel("Predicted Value")
plt.ylabel("True Value")

plt.show()
