In [None]:
"""
Evaluation script for EMONET-FACE benchmark mappings on FERD and AffectNet.
---------------------------------------------------------------------------
This script evaluates model predictions on FERD and AffectNet datasets
using dataset-specific emotion mappings. It computes overall and per-emotion
accuracies and outputs a Markdown comparison table.

Author: LAION AI

NOTE: Please use the merge.sh split before you continue
"""

import pandas as pd
import numpy as np
from config import parameters, affectnet_mapping, ferd_mapping
from IPython.display import display, Markdown

# ----------------------------------------------------------------------
# Utility function: compute emotion accuracies
# ----------------------------------------------------------------------
def evaluate_dataset(df, gt_col, results_col, emotion_mapping):
    """
    Compute per-emotion and overall accuracy after mapping model predictions.

    Args:
        df (pd.DataFrame): Dataset with ground truth and prediction results.
        gt_col (str): Column with ground truth emotion labels.
        results_col (str): Column containing prediction dictionaries.
        emotion_mapping (dict): Mapping from fine-grained to target categories.

    Returns:
        dict: {overall_accuracy, emotion_accuracies, n, mapped_total_ratio}
    """
    n_mapped = sum(v is not None for v in emotion_mapping.values())
    total = len(emotion_mapping)
    cutoff = 1.5 / n_mapped

    def is_correct(row):
        gt = row[gt_col]
        preds = [
            (emotion_mapping.get(k), v["mean_subtracted"])
            for k, v in row[results_col].items()
            if emotion_mapping.get(k)
        ]
        if not preds or pd.isna(gt):
            return False
        emotions, vals = zip(*preds)
        sm = np.exp(vals) / np.exp(vals).sum()
        return any(e == gt and s > cutoff for e, s in zip(emotions, sm))

    df["is_correct"] = df.apply(is_correct, axis=1)
    return {
        "overall_accuracy": df["is_correct"].mean(),
        "emotion_accuracies": df.groupby(gt_col)["is_correct"].mean().to_dict(),
        "n": len(df),
        "mapped_total_ratio": f"{n_mapped} / {total}",
    }

# ----------------------------------------------------------------------
# Load and evaluate all datasets
# ----------------------------------------------------------------------
def load_datasets(params):
    """Load and prepare datasets according to the provided configuration."""
    datasets = []
    for item in params:
        df = pd.read_json(item["dataset"], lines=True)
        emotion_col = item["emotion_col"]
        df["gt_emotion_final"] = (
            df[emotion_col].apply(
                lambda x: x.split("_")[-1].split(".")[0].lower().replace("surprised", "surprise")
            )
            if "ferd.jsonl" in item["dataset"]
            else df[emotion_col].map(item.get("gt_mapping"))
        )
        datasets.append(
            (
                "FERD" if "ferd" in item["dataset"] else "AffectNet",
                df,
                "gt_emotion_final",
                item["results_col"],
            )
        )
    return datasets

datasets = load_datasets(parameters)

final_results, all_emotions = {}, set()
for name, df, gt_col, res_col in datasets:
    mapping = ferd_mapping if name == "FERD" else affectnet_mapping
    result = evaluate_dataset(df, gt_col, res_col, mapping)
    final_results[name] = result
    all_emotions.update(result["emotion_accuracies"].keys())

# ----------------------------------------------------------------------
# Generate Markdown comparison table
# ----------------------------------------------------------------------
emotions = sorted(all_emotions)
ferd, aff = final_results.get("FERD", {}), final_results.get("AffectNet", {})

def fmt(x): return f"{x:.4f}" if x is not None else "-"

md = [
    "| | **FERD Acc.** | **AffectNet Acc.** |",
    "|---|---|---|",
    f"| n | {ferd.get('n', '-')} | {aff.get('n', '-')} |",
    f"| Mapped / Total | {ferd.get('mapped_total_ratio', '-')} | {aff.get('mapped_total_ratio', '-')} |",
    "| **Emotion** | | |",
]
for emo in emotions:
    md.append(f"| {emo.capitalize()} | {fmt(ferd.get('emotion_accuracies', {}).get(emo))} | {fmt(aff.get('emotion_accuracies', {}).get(emo))} |")

md.append("| **Overall** | | |")
md.append(f"| Mean Acc. | **{fmt(ferd.get('overall_accuracy'))}** | **{fmt(aff.get('overall_accuracy'))}** |")

md_table = "\n".join(md)

print("\n=== FINAL COMPARISON TABLE ===")
display(Markdown(md_table))

with open("results_ferd_affectnet.md", "w") as f:
    f.write(md_table)



=== FINAL COMPARISON TABLE ===


| | **FERD Acc.** | **AffectNet Acc.** |
|---|---|---|
| n | 152 | 31002 |
| Mapped / Total | 40 / 40 | 40 / 40 |
| **Emotion** | | |
| Anger | 0.7368 | 0.7705 |
| Contempt | 0.3158 | 0.2875 |
| Disgust | 0.7895 | 0.4053 |
| Fear | 1.0000 | 0.6908 |
| Happy | 1.0000 | 0.9925 |
| Neutral | 0.8421 | 0.7896 |
| Sad | 0.7895 | 0.8394 |
| Surprise | 0.7895 | 0.9870 |
| **Overall** | | |
| Mean Acc. | **0.7829** | **0.7572** |