In [1]:
import re
import string
import pandas as pd
from collections import Counter
from typing import List, Tuple, Optional, Union
from logging import Logger

class MetricsHelper:
    """
    Author: -
    Date: 2025-06-01
    """
    @classmethod
    def normalize_text(cls, s: Optional[str]) -> str:
        if s is None:
            return ""
        # Convert non-string inputs (like bool, int, float) to string
        if not isinstance(s, str):
            s = str(s)
        s = s.lower()
        s = re.sub(r'\b(a|an|the)\b', ' ', s)
        s = ''.join(ch for ch in s if ch not in string.punctuation)
        return ' '.join(s.split())

    @classmethod
    def compute_em(cls, pred: str, truth: str) -> int:
        return int(cls.normalize_text(pred) == cls.normalize_text(truth))

    @classmethod
    def compute_f1(cls, pred: str, truth: str) -> float:
        tp = Counter(cls.normalize_text(pred).split())
        gt = Counter(cls.normalize_text(truth).split())
        common = tp & gt
        num_same = sum(common.values())
        if num_same == 0:
            return 0.0
        precision = num_same / sum(tp.values())
        recall = num_same / sum(gt.values())
        return 2 * precision * recall / (precision + recall)

    @classmethod
    def max_em_f1(cls, pred: str, truths: Union[str, List[str]]) -> Tuple[int, float]:
        truth_list = truths if isinstance(truths, list) else [truths]
        em_scores = [cls.compute_em(pred, t) for t in truth_list]
        f1_scores = [cls.compute_f1(pred, t) for t in truth_list]
        return max(em_scores), max(f1_scores)

    @classmethod
    def compute_accuracy(cls, preds: List[str], truths: List[str]) -> float:
        correct = sum(cls.normalize_text(p) == cls.normalize_text(t) for p, t in zip(preds, truths))
        total = len(preds)
        return correct / total if total > 0 else 0.0

    @classmethod
    def evaluate(cls, results_df: pd.DataFrame, logger:Logger, column_name: str = 'final_answer',log: bool=False ) -> None:
        n = len(results_df)
        ems = []
        f1s = []
        preds = results_df[column_name].astype(str).tolist()
        truths = results_df['g_t'].astype(str).tolist()
        accuracy = cls.compute_accuracy(preds, truths)
        for _, row in results_df.iterrows():
            pred = row[column_name] or ""
            truths_list = row['g_t'] if isinstance(row['g_t'], list) else [row['g_t']]
            em, f1 = cls.max_em_f1(pred, truths_list)
            ems.append(em)
            f1s.append(f1)
        total_em = sum(ems)
        avg_f1 = sum(f1s) / n if n else 0.0
        if log:
            logger.info(f"Evaluated {n} examples")
            logger.info(f"Exact Match: {total_em}/{n} = {total_em/n:.2%}")
            logger.info(f"Average F1 Score: {avg_f1:.2%}")
            logger.info(f"Accuracy: {accuracy:.2%}")
        else:
            logger.info(f"Exact Match: {total_em/n:.2%}")
            logger.info(f"Average F1 Score: {avg_f1:.2%}")
            logger.info(f"Accuracy: {accuracy:.2%}")