## Visualize benchmark results in a table

In [1]:
import os
import numpy as np
import functools
import pandas as pd
from matplotlib import colors
import matplotlib.pyplot as plt
from collections import defaultdict
from lm_polygraph.utils.manager import UEManager

def b_g(s, A, cmap='PuBu', low=0.8, high=0):
    # Pass the columns from Dataframe A
    i = A.columns.tolist().index(s.name)
    a = A.values[:,i].copy()
    if s.name[-1] in ['rcc-auc']:
        a = -a
    if s.name[0] == 'BARTScoreSeq-rh':
        a = -a
    rng = a.max() - a.min()
    norm = colors.Normalize(a.min() - (rng * low),
                        a.max() + (rng * high))
    normed = norm(a)
    c = [colors.rgb2hex(x) for x in plt.colormaps[cmap](normed)]
    return ['background-color: %s' % color for color in c]

def get_array(dfs, row, col):
    vals = []
    for df in dfs:
        if row in df.index and col in df.columns:
            vals.append(df.loc[row, col])
    return vals

def pretty_plot(dataset_name, man_files, except_metrics=[], except_gen=['BARTScoreSeq-rh'], level='sequence'):
    dfs = []
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
        man_files = [man_files]
    columns = []
    for group_name, group_files in zip(dataset_name, man_files):
        gen_metrics = None
        for f in group_files:
            man = UEManager.load(f)
            estimators = [e for (l, e) in man.estimations.keys() if l == level]
            gen_metrics = list(set([(group_name, gen_name, m_name)
               for (l, e_name, gen_name, m_name) in man.metrics
               if l == level and (m_name not in except_metrics) and (gen_name not in except_gen)]))
            gen_metrics.sort()
            df = {k: {} for k in gen_metrics}
            for (l, e_name, gen_name, m_name), value in man.metrics.items():
                if l == level and (m_name not in except_metrics) and (gen_name not in except_gen):
                    df[group_name, gen_name, m_name][e_name] = value
            for k in gen_metrics:
                df[k] = [df[k][e] for e in estimators]
            df = pd.DataFrame(data=df, index=[e for e in estimators])
            df = df.reindex(columns=gen_metrics)
            dfs.append(df)
        print('Will measure variance using', len(group_files), 'seeds')
        columns += gen_metrics
    assert(len(dfs) > 0)
    index = dfs[0].index
    mean, total = defaultdict(lambda: defaultdict(int)), defaultdict(lambda: defaultdict(int))
    for col in columns:
        for row in index:
            vals = get_array(dfs, row, col)
            mean[row][col] = -np.mean(vals)
            total[row][col] = '{:.2f} ± {:.2f}'.format(np.mean(vals).item() * 100, np.std(vals).item() * 100)
    
    total_df = pd.DataFrame([[total[row][col] for col in columns] for row in index],
                            index=index, columns=pd.MultiIndex.from_tuples(columns))
    mean_df = pd.DataFrame([[mean[row][col] for col in columns] for row in index],
                           index=index, columns=pd.MultiIndex.from_tuples(columns))
    
    s = total_df.style.apply(functools.partial(b_g, A=mean_df, cmap='Reds'), axis=0)
    s.set_table_styles([{  # for row hover use <tr> instead of <td>
        'selector': 'td:hover',
        'props': [('background-color', '#ffffb3')]
    }, {
        'selector': '.index_name',
        'props': 'font-style: italic; color: darkgrey; font-weight:normal;'
    }])
    s.set_table_styles({
        columns[i]: [{'selector': 'th', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)},
                     {'selector': 'td', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)}]
        for i in range(1, len(columns)) if i == 0 or columns[i][1] != columns[i - 1][1]
    }, overwrite=False, axis=0)
    return s

In [3]:
# visualize results in a table
pretty_plot(
    'HotpotQA, Llama3.2-3b',
    # outputs generated by scripts/polygraph_eval benchmark
    # provide several seeds to calculate variance
    ["../workdir/output/qa/{'path': 'meta-llama/Llama-3.2-3B-Instruct', 'ensemble': False, 'mc': False, 'mc_seeds': None, 'dropout_rate': None, 'type': 'CausalLM', 'path_to_load_script': 'model/default_causal.py', 'load_model_args': {'device_map': 'auto'}, 'load_tokenizer_args': {}}/['denis1699/hotpot_cot']/2025-05-06/06-38-32/ue_manager_seed1"])

Will measure variance using 1 seeds


