# Biometric Fusion with MPT Testing

This notebook implements the testing and evaluation process for the Multi-Modal Biometric Fusion model with Modified Prompt Tuning (MPT).

## Overview

The testing includes:
1. Rank-1 recognition accuracy 
2. ROC curve and EER calculation
3. Modality-specific performance analysis
4. Embedding visualization using t-SNE

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from pathlib import Path
from sklearn.manifold import TSNE
from sklearn.metrics import roc_curve, auc
import pandas as pd
from scipy.spatial.distance import cdist

from model.model_mpt import BiometricModel
from model.dataset import BiometricDataset

## Configuration Parameters

Set up parameters for testing.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Testing parameters
params = {
    'data_path': './dataset2',                   # Path to dataset
    'model_path': './checkpoints/model_latest_mpt_best.pt',  # Path to trained model
    'embedding_dim': 256,                         # Dimension of embeddings
    'batch_size': 32,                             # Batch size for testing
    'output_dir': './results',                    # Directory to save results
    'gallery_sizes': [1, 3, 5, 7, 9],             # Gallery sizes to test
    'visualize_persons': 30                       # Number of persons for visualization
}

# Create output directory
os.makedirs(params['output_dir'], exist_ok=True)

# Define transformation for testing
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

## Load Model

Load the trained model for evaluation.

In [None]:
# Load model
model = BiometricModel(embedding_dim=params['embedding_dim'])
model.load_state_dict(torch.load(params['model_path'], map_location=device))
model = model.to(device)
model.eval()

print(f"Loaded model from {params['model_path']}")

## Embedding Extraction

Create functions to extract embeddings for all test images.

In [None]:
def load_image(image_path, transform):
    """Load and transform an image"""
    img = Image.open(image_path).convert('L')
    return transform(img)

def create_embedding_dicts(data_root, model, device=device):
    """
    Create dictionaries of embeddings for all test images
    Returns a dictionary with person IDs as keys and lists of embeddings as values
    """
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    modalities = ['periocular', 'forehead', 'iris']
    split = 'test'
    test_dict = {}

    model = model.to(device)
    model.eval()

    print(f"Extracting embeddings for {split} set...")
    # Get person IDs from directory
    person_ids = [f"{i:03d}" for i in range(1, 248)]
    
    for person_id in tqdm(person_ids):
        test_dict[person_id] = []  # List to store embeddings
        # Iterate over poses (1 to 10)
        for pose_idx in range(1, 11):
            # Load images for all three modalities for this person and pose
            images = []
            for modality in modalities:
                img_path = Path(data_root) / modality / split / person_id
                if not img_path.exists():
                    continue
                    
                pose_images = [f for f in sorted(os.listdir(img_path)) if f!='.DS_Store']
                if len(pose_images) < pose_idx:
                    break
                    
                img_name = pose_images[pose_idx - 1]  # Select the pose_idx-th image
                img_path = img_path / img_name
                if not img_path.exists():
                    break
                    
                img = load_image(img_path, transform).to(device)
                images.append(img)
                
            # Only execute if all images are found
            if len(images) == 3:
                # Pass three images to the model
                with torch.no_grad():
                    embedding = model(
                        images[0].unsqueeze(0),  # periocular (add batch dim)
                        images[1].unsqueeze(0),  # forehead
                        images[2].unsqueeze(0)   # iris
                    ).squeeze(0).cpu()  # Remove batch dim
                test_dict[person_id].append(embedding)

    # Remove persons with no embeddings
    test_dict = {k: v for k, v in test_dict.items() if v}
    print(f"Created embeddings for {len(test_dict)} persons")
    
    return test_dict

## Rank-1 Recognition Evaluation

Evaluate Rank-1 recognition accuracy with different gallery sizes.

