## Visualize benchmark results in a table

In [2]:
import os
import numpy as np
import functools
import pandas as pd
from matplotlib import colors
import matplotlib.pyplot as plt
from collections import defaultdict
from lm_polygraph.utils.manager import UEManager, _delete_nans

def b_g(s, A, cmap='PuBu', low=0.8, high=0):
    # Pass the columns from Dataframe A
    i = A.columns.tolist().index(s.name)
    a = A.values[:,i].copy()
    if s.name[-1] in ['rcc-auc']:
        a = -a
    if s.name[0] == 'BARTScoreSeq-rh':
        a = -a
    rng = a.max() - a.min()
    norm = colors.Normalize(a.min() - (rng * low),
                        a.max() + (rng * high))
    normed = norm(a)
    c = [colors.rgb2hex(x) for x in plt.colormaps[cmap](normed)]
    return ['background-color: %s' % color for color in c]

def get_array(dfs, row, col):
    vals = []
    for df in dfs:
        if row in df.index and col in df.columns:
            vals.append(df.loc[row, col])
    return vals

def pretty_plot(dataset_name, man_files, except_metrics=[], except_gen=['BARTScoreSeq-rh'], level='sequence'):
    dfs = []
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
        man_files = [man_files]
    columns = []
    for group_name, group_files in zip(dataset_name, man_files):
        gen_metrics = None
        for f in group_files:
            man = UEManager.load(f)
            # print(man.gen_metrics)
            # print({str(k) : v for k, v in man.gen_metrics})
            print({str(k) : np.mean(v) for k, v in man.gen_metrics.items()})
            estimators = [e for (l, e) in man.estimations.keys() if l == level]
            gen_metrics = list(set([(group_name, gen_name, m_name)
               for (l, e_name, gen_name, m_name) in man.metrics
               if l == level and (m_name not in except_metrics) and (gen_name not in except_gen)]))
            gen_metrics.sort()
            df = {k: {} for k in gen_metrics}
            for (l, e_name, gen_name, m_name), value in man.metrics.items():
                if l == level and (m_name not in except_metrics) and (gen_name not in except_gen):
                    df[group_name, gen_name, m_name][e_name] = value
            for k in gen_metrics:
                df[k] = [df[k][e] for e in estimators]
            df = pd.DataFrame(data=df, index=[e for e in estimators])
            df = df.reindex(columns=gen_metrics)
            dfs.append(df)
        print('Will measure variance using', len(group_files), 'seeds')
        columns += gen_metrics
    assert(len(dfs) > 0)
    index = dfs[0].index
    mean, total = defaultdict(lambda: defaultdict(int)), defaultdict(lambda: defaultdict(int))
    for col in columns:
        for row in index:
            vals = get_array(dfs, row, col)
            mean[row][col] = -np.mean(vals)
            # total[row][col] = '{:.2f} ± {:.2f}'.format(np.mean(vals).item() * 100, np.std(vals).item() * 100)
            total[row][col] = '{:.2f}'.format(np.mean(vals).item() * 100)
    
    total_df = pd.DataFrame([[total[row][col] for col in columns] for row in index],
                            index=index, columns=pd.MultiIndex.from_tuples(columns))
    mean_df = pd.DataFrame([[mean[row][col] for col in columns] for row in index],
                           index=index, columns=pd.MultiIndex.from_tuples(columns))
    
    s = total_df.style.apply(functools.partial(b_g, A=mean_df, cmap='Greens'), axis=0)
    s.set_table_styles([{  # for row hover use <tr> instead of <td>
        'selector': 'td:hover',
        'props': [('background-color', '#ffffb3')]
    }, {
        'selector': '.index_name',
        'props': 'font-style: italic; color: darkgrey; font-weight:normal;'
    }])
    s.set_table_styles({
        columns[i]: [{'selector': 'th', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)},
                     {'selector': 'td', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)}]
        for i in range(1, len(columns)) if i == 0 or columns[i][1] != columns[i - 1][1]
    }, overwrite=False, axis=0)
    return s