Unnamed: 0_level_0,"HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b","HotpotQA, Llama3.2-3b"
Unnamed: 0_level_1,Accuracy,Accuracy,Accuracy,Accuracy,BLEU,BLEU,BLEU,BLEU,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized,prr,prr_0.5,prr_0.5_normalized,prr_normalized
MaximumSequenceProbability,29.29 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,100.00 ± 0.00,55.70 ± 0.00,41.07 ± 0.00,78.16 ± 0.00,67.37 ± 0.00,57.92 ± 0.00,45.50 ± 0.00,46.68 ± 0.00,53.28 ± 0.00,46.94 ± 0.00,32.78 ± 0.00,-10.35 ± 0.00,42.17 ± 0.00,57.92 ± 0.00,45.50 ± 0.00,46.68 ± 0.00,53.28 ± 0.00
Perplexity,14.29 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,21.71 ± 0.00,51.20 ± 0.00,41.07 ± 0.00,78.16 ± 0.00,53.99 ± 0.00,53.42 ± 0.00,45.50 ± 0.00,46.68 ± 0.00,40.13 ± 0.00,21.94 ± 0.00,32.78 ± 0.00,-10.35 ± 0.00,-37.91 ± 0.00,53.42 ± 0.00,45.50 ± 0.00,46.68 ± 0.00,40.13 ± 0.00
MeanTokenEntropy,10.96 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,4.32 ± 0.00,44.02 ± 0.00,33.38 ± 0.00,0.31 ± 0.00,32.64 ± 0.00,47.18 ± 0.00,39.69 ± 0.00,-2.72 ± 0.00,21.90 ± 0.00,13.06 ± 0.00,26.11 ± 0.00,-98.63 ± 0.00,-66.38 ± 0.00,47.18 ± 0.00,39.69 ± 0.00,-2.72 ± 0.00,21.90 ± 0.00
MeanPointwiseMutualInformation,4.79 ± 0.00,9.58 ± 0.00,-16.51 ± 0.00,-27.87 ± 0.00,36.26 ± 0.00,26.84 ± 0.00,-65.81 ± 0.00,9.55 ± 0.00,38.52 ± 0.00,31.33 ± 0.00,-73.75 ± 0.00,-3.41 ± 0.00,13.06 ± 0.00,26.11 ± 0.00,-98.63 ± 0.00,-66.38 ± 0.00,38.52 ± 0.00,31.33 ± 0.00,-73.75 ± 0.00,-3.41 ± 0.00
MeanConditionalPointwiseMutualInformation,4.79 ± 0.00,9.58 ± 0.00,-16.51 ± 0.00,-27.87 ± 0.00,28.45 ± 0.00,31.21 ± 0.00,-21.62 ± 0.00,-13.66 ± 0.00,32.78 ± 0.00,39.75 ± 0.00,-2.23 ± 0.00,-20.18 ± 0.00,16.39 ± 0.00,32.78 ± 0.00,-10.35 ± 0.00,-55.70 ± 0.00,32.78 ± 0.00,39.75 ± 0.00,-2.23 ± 0.00,-20.18 ± 0.00
PTrue,19.29 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,47.81 ± 0.00,45.27 ± 0.00,40.21 ± 0.00,69.43 ± 0.00,36.36 ± 0.00,46.68 ± 0.00,43.02 ± 0.00,25.62 ± 0.00,20.44 ± 0.00,46.94 ± 0.00,32.78 ± 0.00,-10.35 ± 0.00,42.17 ± 0.00,46.68 ± 0.00,43.02 ± 0.00,25.62 ± 0.00,20.44 ± 0.00
PTrueSampling,6.46 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,-19.17 ± 0.00,11.58 ± 0.00,23.12 ± 0.00,-103.47 ± 0.00,-63.82 ± 0.00,13.11 ± 0.00,25.98 ± 0.00,-119.30 ± 0.00,-77.67 ± 0.00,8.89 ± 0.00,17.78 ± 0.00,-208.98 ± 0.00,-79.73 ± 0.00,13.11 ± 0.00,25.98 ± 0.00,-119.30 ± 0.00,-77.67 ± 0.00
MonteCarloSequenceEntropy,29.29 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,100.00 ± 0.00,64.50 ± 0.00,43.23 ± 0.00,99.95 ± 0.00,93.51 ± 0.00,73.87 ± 0.00,51.74 ± 0.00,99.74 ± 0.00,99.90 ± 0.00,65.00 ± 0.00,41.11 ± 0.00,100.00 ± 0.00,100.00 ± 0.00,73.87 ± 0.00,51.74 ± 0.00,99.74 ± 0.00,99.90 ± 0.00
MonteCarloNormalizedSequenceEntropy,19.29 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,47.81 ± 0.00,65.82 ± 0.00,43.23 ± 0.00,100.00 ± 0.00,97.43 ± 0.00,71.43 ± 0.00,51.77 ± 0.00,100.00 ± 0.00,92.76 ± 0.00,56.67 ± 0.00,41.11 ± 0.00,100.00 ± 0.00,73.31 ± 0.00,71.43 ± 0.00,51.77 ± 0.00,100.00 ± 0.00,92.76 ± 0.00
EigenScore,29.29 ± 0.00,12.91 ± 0.00,100.00 ± 0.00,100.00 ± 0.00,55.39 ± 0.00,43.22 ± 0.00,99.83 ± 0.00,66.44 ± 0.00,68.00 ± 0.00,51.67 ± 0.00,99.14 ± 0.00,82.75 ± 0.00,65.00 ± 0.00,41.11 ± 0.00,100.00 ± 0.00,100.00 ± 0.00,68.00 ± 0.00,51.67 ± 0.00,99.14 ± 0.00,82.75 ± 0.00
