## Visualize benchmark results in a table

In [4]:
import os
import numpy as np
import functools
import pandas as pd
from matplotlib import colors
import matplotlib.pyplot as plt
from collections import defaultdict
from lm_polygraph.utils.manager import UEManager

def b_g(s, A, cmap='PuBu', low=0.8, high=0):
    # Pass the columns from Dataframe A
    i = A.columns.tolist().index(s.name)
    a = A.values[:,i].copy()
    if s.name[-1] in ['rcc-auc']:
        a = -a
    if s.name[0] == 'BARTScoreSeq-rh':
        a = -a
    rng = a.max() - a.min()
    norm = colors.Normalize(a.min() - (rng * low),
                        a.max() + (rng * high))
    normed = norm(a)
    c = [colors.rgb2hex(x) for x in plt.colormaps[cmap](normed)]
    return ['background-color: %s' % color for color in c]

def get_array(dfs, row, col):
    vals = []
    for df in dfs:
        if row in df.index and col in df.columns:
            vals.append(df.loc[row, col])
    return vals

def pretty_plot(dataset_name, man_files, except_metrics=[], except_gen=['BARTScoreSeq-rh'], level='sequence'):
    dfs = []
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
        man_files = [man_files]
    columns = []
    for group_name, group_files in zip(dataset_name, man_files):
        gen_metrics = None
        for f in group_files:
            man = UEManager.load(f)
            estimators = [e for (l, e) in man.estimations.keys() if l == level]
            gen_metrics = list(set([(group_name, gen_name, m_name)
               for (l, e_name, gen_name, m_name) in man.metrics
               if l == level and (m_name not in except_metrics) and (gen_name not in except_gen)]))
            gen_metrics.sort()
            df = {k: {} for k in gen_metrics}
            for (l, e_name, gen_name, m_name), value in man.metrics.items():
                if l == level and (m_name not in except_metrics) and (gen_name not in except_gen):
                    df[group_name, gen_name, m_name][e_name] = value
            for k in gen_metrics:
                df[k] = [df[k][e] for e in estimators]
            df = pd.DataFrame(data=df, index=[e for e in estimators])
            df = df.reindex(columns=gen_metrics)
            dfs.append(df)
        print('Will measure variance using', len(group_files), 'seeds')
        columns += gen_metrics
    assert(len(dfs) > 0)
    index = dfs[0].index
    mean, total = defaultdict(lambda: defaultdict(int)), defaultdict(lambda: defaultdict(int))
    for col in columns:
        for row in index:
            vals = get_array(dfs, row, col)
            mean[row][col] = -np.mean(vals)
            total[row][col] = '{:.2f} ± {:.2f}'.format(np.mean(vals).item() * 100, np.std(vals).item() * 100)
    
    total_df = pd.DataFrame([[total[row][col] for col in columns] for row in index],
                            index=index, columns=pd.MultiIndex.from_tuples(columns))
    mean_df = pd.DataFrame([[mean[row][col] for col in columns] for row in index],
                           index=index, columns=pd.MultiIndex.from_tuples(columns))
    
    s = total_df.style.apply(functools.partial(b_g, A=mean_df, cmap='Greens'), axis=0)
    s.set_table_styles([{  # for row hover use <tr> instead of <td>
        'selector': 'td:hover',
        'props': [('background-color', '#ffffb3')]
    }, {
        'selector': '.index_name',
        'props': 'font-style: italic; color: darkgrey; font-weight:normal;'
    }])
    s.set_table_styles({
        columns[i]: [{'selector': 'th', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)},
                     {'selector': 'td', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)}]
        for i in range(1, len(columns)) if i == 0 or columns[i][1] != columns[i - 1][1]
    }, overwrite=False, axis=0)
    return s

In [6]:
# visualize results in a table
pretty_plot(
    'HotpotQA, Llama3.2-3b',
    # outputs generated by scripts/polygraph_eval benchmark
    # provide several seeds to calculate variance
    ["../workdir/output/qa/{'path': 'meta-llama/Llama-3.2-3B-Instruct', 'ensemble': False, 'mc': False, 'mc_seeds': None, 'dropout_rate': None, 'type': 'CausalLM', 'path_to_load_script': 'model/default_causal.py', 'load_model_args': {'device_map': 'auto'}, 'load_tokenizer_args': {}}/['denis1699/hotpot_cot']/2025-05-06/09-26-59/ue_manager_seed1"])

Will measure variance using 1 seeds


