In [1]:
import math
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import brier_score_loss
from torchmetrics.classification import BinaryCalibrationError


class CalibrationMetrics:
    """
    Metric utilities for confidence calibration & selective prediction.
    Pass the column names you want (prob_col, correct_col) to every method.
    """

    # ---------- helpers ----------
    @classmethod
    def _valid_mask(cls, df: pd.DataFrame, prob_col: str, correct_col: str) -> np.ndarray:
        """Valid rows have finite prob and finite correctness."""
        return np.isfinite(pd.to_numeric(df[prob_col], errors="coerce")) & \
               np.isfinite(pd.to_numeric(df[correct_col], errors="coerce"))

    @classmethod
    def _y_true_prob(cls, df: pd.DataFrame, prob_col: str, correct_col: str):
        """Return (y_true, y_prob) arrays filtered to valid rows with prob clipped to [0,1]."""
        m = cls._valid_mask(df, prob_col, correct_col)
        y_true = df.loc[m, correct_col].astype(int).to_numpy()
        y_prob = pd.to_numeric(df.loc[m, prob_col], errors="coerce").clip(0.0, 1.0).to_numpy()
        return y_true, y_prob

    
    @classmethod
    def normalized_selective_auc(cls, df: pd.DataFrame, prob_col: str, correct_col: str, anchor: bool = False):
        """
        Normalized selective AUC:
          - 0.0 = random ranking baseline
          - 1.0 = perfect ranking (all correct before all incorrect)
        Returns (normalized_auc, raw_auc, coverage, accuracy_curve).
        If anchor=True, the curve is anchored at (coverage=0, accuracy=1); the random baseline area becomes (1+acc_full)/2.
        """
        y_true, y_prob = cls._y_true_prob(df, prob_col, correct_col)
        n = len(y_true)
        if n == 0:
            return math.nan, math.nan, np.array([]), np.array([])
    
        # Raw SAUC on current probs (trapz)
        idx = np.argsort(-y_prob)
        y_sorted = y_true[idx].astype(int)
        coverage = np.arange(1, n + 1) / n
        accuracy_curve = np.cumsum(y_sorted) / np.arange(1, n + 1)
        if anchor:
            cov_a = np.concatenate(([0.0], coverage))
            acc_a = np.concatenate(([1.0], accuracy_curve))
        else:
            cov_a, acc_a = coverage, accuracy_curve
        auc_raw = float(np.trapz(acc_a, cov_a))
    
        # Perfect-ranking area
        y_perfect = np.sort(y_true)[::-1]
        acc_p = np.cumsum(y_perfect) / np.arange(1, n + 1)
        if anchor:
            cov_p = np.concatenate(([0.0], coverage))
            acc_p = np.concatenate(([1.0], acc_p))
        else:
            cov_p = coverage
        auc_perfect = float(np.trapz(acc_p, cov_p))
    
        # Random baseline area
        acc_full = float(np.mean(y_true))
        auc_random = 0.5 * (1.0 + acc_full) if anchor else acc_full
    
        denom = (auc_perfect - auc_random)
        n_auc = 0.0 if denom == 0.0 else (auc_raw - auc_random) / denom
    
        return float(n_auc), float(auc_raw), coverage, accuracy_curve

    
    # ---------- metrics ----------
    @classmethod
    def selective_auc(cls, df: pd.DataFrame, prob_col: str, correct_col: str):
        """
        AUC of the selective accuracyâ€“coverage curve.
        Returns (auc, coverage_array, accuracy_curve_array).
        """
        y_true, y_prob = cls._y_true_prob(df, prob_col, correct_col)
        if len(y_true) == 0:
            return math.nan, np.array([]), np.array([])
        idx = np.argsort(-y_prob)  # descending by confidence
        y_sorted = y_true[idx]
        n = len(y_sorted)
        coverage = np.arange(1, n + 1) / n
        accuracy_curve = np.cumsum(y_sorted) / np.arange(1, n + 1)
        auc = np.trapz(accuracy_curve, coverage)
        return float(auc), coverage, accuracy_curve

    @classmethod
    def ece_torchmetrics_binary(
        cls, df: pd.DataFrame, prob_col: str, correct_col: str, n_bins: int = 10, norm: str = "l2"
    ) -> float:
        """Expected Calibration Error via torchmetrics (norm='l2' squared, 'l1' absolute)."""
        y_true, y_prob = cls._y_true_prob(df, prob_col, correct_col)
        if len(y_true) == 0:
            return math.nan
        eps = 1e-12
        y_prob_t = torch.tensor(np.clip(y_prob, eps, 1 - eps), dtype=torch.float32)
        y_true_t = torch.tensor(y_true.astype(int), dtype=torch.long)
        metric = BinaryCalibrationError(n_bins=n_bins, norm=norm)
        return float(metric(y_prob_t, y_true_t))

    @classmethod
    def brier(cls, df: pd.DataFrame, prob_col: str, correct_col: str) -> float:
        """Brier score (MSE between correctness and confidence)."""
        y_true, y_prob = cls._y_true_prob(df, prob_col, correct_col)
        if len(y_true) == 0:
            return math.nan
        return float(brier_score_loss(y_true, y_prob))

    @classmethod
    def reliability_table(cls, df: pd.DataFrame, prob_col: str, correct_col: str, n_bins: int = 10) -> pd.DataFrame:
        """
        Returns a reliability table using your VisuaCalibration helper.
        Filters invalid rows before plotting/tabulating.
        """
        m = cls._valid_mask(df, prob_col, correct_col)
        if not m.any():
            return pd.DataFrame(columns=["bin_lower", "bin_upper", "bin_center", "count", "accuracy", "confidence"])

    # ---------- one-call summary ----------
    @classmethod
    def summarize(
        cls,
        df: pd.DataFrame,
        prob_col: str,
        correct_col: str,
        n_bins: int = 10,
        norm: str = "l2",
        include_curves: bool = True,
        include_table: bool = True,
    ) -> dict:
        """
        Compute a summary dict for arbitrary columns.
        Keys: 'selective_auc', 'ece', 'brier', and optionally 'coverage', 'accuracy_curve', 'reliability_table'.
        """
        auc, cov, acc = cls.selective_auc(df, prob_col=prob_col, correct_col=correct_col)
        ece = cls.ece_torchmetrics_binary(df, prob_col=prob_col, correct_col=correct_col, n_bins=n_bins, norm=norm)
        brier = cls.brier(df, prob_col=prob_col, correct_col=correct_col)
        out = {"selective_auc": auc, "ece": ece, "brier": brier}
        if include_curves:
            out.update({"coverage": cov, "accuracy_curve": acc})
        if include_table:
            out["reliability_table"] = cls.reliability_table(df, prob_col=prob_col, correct_col=correct_col, n_bins=n_bins)
        return out

    # ---------- optional: batch over many (prob, correct) pairs ----------
    @classmethod
    def summarize_many(cls, df: pd.DataFrame, pairs: list[tuple[str, str]], n_bins: int = 10, norm: str = "l2") -> dict:
        """
        Compute summaries for multiple (prob_col, correct_col) pairs.
        Returns a dict keyed by '<prob_col>/<correct_col>' -> summary dict.
        """
        results = {}
        for prob_col, correct_col in pairs:
            key = f"{prob_col}/{correct_col}"
            results[key] = cls.summarize(df, prob_col, correct_col, n_bins=n_bins, norm=norm, include_curves=False, include_table=False)
        return results
