## Visualize benchmark results in a table

In [None]:
import os
import numpy as np
import functools
import pandas as pd
from matplotlib import colors
import matplotlib.pyplot as plt
from collections import defaultdict
from lm_polygraph.utils.manager import UEManager

def b_g(s, A, cmap='PuBu', low=0.8, high=0):
    # Pass the columns from Dataframe A
    i = A.columns.tolist().index(s.name)
    a = A.values[:,i].copy()
    if s.name[-1] in ['rcc-auc']:
        a = -a
    if s.name[0] == 'BARTScoreSeq-rh':
        a = -a
    rng = a.max() - a.min()
    norm = colors.Normalize(a.min() - (rng * low),
                        a.max() + (rng * high))
    normed = norm(a)
    c = [colors.rgb2hex(x) for x in plt.colormaps[cmap](normed)]
    return ['background-color: %s' % color for color in c]

def get_array(dfs, row, col):
    vals = []
    for df in dfs:
        if row in df.index and col in df.columns:
            vals.append(df.loc[row, col])
    return vals

def pretty_plot(dataset_name, man_files, except_metrics=[], except_gen=['BARTScoreSeq-rh'], level='sequence'):
    dfs = []
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
        man_files = [man_files]
    columns = []
    for group_name, group_files in zip(dataset_name, man_files):
        gen_metrics = None
        for f in group_files:
            man = UEManager.load(f)
            estimators = [e for (l, e) in man.estimations.keys() if l == level]
            gen_metrics = list(set([(group_name, gen_name, m_name)
               for (l, e_name, gen_name, m_name) in man.metrics
               if l == level and (m_name not in except_metrics) and (gen_name not in except_gen)]))
            gen_metrics.sort()
            df = {k: {} for k in gen_metrics}
            for (l, e_name, gen_name, m_name), value in man.metrics.items():
                if l == level and (m_name not in except_metrics) and (gen_name not in except_gen):
                    df[group_name, gen_name, m_name][e_name] = value
            for k in gen_metrics:
                df[k] = [df[k][e] for e in estimators]
            df = pd.DataFrame(data=df, index=[e for e in estimators])
            df = df.reindex(columns=gen_metrics)
            dfs.append(df)
        print('Will measure variance using', len(group_files), 'seeds')
        columns += gen_metrics
    assert(len(dfs) > 0)
    index = dfs[0].index
    mean, total = defaultdict(lambda: defaultdict(int)), defaultdict(lambda: defaultdict(int))
    for col in columns:
        for row in index:
            vals = get_array(dfs, row, col)
            mean[row][col] = -np.mean(vals)
            total[row][col] = '{:.2f} ± {:.2f}'.format(np.mean(vals).item() * 100, np.std(vals).item() * 100)
    
    total_df = pd.DataFrame([[total[row][col] for col in columns] for row in index],
                            index=index, columns=pd.MultiIndex.from_tuples(columns))
    mean_df = pd.DataFrame([[mean[row][col] for col in columns] for row in index],
                           index=index, columns=pd.MultiIndex.from_tuples(columns))
    
    s = total_df.style.apply(functools.partial(b_g, A=mean_df, cmap='Reds'), axis=0)
    s.set_table_styles([{  # for row hover use <tr> instead of <td>
        'selector': 'td:hover',
        'props': [('background-color', '#ffffb3')]
    }, {
        'selector': '.index_name',
        'props': 'font-style: italic; color: darkgrey; font-weight:normal;'
    }])
    s.set_table_styles({
        columns[i]: [{'selector': 'th', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)},
                     {'selector': 'td', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)}]
        for i in range(1, len(columns)) if i == 0 or columns[i][1] != columns[i - 1][1]
    }, overwrite=False, axis=0)
    return s

In [None]:
# visualize results in a table
pretty_plot(
    'TriviaQA, Dolly3b',
    # outputs generated by scripts/polygraph_eval benchmark
    # provide several seeds to calculate variance
    ['./workdir/output_seed' + str(x)
     for x in range(1, 10)])