Unnamed: 0_level_0,"HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b"
Unnamed: 0_level_1,Accuracy,Accuracy,Accuracy,Accuracy,BLEU,BLEU,BLEU,BLEU,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized
MaximumSequenceProbability,29.89 ± 0.00,36.33 ± 0.00,-26.36 ± 0.00,-27.21 ± 0.00,30.36 ± 0.00,37.28 ± 0.00,-28.28 ± 0.00,-29.70 ± 0.00,30.58 ± 0.00,37.71 ± 0.00,-29.10 ± 0.00,-30.83 ± 0.00,22.75 ± 0.00,31.47 ± 0.00,9.68 ± 0.00,-22.90 ± 0.00,30.58 ± 0.00,37.71 ± 0.00,-29.10 ± 0.00,-30.83 ± 0.00
Perplexity,32.30 ± 0.00,32.06 ± 0.00,-57.82 ± 0.00,-20.45 ± 0.00,32.44 ± 0.00,32.34 ± 0.00,-63.40 ± 0.00,-23.85 ± 0.00,32.50 ± 0.00,32.47 ± 0.00,-65.81 ± 0.00,-25.41 ± 0.00,20.61 ± 0.00,26.43 ± 0.00,-49.48 ± 0.00,-29.20 ± 0.00,32.50 ± 0.00,32.47 ± 0.00,-65.81 ± 0.00,-25.41 ± 0.00
MeanTokenEntropy,28.05 ± 0.00,30.57 ± 0.00,-68.74 ± 0.00,-32.35 ± 0.00,28.20 ± 0.00,30.85 ± 0.00,-73.94 ± 0.00,-35.80 ± 0.00,28.26 ± 0.00,30.98 ± 0.00,-76.18 ± 0.00,-37.37 ± 0.00,18.79 ± 0.00,22.50 ± 0.00,-95.73 ± 0.00,-34.54 ± 0.00,28.26 ± 0.00,30.98 ± 0.00,-76.18 ± 0.00,-37.37 ± 0.00
MeanPointwiseMutualInformation,48.88 ± 0.00,34.89 ± 0.00,-36.93 ± 0.00,26.12 ± 0.00,49.36 ± 0.00,35.85 ± 0.00,-38.47 ± 0.00,23.81 ± 0.00,49.57 ± 0.00,36.28 ± 0.00,-39.14 ± 0.00,22.76 ± 0.00,32.91 ± 0.00,28.10 ± 0.00,-29.89 ± 0.00,7.02 ± 0.00,49.57 ± 0.00,36.28 ± 0.00,-39.14 ± 0.00,22.76 ± 0.00
MeanConditionalPointwiseMutualInformation,49.75 ± 0.00,40.65 ± 0.00,5.44 ± 0.00,28.55 ± 0.00,50.80 ± 0.00,42.49 ± 0.00,8.73 ± 0.00,27.88 ± 0.00,51.28 ± 0.00,43.33 ± 0.00,10.14 ± 0.00,27.58 ± 0.00,23.43 ± 0.00,29.61 ± 0.00,-12.08 ± 0.00,-20.89 ± 0.00,51.28 ± 0.00,43.33 ± 0.00,10.14 ± 0.00,27.58 ± 0.00
PTrue,51.65 ± 0.00,48.98 ± 0.00,66.76 ± 0.00,33.89 ± 0.00,52.22 ± 0.00,50.13 ± 0.00,63.00 ± 0.00,31.88 ± 0.00,52.48 ± 0.00,50.65 ± 0.00,61.38 ± 0.00,30.96 ± 0.00,39.36 ± 0.00,32.03 ± 0.00,16.36 ± 0.00,25.99 ± 0.00,52.48 ± 0.00,50.65 ± 0.00,61.38 ± 0.00,30.96 ± 0.00
PTrueSampling,33.60 ± 0.00,47.49 ± 0.00,55.74 ± 0.00,-16.78 ± 0.00,33.67 ± 0.00,47.62 ± 0.00,45.18 ± 0.00,-20.38 ± 0.00,33.70 ± 0.00,47.69 ± 0.00,40.63 ± 0.00,-22.02 ± 0.00,30.19 ± 0.00,33.55 ± 0.00,34.16 ± 0.00,-0.99 ± 0.00,33.70 ± 0.00,47.69 ± 0.00,40.63 ± 0.00,-22.02 ± 0.00
MonteCarloSequenceEntropy,38.71 ± 0.00,42.34 ± 0.00,17.84 ± 0.00,-2.44 ± 0.00,40.52 ± 0.00,44.18 ± 0.00,20.70 ± 0.00,-1.09 ± 0.00,41.34 ± 0.00,45.01 ± 0.00,21.93 ± 0.00,-0.48 ± 0.00,24.48 ± 0.00,35.22 ± 0.00,53.75 ± 0.00,-17.81 ± 0.00,41.34 ± 0.00,45.01 ± 0.00,21.93 ± 0.00,-0.48 ± 0.00
MonteCarloNormalizedSequenceEntropy,52.69 ± 0.00,41.50 ± 0.00,11.71 ± 0.00,36.80 ± 0.00,54.77 ± 0.00,43.34 ± 0.00,14.78 ± 0.00,39.06 ± 0.00,55.72 ± 0.00,44.18 ± 0.00,16.10 ± 0.00,40.08 ± 0.00,37.30 ± 0.00,35.22 ± 0.00,53.75 ± 0.00,19.91 ± 0.00,55.72 ± 0.00,44.18 ± 0.00,16.10 ± 0.00,40.08 ± 0.00
EigenScore,69.41 ± 0.00,48.40 ± 0.00,62.43 ± 0.00,83.77 ± 0.00,70.09 ± 0.00,49.76 ± 0.00,60.33 ± 0.00,82.22 ± 0.00,70.40 ± 0.00,50.37 ± 0.00,59.42 ± 0.00,81.52 ± 0.00,56.09 ± 0.00,33.55 ± 0.00,34.16 ± 0.00,75.20 ± 0.00,70.40 ± 0.00,50.37 ± 0.00,59.42 ± 0.00,81.52 ± 0.00