In [None]:
def evaluate_rank1(test_dict, gallery_size=5, device=device):
    """
    Evaluate rank-1 recognition accuracy using cosine similarity.
    
    Args:
        test_dict: Dictionary with person IDs as keys and lists of embeddings as values
        gallery_size: Number of poses to use as gallery (remaining used as probe)
        device: Device to perform calculations on
    
    Returns:
        Rank-1 recognition rate, number of correct identifications, total number of probes
    """
    correct = 0
    total = 0
    
    # Convert embeddings to tensors and prepare gallery and probe sets
    gallery_embeddings = []
    gallery_labels = []
    probe_embeddings = []
    probe_labels = []
    
    for person_id, embeddings in test_dict.items():
        if len(embeddings) < gallery_size + 1:
            continue
            
        # Use first gallery_size embeddings as gallery
        for i in range(gallery_size):
            gallery_embeddings.append(embeddings[i])
            gallery_labels.append(person_id)
            
        # Use remaining embeddings as probe
        for i in range(gallery_size, len(embeddings)):
            probe_embeddings.append(embeddings[i])
            probe_labels.append(person_id)
    
    gallery_tensor = torch.stack(gallery_embeddings).to(device)
    probe_tensor = torch.stack(probe_embeddings).to(device)
    
    # Normalize embeddings for cosine similarity
    gallery_tensor = F.normalize(gallery_tensor, p=2, dim=1)
    probe_tensor = F.normalize(probe_tensor, p=2, dim=1)
    
    print(f"Gallery size: {len(gallery_tensor)} embeddings")
    print(f"Probe size: {len(probe_tensor)} embeddings")
    
    # Calculate similarities in batches to avoid OOM
    batch_size = 100
    correct = 0
    total = len(probe_tensor)
    
    with torch.no_grad():
        for i in range(0, len(probe_tensor), batch_size):
            batch_end = min(i + batch_size, len(probe_tensor))
            batch_probe = probe_tensor[i:batch_end]
            
            # Calculate cosine similarity between probe and gallery
            similarities = torch.mm(batch_probe, gallery_tensor.t())
            
            # Get the indices of the highest similarities
            _, indices = torch.max(similarities, dim=1)
            
            # Check if the prediction is correct
            for j in range(len(batch_probe)):
                probe_person = probe_labels[i + j]
                predicted_person = gallery_labels[indices[j].item()]
                
                if probe_person == predicted_person:
                    correct += 1
    
    # Calculate rank-1 recognition rate
    rank1_rate = correct / total if total > 0 else 0
    return rank1_rate, correct, total

In [None]:
# Extract embeddings for all test images
test_dict = create_embedding_dicts(params['data_path'], model, device)

# Evaluate rank-1 recognition for different gallery sizes
print("Evaluating rank-1 recognition performance...")
results = []

for gallery_size in params['gallery_sizes']:
    rank1_rate, correct, total = evaluate_rank1(test_dict, gallery_size=gallery_size, device=device)
    results.append({
        'gallery_size': gallery_size,
        'rank1_rate': rank1_rate,
        'correct': correct,
        'total': total
    })
    print(f"Rank-1 recognition rate with gallery size {gallery_size}: {rank1_rate:.4f} ({rank1_rate*100:.2f}%) {correct}/{total}")

# Plot results
plt.figure(figsize=(10, 6))
plt.plot([r['gallery_size'] for r in results], [r['rank1_rate']*100 for r in results], 'o-', linewidth=2)
plt.xlabel('Gallery Size')
plt.ylabel('Rank-1 Recognition Rate (%)')
plt.title('Rank-1 Recognition Rate vs. Gallery Size')
plt.grid(True)
plt.xticks([r['gallery_size'] for r in results])
plt.savefig(f"{params['output_dir']}/rank1_vs_gallery_size.png", dpi=300)
plt.show()