In [3]:
# visualize results in a table
s = pretty_plot(
    'gsm8k, Llama3.1-8b',
    # outputs generated by scripts/polygraph_eval benchmark
    # provide several seeds to calculate variance
    ["../workdir/output/qa/{'path': 'meta-llama/Llama-3.1-8b', 'ensemble': False, 'mc': False, 'mc_seeds': None, 'dropout_rate': None, 'type': 'CausalLM', 'path_to_load_script': 'model/llama-3.1-8b-quantized.py', 'load_model_args': {'device_map': 'auto'}, 'load_tokenizer_args': {}}/['denis1699/gsm8k_reasoning']/2025-06-08/14-12-23/ue_manager_seed1"])

{"('sequence', 'AccuracyReasoning')": 0.5}
Will measure variance using 1 seeds


In [4]:
s

Unnamed: 0_level_0,"gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b"
Unnamed: 0_level_1,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning
Unnamed: 0_level_2,prr,prr_0.5,prr_0.5_normalized,prr_normalized,roc-auc,roc-auc_normalized
MaximumSequenceProbability,57.98,53.77,20.11,23.08,39.44,21.12
Perplexity,49.49,51.33,7.14,-1.62,47.96,4.08
MeanTokenEntropy,59.52,54.59,24.47,27.55,37.08,25.84
ProbasMinWithCoT,49.79,51.34,7.2,-0.77,48.4,3.2
StepsMinProb,42.52,48.97,-5.4,-21.92,54.36,-8.72
StepsAvgProb,49.45,49.3,-3.62,-1.76,50.84,-1.68
StepsMaxProb,56.4,49.33,-3.5,18.49,47.56,4.88


In [21]:
s.to_html("table.html")

In [19]:
s

Unnamed: 0_level_0,"gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b","gsm8k, Llama3.1-8b"
Unnamed: 0_level_1,Accuracy,Accuracy,Accuracy,Accuracy,Accuracy,Accuracy,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning,AccuracyReasoning
Unnamed: 0_level_2,prr,prr_0.5,prr_0.5_normalized,prr_normalized,roc-auc,roc-auc_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,roc-auc,roc-auc_normalized
MaximumSequenceProbability,51.38,48.03,-4.82,7.32,49.38,1.76,50.03,49.45,-2.83,-0.05,51.36,-2.72
Perplexity,51.29,48.71,-1.12,7.07,48.42,3.67,50.72,50.37,2.07,1.94,49.28,1.44
MeanTokenEntropy,53.61,49.88,5.18,13.72,45.82,8.84,52.41,51.37,7.38,6.87,47.24,5.52
ProbasMinWithCoT,49.05,51.55,14.21,0.66,46.22,8.04,49.79,51.34,7.2,-0.77,48.4,3.2


In [6]:
man = UEManager.load("../workdir/output/qa/{'path': 'meta-llama/Llama-3.1-8b', 'ensemble': False, 'mc': False, 'mc_seeds': None, 'dropout_rate': None, 'type': 'CausalLM', 'path_to_load_script': 'model/llama-3.1-8b-quantized.py', 'load_model_args': {'device_map': 'auto'}, 'load_tokenizer_args': {}}/['denis1699/gsm8k_reasoning']/2025-06-02/09-01-59/ue_manager_seed1")

In [8]:
print(man.estimations.keys())

dict_keys([('sequence', 'MaximumSequenceProbability'), ('sequence', 'Perplexity'), ('sequence', 'MeanTokenEntropy'), ('sequence', 'ProbasMinWithCoT')])


In [4]:
for (gen_level, gen_name), generation_metric in man.gen_metrics.items():
    print(f"gen_level: {gen_level}, gen_name: {gen_name}")
    print(generation_metric)

gen_level: sequence, gen_name: Accuracy
[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1]
gen_level: sequence, gen_name: AccuracyReasoning
[1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1]


In [8]:
man

<lm_polygraph.utils.manager.UEManager at 0x7f54a4a0ef90>

In [14]:
from lm_polygraph.ue_metrics import *
from lm_polygraph.ue_metrics.ue_metric import (
    get_random_scores,
    normalize_metric,
)

