In [1]:
import torch
from torch.utils.data import ConcatDataset
from torchvision import transforms
import os
import numpy as np

# Import your custom utilities
from utils.dataset import LFWPairDataset, CASIAwebfaceDataset
from utils.config import DATASET_PATH
from utils.model_utils import tune_threshold, evaluate_lfw_10fold
from utils.criterion import SphereFaceNet, CosFaceNet, ArcFaceNet, CurricularFaceNet

# --- Define transforms ---
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = CASIAwebfaceDataset(
    root_dir=f'{DATASET_PATH}/CASIA-webface',
    transform=transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
)
num_classes = train_dataset.num_of_identities
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using: ", device)

dataset_path = "/home/phatvo/callmePhineas/DACN/data_a_P/LFW"

# --- Dataset paths ---
aligned_lfw_path = f'{dataset_path}'
match_train_file = f'{dataset_path}/matchpairsDevTrain.csv'
mismatch_train_file = f'{dataset_path}/mismatchpairsDevTrain.csv'
match_test_file = f'{dataset_path}/matchpairsDevTest.csv'
mismatch_test_file = f'{dataset_path}/mismatchpairsDevTest.csv'
pairs_all_file = f'{dataset_path}/pairs.csv'

# --- Combine DevTrain + DevTest for threshold tuning ---
train_pairs_dataset_match = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=match_train_file,
    transform=test_transform
)
train_pairs_dataset_mismatch = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=mismatch_train_file,
    transform=test_transform
)
test_pairs_dataset_match = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=match_test_file,
    transform=test_transform
)
test_pairs_dataset_mismatch = LFWPairDataset(
    root_dir=aligned_lfw_path,
    pairs_files=mismatch_test_file,
    transform=test_transform
)

# Combine all pairs for threshold tuning
combined_pairs_dataset = ConcatDataset([
    train_pairs_dataset_match,
    train_pairs_dataset_mismatch,
    
    test_pairs_dataset_match,
    test_pairs_dataset_mismatch
])

  from .autonotebook import tqdm as notebook_tqdm


Loading CASIAwebface dataset |████████████████████✗︎ (!) 490623/10572 [4641%] in 1.0s (474854.42/s) 
Using:  cuda


In [2]:
def tune_and_validate(
    model,
    batch_size=64,
):
    from contextlib import redirect_stdout
    
    log_file = "./validation_log"
    with open(log_file, 'a') as f:
        with redirect_stdout(f):
            # --- Tune threshold ---
            print("🔧 Tuning threshold on combined DevTrain + DevTest ...")
            best_threshold, best_accuracy, threshold_list, accuracy_list = tune_threshold(
                model, combined_pairs_dataset, batch_size, device, thresholds=np.arange(0.0, 0.5, 0.002)
            )
            print(f"✅ Best threshold: {best_threshold:.4f}, Combined accuracy: {best_accuracy:.2f}%")

            # --- Validate with 10-Fold LFW ---
            print("📊 Running 10-Fold cross-validation ...")
            mean_acc, std_acc = evaluate_lfw_10fold(
                model=model,
                pairs_file=pairs_all_file,
                batch_size=batch_size,
                root_dir=aligned_lfw_path,
                transform=test_transform,
                device=device,
                threshold=best_threshold
            )

            print(f"🎯 LFW 10-Fold Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}% (Threshold: {best_threshold:.4f})")
            return best_threshold, best_accuracy, mean_acc, std_acc

# SphereFace

In [3]:
# # --- Load model ---
# model_weights_path = "./models_evaluation/CosFace_min_loss.pth"
# print(f"🔹 Loading SphereFace model from {model_weights_path}")
# model = SphereFaceNet(num_classes=num_classes).to(device)
# state_dict = torch.load(model_weights_path, map_location=device)
# model.load_state_dict(state_dict)
# model.eval()

# tune_and_validate(model=model)


# CosFace

In [4]:
# --- Load model ---
model_weights_path = "./models_evaluation/CosFace_min_loss.pth"
print(f"🔹 Loading CosFace model from {model_weights_path}")
model = CosFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
model.eval()

tune_and_validate(model=model)


🔹 Loading CosFace model from ./models_evaluation/CosFace_min_loss.pth
Initialize CosFace with margin 0.35, scale 64.0


                                                                          

(0.39, 67.34375, 65.71666666666667, 1.6108141488769678)

# ArcFace

In [5]:
# --- Load model ---
model_weights_path = "./models_evaluation/ArcFace_min_loss.pth"
print(f"🔹 Loading ArcFace model from {model_weights_path}")
model = ArcFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
model.eval()

tune_and_validate(model=model)

🔹 Loading ArcFace model from ./models_evaluation/ArcFace_min_loss.pth


Evaluating (Verification):   0%|          | 0/50 [00:00<?, ?it/s]

Initialize ArcFace with margin 0.35, scale 64.0


                                                                          

(0.304, 76.375, 77.35000000000001, 1.379311422413371)

# CurricularFace

In [6]:
# --- Load model ---
model_weights_path = "./models_evaluation/CurricularFace_min_loss.pth"
print(f"🔹 Loading CurricularFace model from {model_weights_path}")
model = CurricularFaceNet(num_classes=num_classes).to(device)
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
model.eval()

tune_and_validate(model=model)

🔹 Loading CurricularFace model from ./models_evaluation/CurricularFace_min_loss.pth
Initialize CurricularFace with margin 0.5, scale 64.0, momentum 0.01


                                                                          

(0.43, 71.9375, 71.71666666666667, 2.135740829054146)