In [None]:
import torch
from torch.utils.data import ConcatDataset
from torchvision import transforms
import os
import numpy as np

# Import your custom utilities
from utils.dataset import LFWPairDataset, CASIAwebfaceDataset
from utils.config import DATASET_PATH
from utils.model_utils import tune_threshold, evaluate_lfw_10fold
from utils.criterion import SphereFaceNet, CosFaceNet, ArcFaceNet, CurricularFaceNet

# --- Define transforms ---
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = CASIAwebfaceDataset(
    root_dir=f'{DATASET_PATH}/CASIA-webface',
    transform=transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
)
num_classes = train_dataset.num_of_identities
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using: ", device)

dataset_path = "/path/to/dataset"

# --- Dataset paths ---
aligned_lfw_path = f'{dataset_path}/Labeled Faces in the Wild (LFW)_aligned'
match_train_file = f'{dataset_path}/Labeled Faces in the Wild (LFW)/matchpairsDevTrain.csv'
mismatch_train_file = f'{dataset_path}/Labeled Faces in the Wild (LFW)/mismatchpairsDevTrain.csv'
match_test_file = f'{dataset_path}/Labeled Faces in the Wild (LFW)/matchpairsDevTest.csv'
mismatch_test_file = f'{dataset_path}/Labeled Faces in the Wild (LFW)/mismatchpairsDevTest.csv'
pairs_all_file = f'{dataset_path}/Labeled Faces in the Wild (LFW)/pairs.csv'

# --- Combine DevTrain + DevTest for threshold tuning ---
train_pairs_dataset_match = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=match_train_file,
    transform=test_transform
)
train_pairs_dataset_mismatch = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=mismatch_train_file,
    transform=test_transform
)
test_pairs_dataset_match = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=match_test_file,
    transform=test_transform
)
test_pairs_dataset_mismatch = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=mismatch_test_file,
    transform=test_transform
)

# Combine all pairs for threshold tuning
combined_pairs_dataset = ConcatDataset([
    train_pairs_dataset_match,
    train_pairs_dataset_mismatch,
    test_pairs_dataset_match,
    test_pairs_dataset_mismatch
])

In [None]:
def tune_and_validate(
    model,
    batch_size=512,
    device=None
):


    # --- Tune threshold ---
    print("🔧 Tuning threshold on combined DevTrain + DevTest ...")
    best_threshold, best_accuracy, threshold_list, accuracy_list = tune_threshold(
        model, combined_pairs_dataset, batch_size, device
    )
    print(f"✅ Best threshold: {best_threshold:.4f}, Combined accuracy: {best_accuracy:.2f}%")

    # --- Validate with 10-Fold LFW ---
    print("📊 Running 10-Fold cross-validation ...")
    mean_acc, std_acc = evaluate_lfw_10fold(
        model=model,
        pairs_file=pairs_all_file,
        batch_size=batch_size,
        root_dir=aligned_lfw_path,
        transform=test_transform,
        device=device,
        threshold=best_threshold
    )

    print(f"🎯 LFW 10-Fold Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}% (Threshold: {best_threshold:.4f})")
    return best_threshold, best_accuracy, mean_acc, std_acc

# SphereFace

In [None]:
# --- Load model ---
model_weights_path = "./models_evaluation/CosFace_min_loss.pth"
print(f"🔹 Loading SphereFace model from {model_weights_path}")
model = SphereFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

tune_and_validate(model=model)


# CosFace

In [None]:
# --- Load model ---
model_weights_path = "./models_evaluation/CosFace_min_loss.pth"
print(f"🔹 Loading CosFace model from {model_weights_path}")
model = CosFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

tune_and_validate(model=model)


# ArcFace

In [None]:
# --- Load model ---
model_weights_path = "./models_evaluation/ArcFace_min_loss.pth"
print(f"🔹 Loading ArcFace model from {model_weights_path}")
model = ArcFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

tune_and_validate(model=model)

# CurricularFace

In [None]:
# --- Load model ---
model_weights_path = "./models_evaluation/CurricularFace_min_loss.pth"
print(f"🔹 Loading CurricularFace model from {model_weights_path}")
model = CurricularFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

tune_and_validate(model=model)