In [36]:
for (gen_level, gen_name), generation_metric in man.gen_metrics.items():
    print(f"Gen. metric: {gen_name}")
    for ue_metric in [PredictionRejectionArea(), PredictionRejectionArea(max_rejection=0.5), ROCAUC()]:
        print(f"UE metric: {ue_metric}")
        oracle_score_all = ue_metric(
            -np.array(generation_metric), np.array(generation_metric)
        )
        random_score_all = get_random_scores(
            ue_metric, np.array(generation_metric)
        )
        print(f"Oracle: {oracle_score_all}")
        print(f"Random: {random_score_all}")
        for (e_level, e_name), estimator_values in man.estimations.items():
            print(f"\nEstimator: {e_name}")
            ue, metric = _delete_nans(estimator_values, generation_metric)
            ue_metric_val = ue_metric(ue, metric)
            print(f"{ue_metric}: {ue_metric_val}")
            ue_metric_val_normalized = normalize_metric(ue_metric_val, oracle_score_all, random_score_all)
            print(f"{ue_metric}_normalized: {ue_metric_val_normalized}")
        print("\n")

Gen. metric: Accuracy
UE metric: prr
Oracle: 0.8370043678619956
Random: 0.48821631197200654

Estimator: MaximumSequenceProbability
prr: 0.5137612488451946
prr_normalized: 0.07323913890344097

Estimator: Perplexity
prr: 0.5128699611662411
prr_normalized: 0.07068375415358404

Estimator: MeanTokenEntropy
prr: 0.5360746893352977
prr_normalized: 0.13721334935387264

Estimator: ProbasMinWithCoT
prr: 0.4905028917755164
prr_normalized: 0.00655578585589837


UE metric: prr_0.5
Oracle: 0.6744087357239913
Random: 0.4891973259163078

Estimator: MaximumSequenceProbability
prr_0.5: 0.4802610989516802
prr_0.5_normalized: -0.04824879295453065

Estimator: Perplexity
prr_0.5: 0.487126467180894
prr_0.5_normalized: -0.011181053789095128

Estimator: MeanTokenEntropy
prr_0.5: 0.49878368845074994
prr_0.5_normalized: 0.05175902793675742

Estimator: ProbasMinWithCoT
prr_0.5: 0.5155070432897816
prr_0.5_normalized: 0.14205235736174557


UE metric: roc-auc
Oracle: 0.0
Random: 0.5026202480992397

Estimator: Maximu

In [None]:
for (gen_level, gen_name), generation_metric in man.gen_metrics.items():
    for ue_metric in man.ue_metrics:
        log.info(f"Metric: {ue_metric}")

        oracle_score_all = ue_metric(
            -np.array(generation_metric), np.array(generation_metric)
        )
        random_score_all = get_random_scores(
            ue_metric, np.array(generation_metric)
        )
        for (e_level, e_name), estimator_values in self.estimations.items():
            if gen_level != e_level:
                continue
            if len(estimator_values) != len(generation_metric):
                raise Exception(
                    f"Got different number of metrics for {e_name} and {gen_name}: "
                    f"{len(estimator_values)} and {len(generation_metric)}"
                )

            n_nans = np.sum(~np.isfinite(estimator_values))
            if n_nans > 0:
                log.warning(f"We got {n_nans} nans in {e_name} estimator.")

            n_nans = np.sum(~np.isfinite(generation_metric))
            if n_nans > 0:
                log.warning(
                    f"We got {n_nans} nans in {gen_name} generation metric."
                )

            ue, metric = _delete_nans(estimator_values, generation_metric)
            if len(ue) == 0:
                self.metrics[e_level, e_name, gen_name, str(ue_metric)] = np.nan
            else:
                if len(ue) != len(estimator_values):
                    oracle_score = ue_metric(-metric, metric)
                    random_score = get_random_scores(ue_metric, metric)
                else:
                    oracle_score = oracle_score_all
                    random_score = random_score_all

                ue_metric_val = ue_metric(ue, metric)
                self.metrics[e_level, e_name, gen_name, str(ue_metric)] = (
                    ue_metric_val
                )
                self.metrics[
                    e_level, e_name, gen_name, str(ue_metric) + "_normalized"
                ] = normalize_metric(ue_metric_val, oracle_score, random_score)