# Save results to file
with open(f"{params['output_dir']}/rank1_results.txt", 'w') as f:
    f.write("Gallery Size | Rank-1 Rate | Correct/Total\n")
    f.write("-------------|-------------|-------------\n")
    for r in results:
        f.write(f"{r['gallery_size']:12d} | {r['rank1_rate']*100:10.2f}% | {r['correct']}/{r['total']}\n")

## Verification Performance (ROC Curve and EER)

Evaluate the verification performance of the model.

In [None]:
def compute_verification_metrics(test_dict, gallery_size=5, device=device):
    """Compute verification metrics (ROC curve, EER)"""
    # Prepare gallery and probe sets
    gallery_embeddings = []
    gallery_labels = []
    probe_embeddings = []
    probe_labels = []
    
    for person_id, embeddings in test_dict.items():
        if len(embeddings) < gallery_size + 1:
            continue
            
        # Use first gallery_size embeddings as gallery
        for i in range(gallery_size):
            gallery_embeddings.append(embeddings[i])
            gallery_labels.append(person_id)
            
        # Use remaining embeddings as probe
        for i in range(gallery_size, len(embeddings)):
            probe_embeddings.append(embeddings[i])
            probe_labels.append(person_id)
    
    gallery_tensor = torch.stack(gallery_embeddings).to(device)
    probe_tensor = torch.stack(probe_embeddings).to(device)
    
    # Normalize embeddings for cosine similarity
    gallery_tensor = F.normalize(gallery_tensor, p=2, dim=1)
    probe_tensor = F.normalize(probe_tensor, p=2, dim=1)
    
    # Compute all pairwise similarities
    similarities = torch.mm(probe_tensor, gallery_tensor.t()).cpu().numpy()
    
    # Create ground truth labels for all pairs
    y_true = []
    y_scores = []
    
    for i, probe_label in enumerate(probe_labels):
        for j, gallery_label in enumerate(gallery_labels):
            # Ground truth: 1 if same person, 0 if different
            y_true.append(1 if probe_label == gallery_label else 0)
            y_scores.append(similarities[i, j])
    
    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    
    # Compute Equal Error Rate (EER)
    fnr = 1 - tpr
    eer_idx = np.argmin(np.abs(fnr - fpr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2
    eer_threshold = thresholds[eer_idx]
    
    # Find thresholds for specific FARs (False Accept Rates)
    far_thresholds = {}
    for target_far in [0.001, 0.01, 0.1]:
        idx = np.argmin(np.abs(fpr - target_far))
        far_thresholds[target_far] = {
            'threshold': thresholds[idx],
            'far': fpr[idx],
            'frr': 1 - tpr[idx],
            'tar': tpr[idx]  # True Accept Rate
        }
    
    # Separate genuine and impostor scores for distribution plotting
    genuine_scores = [y_scores[i] for i in range(len(y_true)) if y_true[i] == 1]
    impostor_scores = [y_scores[i] for i in range(len(y_true)) if y_true[i] == 0]
    
    return {
        'fpr': fpr,
        'tpr': tpr,
        'thresholds': thresholds,
        'roc_auc': roc_auc,
        'eer': eer,
        'eer_threshold': eer_threshold,
        'genuine_scores': genuine_scores,
        'impostor_scores': impostor_scores,
        'far_thresholds': far_thresholds
    }

In [None]:
# Compute verification metrics
print("Computing verification metrics...")
gallery_size = 5  # Use a standard gallery size for verification metrics
verification_metrics = compute_verification_metrics(test_dict, gallery_size=gallery_size, device=device)

# Print results
print(f"ROC AUC: {verification_metrics['roc_auc']:.4f}")
print(f"EER: {verification_metrics['eer']:.4f}")
print(f"EER Threshold: {verification_metrics['eer_threshold']:.4f}")

# Print FARs and FRRs at specific thresholds
print("\nPerformance at specific FARs:")
for far, data in verification_metrics['far_thresholds'].items():
    print(f"FAR = {far:.4f}: Threshold = {data['threshold']:.4f}, FRR = {data['frr']:.4f}, TAR = {data['tar']:.4f}")

# Plot ROC curve
plt.figure(figsize=(10, 8))
plt.plot(
    verification_metrics['fpr'], 
    verification_metrics['tpr'], 
    color='darkorange',
    lw=2, 
    label=f'ROC curve (area = {verification_metrics["roc_auc"]:.4f})'
)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.savefig(f"{params['output_dir']}/roc_curve.png", dpi=300)
plt.show()

# Plot score distributions
plt.figure(figsize=(10, 8))
plt.hist(
    verification_metrics['genuine_scores'], 
    bins=50, 
    alpha=0.5, 
    color='green', 
    label='Genuine Pairs'
)
plt.hist(
    verification_metrics['impostor_scores'], 
    bins=50, 
    alpha=0.5, 
    color='red', 
    label='Impostor Pairs'
)
plt.axvline(x=verification_metrics['eer_threshold'], color='black', linestyle='--', 
            label=f'EER Threshold = {verification_metrics["eer_threshold"]:.4f}')
plt.xlabel('Similarity Score')
plt.ylabel('Frequency')
plt.title('Distribution of Similarity Scores')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"{params['output_dir']}/score_distribution.png", dpi=300)
plt.show()

# Save results to file
with open(f"{params['output_dir']}/verification_metrics.txt", 'w') as f:
    f.write(f"ROC AUC: {verification_metrics['roc_auc']:.4f}\n")
    f.write(f"EER: {verification_metrics['eer']:.4f}\n")
    f.write(f"EER Threshold: {verification_metrics['eer_threshold']:.4f}\n\n")
    
    f.write("Performance at specific FARs:\n")
    for far, data in verification_metrics['far_thresholds'].items():
        f.write(f"FAR = {far:.4f}: Threshold = {data['threshold']:.4f}, FRR = {data['frr']:.4f}, TAR = {data['tar']:.4f}\n")

## Individual Modality Analysis

Analyze the performance of individual modalities versus fusion.

In [None]:
def analyze_modalities(test_dict, model, device=device):
    """
    Analyze the performance of individual modalities vs. fusion
    
    Args:
        test_dict: Dictionary with person IDs as keys and lists of embeddings as values
        model: BiometricModel instance
        device: Device to run computations on
        
    Returns:
        Dictionary with rank-1 recognition rates for each modality
    """
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    # Modalities to test
    modalities = ['periocular', 'forehead', 'iris', 'fusion']
    
    # Dictionary to store embeddings for each modality
    modality_dict = {modality: {} for modality in modalities}
    
    # Extract embeddings for each modality
    print("Extracting embeddings for individual modalities...")
    for person_id, embeddings in tqdm(test_dict.items()):
        for modality in modalities:
            modality_dict[modality][person_id] = []
    
    # Process each person and pose
    for person_id in tqdm(test_dict.keys()):
        # For each pose (1 to 10)
        for pose_idx in range(1, 11):
            # Load images for all three modalities
            images = []
            for mod_idx, modality in enumerate(['periocular', 'forehead', 'iris']):
                img_path = Path(params['data_path']) / modality / 'test' / person_id
                if not img_path.exists():
                    continue
                    
                pose_images = [f for f in sorted(os.listdir(img_path)) if f!='.DS_Store']
                if len(pose_images) < pose_idx:
                    break
                    
                img_name = pose_images[pose_idx - 1]
                img_path = img_path / img_name
                if not img_path.exists():
                    break
                    
                img = load_image(img_path, transform).to(device)
                images.append(img)
                
            # Only proceed if all images are found
            if len(images) == 3:
                # Process each modality separately
                with torch.no_grad():
                    # Periocular
                    periocular_emb = model.periocular_cnn(images[0].unsqueeze(0))
                    periocular_emb = F.normalize(periocular_emb, dim=1).cpu().squeeze(0)
                    modality_dict['periocular'][person_id].append(periocular_emb)
                    
                    # Forehead
                    forehead_emb = model.forehead_cnn(images[1].unsqueeze(0))
                    forehead_emb = F.normalize(forehead_emb, dim=1).cpu().squeeze(0)
                    modality_dict['forehead'][person_id].append(forehead_emb)
                    
                    # Iris
                    iris_emb = model.iris_cnn(images[2].unsqueeze(0))
                    iris_emb = F.normalize(iris_emb, dim=1).cpu().squeeze(0)
                    modality_dict['iris'][person_id].append(iris_emb)
                    
                    # Fusion (all modalities combined)
                    fusion_emb = model(
                        images[0].unsqueeze(0),
                        images[1].unsqueeze(0),
                        images[2].unsqueeze(0)
                    ).cpu().squeeze(0)
                    modality_dict['fusion'][person_id].append(fusion_emb)
    
    # Evaluate rank-1 recognition for each modality
    results = {}
    gallery_size = 5
    
    print("Evaluating rank-1 recognition for each modality...")
    for modality in modalities:
        # Filter out persons with insufficient samples
        filtered_dict = {k: v for k, v in modality_dict[modality].items() if len(v) >= gallery_size + 1}
        if not filtered_dict:
            print(f"Warning: No valid samples for {modality}")
            results[modality] = 0
            continue
            
        rank1_rate, correct, total = evaluate_rank1(filtered_dict, gallery_size=gallery_size, device=device)
        results[modality] = {
            'rank1_rate': rank1_rate,
            'correct': correct,
            'total': total
        }
        print(f"{modality.capitalize()} Rank-1: {rank1_rate:.4f} ({rank1_rate*100:.2f}%) {correct}/{total}")
    
    return results

In [None]:
# Analyze individual modalities vs. fusion
print("Analyzing individual modalities vs. fusion...")
modality_results = analyze_modalities(test_dict, model, device)

# Plot comparison
plt.figure(figsize=(10, 6))
modalities = ['periocular', 'forehead', 'iris', 'fusion']
rates = [modality_results[m]['rank1_rate']*100 for m in modalities]
colors = ['blue', 'green', 'red', 'purple']

plt.bar(modalities, rates, color=colors)
plt.ylabel('Rank-1 Recognition Rate (%)')
plt.title('Performance Comparison: Individual Modalities vs. Fusion')
plt.grid(True, alpha=0.3, axis='y')

# Add text labels on top of each bar
for i, v in enumerate(rates):
    plt.text(i, v + 1, f"{v:.1f}%", ha='center')

plt.savefig(f"{params['output_dir']}/modality_comparison.png", dpi=300)
plt.show()

# Save results to file
with open(f"{params['output_dir']}/modality_results.txt", 'w') as f:
    f.write("Modality | Rank-1 Rate | Correct/Total\n")
    f.write("---------|-------------|-------------\n")
    for modality in modalities:
        r = modality_results[modality]
        f.write(f"{modality.capitalize():8s} | {r['rank1_rate']*100:10.2f}% | {r['correct']}/{r['total']}\n")

## Embedding Visualization

Visualize embeddings using t-SNE to see the clustering of identities.

In [None]:
def visualize_embeddings(test_dict, num_persons=30):
    """
    Create t-SNE visualization of embeddings
    
    Args:
        test_dict: Dictionary with person IDs as keys and lists of embeddings as values
        num_persons: Number of persons to visualize
    """
    # Collect embeddings and labels
    embeddings = []
    labels = []
    person_ids = []
    
    # Take first num_persons persons
    for i, (person_id, person_embeddings) in enumerate(test_dict.items()):
        if i >= num_persons:
            break
        embeddings.extend(person_embeddings)
        labels.extend([i] * len(person_embeddings))
        person_ids.append(person_id)
    
    # Convert to numpy arrays
    embeddings_np = torch.stack(embeddings).numpy()
    
    # Apply t-SNE
    print("Computing t-SNE projection...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
    embeddings_tsne = tsne.fit_transform(embeddings_np)
    
    # Plot
    plt.figure(figsize=(12, 10))
    
    # Use a color cycle
    colors = plt.cm.jet(np.linspace(0, 1, num_persons))
    
    for i in range(num_persons):
        indices = [j for j, label in enumerate(labels) if label == i]
        plt.scatter(
            embeddings_tsne[indices, 0], 
            embeddings_tsne[indices, 1], 
            color=colors[i],
            label=f"Person {person_ids[i]}", 
            s=30,
            alpha=0.7
        )
    
    plt.title("t-SNE Visualization of Biometric Embeddings")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.grid(True, alpha=0.3)
    plt.savefig(f"{params['output_dir']}/embeddings_tsne.png", dpi=300)
    
    # Create a version without the legend for clarity
    plt.figure(figsize=(12, 10))
    for i in range(num_persons):
        indices = [j for j, label in enumerate(labels) if label == i]
        plt.scatter(
            embeddings_tsne[indices, 0], 
            embeddings_tsne[indices, 1], 
            color=colors[i],
            s=30,
            alpha=0.7
        )
    
    plt.title("t-SNE Visualization of Biometric Embeddings")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.grid(True, alpha=0.3)
    plt.savefig(f"{params['output_dir']}/embeddings_tsne_no_legend.png", dpi=300)
    
    print("t-SNE visualization saved")
    return embeddings_tsne, labels, person_ids

In [None]:
# Visualize embeddings with t-SNE
print("Creating t-SNE visualization...")
embeddings_tsne, labels, person_ids = visualize_embeddings(test_dict, num_persons=params['visualize_persons'])
plt.show()

print("Testing completed! All results have been saved to:", params['output_dir'])

## Summary of Results

Summarize all evaluation metrics.

In [None]:
# Create a summary of all results
summary = {
    'model': params['model_path'],
    'rank1': {
        'best': max(r['rank1_rate'] for r in results),
        'gallery_size': results[np.argmax([r['rank1_rate'] for r in results])]['gallery_size']
    },
    'verification': {
        'roc_auc': verification_metrics['roc_auc'],
        'eer': verification_metrics['eer']
    },
    'modalities': {
        m: modality_results[m]['rank1_rate'] for m in ['periocular', 'forehead', 'iris', 'fusion']
    }
}

# Print summary
print("\n===== SUMMARY OF RESULTS =====")
print(f"Model: {summary['model']}")
print(f"Best Rank-1 Recognition Rate: {summary['rank1']['best']*100:.2f}% (Gallery Size: {summary['rank1']['gallery_size']})")
print(f"ROC AUC: {summary['verification']['roc_auc']:.4f}")
print(f"Equal Error Rate (EER): {summary['verification']['eer']*100:.2f}%")
print("\nModality Performance (Rank-1):")
for modality, rate in summary['modalities'].items():
    print(f"  {modality.capitalize()}: {rate*100:.2f}%")
print("===============================")

# Save summary to file
with open(f"{params['output_dir']}/summary.txt", 'w') as f:
    f.write("===== SUMMARY OF RESULTS =====\n")
    f.write(f"Model: {summary['model']}\n")
    f.write(f"Best Rank-1 Recognition Rate: {summary['rank1']['best']*100:.2f}% (Gallery Size: {summary['rank1']['gallery_size']})\n")
    f.write(f"ROC AUC: {summary['verification']['roc_auc']:.4f}\n")
    f.write(f"Equal Error Rate (EER): {summary['verification']['eer']*100:.2f}%\n\n")
    f.write("Modality Performance (Rank-1):\n")
    for modality, rate in summary['modalities'].items():
        f.write(f"  {modality.capitalize()}: {rate*100:.2f}%\n")
    f.write("===============================\n")