# Model Evaluation on Benchmark Datasets

This notebook evaluates the demographic analysis models on standard benchmark datasets:
- UTKFace
- FairFace
- AffectNet

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm

from src.core.pipeline import DemographicPipeline
from src.models.age_estimator import AgeEstimator
from src.models.gender_classifier import GenderClassifier

%matplotlib inline
sns.set_style('whitegrid')

## 1. Age Estimation Evaluation

In [None]:
def evaluate_age_estimation(model, test_data):
    """Evaluate age estimation performance"""
    predictions = []
    ground_truth = []
    
    for img, true_age in tqdm(test_data):
        result = model.predict(img)
        predictions.append(result['age'])
        ground_truth.append(true_age)
    
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    mae = np.mean(np.abs(predictions - ground_truth))
    rmse = np.sqrt(np.mean((predictions - ground_truth) ** 2))
    cs_5 = np.mean(np.abs(predictions - ground_truth) <= 5) * 100
    cs_10 = np.mean(np.abs(predictions - ground_truth) <= 10) * 100
    
    print(f"Age Estimation Metrics:")
    print(f"  MAE: {mae:.2f} years")
    print(f"  RMSE: {rmse:.2f} years")
    print(f"  CS@5: {cs_5:.2f}%")
    print(f"  CS@10: {cs_10:.2f}%")
    
    return {
        'mae': mae,
        'rmse': rmse,
        'cs_5': cs_5,
        'cs_10': cs_10,
        'predictions': predictions,
        'ground_truth': ground_truth
    }

## 2. Gender Classification Evaluation

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def evaluate_gender_classification(model, test_data):
    """Evaluate gender classification performance"""
    predictions = []
    ground_truth = []
    
    for img, true_gender in tqdm(test_data):
        result = model.predict(img)
        predictions.append(result['gender'])
        ground_truth.append(true_gender)
    
    accuracy = accuracy_score(ground_truth, predictions) * 100
    precision = precision_score(ground_truth, predictions, average='binary', pos_label='male') * 100
    recall = recall_score(ground_truth, predictions, average='binary', pos_label='male') * 100
    f1 = f1_score(ground_truth, predictions, average='binary', pos_label='male') * 100
    
    print(f"Gender Classification Metrics:")
    print(f"  Accuracy: {accuracy:.2f}%")
    print(f"  Precision: {precision:.2f}%")
    print(f"  Recall: {recall:.2f}%")
    print(f"  F1-Score: {f1:.2f}%")
    
    cm = confusion_matrix(ground_truth, predictions)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Gender Classification Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm
    }

## 3. Visualization

In [None]:
def plot_age_predictions(predictions, ground_truth):
    """Plot age predictions vs ground truth"""
    plt.figure(figsize=(10, 6))
    plt.scatter(ground_truth, predictions, alpha=0.5)
    plt.plot([0, 100], [0, 100], 'r--', label='Perfect Prediction')
    plt.xlabel('Ground Truth Age')
    plt.ylabel('Predicted Age')
    plt.title('Age Prediction: Ground Truth vs Predictions')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    plt.figure(figsize=(10, 6))
    errors = predictions - ground_truth
    plt.hist(errors, bins=50, edgecolor='black')
    plt.xlabel('Prediction Error (years)')
    plt.ylabel('Frequency')
    plt.title('Age Prediction Error Distribution')
    plt.axvline(x=0, color='r', linestyle='--', label='Zero Error')
    plt.legend()
    plt.show()

## 4. Run Evaluation

**Note:** You need to download and prepare the benchmark datasets before running this section.

In [None]:
# Example usage (uncomment and modify with your dataset path)
# age_estimator = AgeEstimator(device='cuda', ensemble=True)
# gender_classifier = GenderClassifier(device='cuda')

# Load your test data
# test_data_age = load_utkface_dataset('path/to/utkface')
# test_data_gender = load_fairface_dataset('path/to/fairface')

# age_results = evaluate_age_estimation(age_estimator, test_data_age)
# gender_results = evaluate_gender_classification(gender_classifier, test_data_gender)

# plot_age_predictions(age_results['predictions'], age_results['ground_truth'])