In [92]:
import os
import numpy as np
import functools
import pandas as pd
from matplotlib import colors
import matplotlib.pyplot as plt
from collections import defaultdict
from utils.manager import UEManager

def b_g(s, A, cmap='PuBu', low=0.8, high=0):
    # Pass the columns from Dataframe A
    i = A.columns.tolist().index(s.name)
    a = A.values[:,i].copy()
    if s.name[-1] in ['rcc-auc']:
        a = -a
    if s.name[0] in ['BARTScoreSeq-rh', 'BARTScoreToken-rh']:
        a = -a
    rng = a.max() - a.min()
    norm = colors.Normalize(a.min() - (rng * low),
                        a.max() + (rng * high))
    normed = norm(a)
    c = [colors.rgb2hex(x) for x in plt.colormaps[cmap](normed)]
    return ['background-color: %s' % color for color in c]

def pretty_est_name(e):
    if e.startswith('AttentionRecursiveSeq_'):
        e = 'Exponential' + e
    if 'Attention' in e or e == 'PUncertainty' or 'AdaptedSampling' in e:
        return e + '*'
    return e

def pretty_gen_metrics_name(g):
    if g == 'WERTokenwise':
        return 'LCS'
    return g

def get_array(dfs, row, col):
    vals = []
    for df in dfs:
        if row in df.index and col in df.columns:
            vals.append(df.loc[row, col])
    return vals

def pretty_plot(dataset_name, man_files, except_metrics=[], except_gen=['BARTScoreSeq-rh'], level='sequence'):
    dfs = []
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
        man_files = [man_files]
    columns = []
    for group_name, group_files in zip(dataset_name, man_files):
        gen_metrics = None
        for f in group_files:
            man = UEManager.load(f)
            estimators = [e for (l, e) in man.estimations.keys() if l == level]
            gen_metrics = list(set([(group_name, pretty_gen_metrics_name(gen_name), m_name)
               for (l, e_name, gen_name, m_name) in man.metrics
               if l == level and (m_name not in except_metrics) and (gen_name not in except_gen)]))
            gen_metrics.sort()
            df = {k: {} for k in gen_metrics}
            for (l, e_name, gen_name, m_name), value in man.metrics.items():
                if l == level and (m_name not in except_metrics) and (gen_name not in except_gen):
                    df[group_name, pretty_gen_metrics_name(gen_name), m_name][e_name] = value
            for k in gen_metrics:
                df[k] = [df[k][e] for e in estimators]
            df = pd.DataFrame(data=df, index=[pretty_est_name(e) for e in estimators])
            df = df.reindex(columns=gen_metrics)
            dfs.append(df)
        print('Will measure variance using', len(group_files), 'seeds')
        columns += gen_metrics
    assert(len(dfs) > 0)
    index = dfs[0].index
    mean, total = defaultdict(lambda: defaultdict(int)), defaultdict(lambda: defaultdict(int))
    for col in columns:
        for row in index:
            vals = get_array(dfs, row, col)
            mean[row][col] = -np.mean(vals)
            total[row][col] = '{:.2f} ± {:.2f}'.format(np.mean(vals).item() * 100, np.std(vals).item() * 100)
    
    total_df = pd.DataFrame([[total[row][col] for col in columns] for row in index],
                            index=index, columns=pd.MultiIndex.from_tuples(columns))
    mean_df = pd.DataFrame([[mean[row][col] for col in columns] for row in index],
                           index=index, columns=pd.MultiIndex.from_tuples(columns))
    
    s = total_df.style.apply(functools.partial(b_g, A=mean_df, cmap='Reds'), axis=0)
    s.set_table_styles([{  # for row hover use <tr> instead of <td>
        'selector': 'td:hover',
        'props': [('background-color', '#ffffb3')]
    }, {
        'selector': '.index_name',
        'props': 'font-style: italic; color: darkgrey; font-weight:normal;'
    }])
    s.set_table_styles({
        columns[i]: [{'selector': 'th', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)},
                     {'selector': 'td', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)}]
        for i in range(1, len(columns)) if i == 0 or columns[i][1] != columns[i - 1][1]
    }, overwrite=False, axis=0)
    return s

In [93]:
pretty_plot(
    'TriviaQA, Dolly3b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/qa_dolly3b_seed' + str(x)
     for x in range(1, 10)])

Will measure variance using 9 seeds


Unnamed: 0_level_0,"TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,2.12 ± 0.43,97.88 ± 0.43,15.31 ± 0.00,1.93 ± 0.50,98.07 ± 0.50,9.18 ± 0.00,2.12 ± 0.43,97.88 ± 0.43,15.31 ± 0.00
MaxProbabilityNormalizedSeq,3.69 ± 0.75,96.31 ± 0.75,15.31 ± 0.00,4.21 ± 1.08,95.79 ± 1.08,9.11 ± 0.00,3.69 ± 0.75,96.31 ± 0.75,15.31 ± 0.00
EntropySeq,2.30 ± 0.46,97.70 ± 0.46,15.43 ± 0.00,2.15 ± 0.55,97.85 ± 0.55,9.22 ± 0.00,2.30 ± 0.46,97.70 ± 0.46,15.43 ± 0.00
MutualInformationSeq,2.68 ± 0.54,97.32 ± 0.54,15.91 ± 0.00,2.76 ± 0.71,97.24 ± 0.71,9.98 ± 0.00,2.68 ± 0.54,97.32 ± 0.54,15.91 ± 0.00
ConditionalMutualInformationSeq,2.12 ± 0.43,97.88 ± 0.43,15.31 ± 0.00,1.93 ± 0.50,98.07 ± 0.50,9.18 ± 0.00,2.12 ± 0.43,97.88 ± 0.43,15.31 ± 0.00
AttentionEntropySeq*,2.15 ± 0.42,97.85 ± 0.42,14.11 ± 0.12,2.15 ± 0.55,97.85 ± 0.55,7.69 ± 0.18,2.15 ± 0.42,97.85 ± 0.42,14.11 ± 0.12
AttentionRecursiveSeq*,2.19 ± 0.43,97.81 ± 0.43,14.42 ± 0.07,2.06 ± 0.52,97.94 ± 0.52,8.00 ± 0.10,2.19 ± 0.43,97.81 ± 0.43,14.42 ± 0.07
ExponentialAttentionEntropySeq_0.9*,2.66 ± 0.54,97.34 ± 0.54,15.29 ± 0.00,2.84 ± 0.73,97.16 ± 0.73,9.09 ± 0.00,2.66 ± 0.54,97.34 ± 0.54,15.29 ± 0.00
ExponentialAttentionEntropySeq_0.8*,2.61 ± 0.53,97.39 ± 0.53,15.36 ± 0.00,2.67 ± 0.69,97.33 ± 0.69,9.16 ± 0.00,2.61 ± 0.53,97.39 ± 0.53,15.36 ± 0.00
ExponentialAttentionRecursiveSeq_0.9*,2.49 ± 0.50,97.51 ± 0.50,15.20 ± 0.00,2.45 ± 0.63,97.55 ± 0.63,9.03 ± 0.00,2.49 ± 0.50,97.51 ± 0.50,15.20 ± 0.00


In [78]:
pretty_plot(
    'TriviaQA, Bloomz3b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/qa_bloomz3b_seed' + str(x)
     for x in range(1, 10)])

Will measure variance using 9 seeds


Unnamed: 0_level_0,"TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,12.63 ± 0.00,87.37 ± 0.00,6.49 ± 0.00,6.61 ± 0.00,93.39 ± 0.00,3.14 ± 0.00,12.61 ± 0.00,87.39 ± 0.00,6.54 ± 0.00
MaxProbabilityNormalizedSeq,13.83 ± 0.00,86.17 ± 0.00,7.22 ± 0.00,6.97 ± 0.00,93.03 ± 0.00,3.53 ± 0.00,13.82 ± 0.00,86.18 ± 0.00,7.29 ± 0.00
EntropySeq,14.18 ± 0.00,85.82 ± 0.00,7.25 ± 0.00,6.98 ± 0.00,93.02 ± 0.00,3.36 ± 0.00,14.17 ± 0.00,85.83 ± 0.00,7.31 ± 0.00
MutualInformationSeq,22.70 ± 0.00,77.30 ± 0.00,11.54 ± 0.00,14.65 ± 0.00,85.35 ± 0.00,7.25 ± 0.00,22.69 ± 0.00,77.31 ± 0.00,11.62 ± 0.00
ConditionalMutualInformationSeq,12.63 ± 0.00,87.37 ± 0.00,6.49 ± 0.00,6.61 ± 0.00,93.39 ± 0.00,3.14 ± 0.00,12.61 ± 0.00,87.39 ± 0.00,6.54 ± 0.00
AttentionEntropySeq*,13.89 ± 0.00,86.11 ± 0.00,6.93 ± 0.00,7.27 ± 0.00,92.73 ± 0.00,3.41 ± 0.00,13.88 ± 0.00,86.12 ± 0.00,6.99 ± 0.00
AttentionRecursiveSeq*,13.99 ± 0.00,86.01 ± 0.00,6.91 ± 0.00,7.41 ± 0.00,92.59 ± 0.00,3.37 ± 0.00,13.98 ± 0.00,86.02 ± 0.00,6.98 ± 0.00
ExponentialAttentionEntropySeq_0.9*,13.78 ± 0.00,86.22 ± 0.00,6.92 ± 0.00,7.09 ± 0.00,92.91 ± 0.00,3.36 ± 0.00,13.77 ± 0.00,86.23 ± 0.00,6.98 ± 0.00
ExponentialAttentionEntropySeq_0.8*,13.80 ± 0.00,86.20 ± 0.00,6.91 ± 0.00,7.14 ± 0.00,92.86 ± 0.00,3.35 ± 0.00,13.79 ± 0.00,86.21 ± 0.00,6.97 ± 0.00
ExponentialAttentionRecursiveSeq_0.9*,14.22 ± 0.00,85.78 ± 0.00,7.19 ± 0.00,7.58 ± 0.00,92.42 ± 0.00,3.49 ± 0.00,14.20 ± 0.00,85.80 ± 0.00,7.24 ± 0.00


In [79]:
pretty_plot(
    'XSUM, Dolly3b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/xsum_dolly3b_seed' + str(x)
     for x in range(1, 6)])

Will measure variance using 5 seeds


Unnamed: 0_level_0,"XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,34.94 ± 0.00,65.06 ± 0.00,25.23 ± 0.00,14.24 ± 0.00,85.76 ± 0.00,20.53 ± 0.00,41.22 ± 0.00,58.78 ± 0.00,23.42 ± 0.00
MaxProbabilityNormalizedSeq,35.41 ± 0.00,64.59 ± 0.00,25.51 ± 0.00,14.90 ± 0.00,85.10 ± 0.00,20.58 ± 0.00,41.75 ± 0.00,58.25 ± 0.00,23.63 ± 0.00
EntropySeq,34.78 ± 0.00,65.22 ± 0.00,24.94 ± 0.00,14.08 ± 0.00,85.92 ± 0.00,20.30 ± 0.00,40.83 ± 0.00,59.17 ± 0.00,23.10 ± 0.00
MutualInformationSeq,36.20 ± 0.00,63.80 ± 0.00,27.84 ± 0.00,17.34 ± 0.00,82.66 ± 0.00,22.55 ± 0.00,42.83 ± 0.00,57.17 ± 0.00,25.57 ± 0.00
ConditionalMutualInformationSeq,34.94 ± 0.00,65.06 ± 0.00,25.23 ± 0.00,14.24 ± 0.00,85.76 ± 0.00,20.53 ± 0.00,41.22 ± 0.00,58.78 ± 0.00,23.42 ± 0.00
AttentionEntropySeq*,34.47 ± 0.00,65.53 ± 0.00,24.39 ± 0.00,13.85 ± 0.00,86.15 ± 0.00,19.93 ± 0.00,40.46 ± 0.00,59.54 ± 0.00,22.77 ± 0.00
AttentionRecursiveSeq*,34.55 ± 0.00,65.45 ± 0.00,24.49 ± 0.00,13.96 ± 0.00,86.04 ± 0.00,19.99 ± 0.00,40.60 ± 0.00,59.40 ± 0.00,22.86 ± 0.00
ExponentialAttentionEntropySeq_0.9*,34.78 ± 0.00,65.22 ± 0.00,24.85 ± 0.00,14.09 ± 0.00,85.91 ± 0.00,20.21 ± 0.00,40.86 ± 0.00,59.14 ± 0.00,23.07 ± 0.00
ExponentialAttentionEntropySeq_0.8*,34.81 ± 0.00,65.19 ± 0.00,24.95 ± 0.00,14.15 ± 0.00,85.85 ± 0.00,20.32 ± 0.00,40.88 ± 0.00,59.12 ± 0.00,23.12 ± 0.00
ExponentialAttentionRecursiveSeq_0.9*,34.74 ± 0.00,65.26 ± 0.00,24.71 ± 0.00,14.08 ± 0.00,85.92 ± 0.00,20.11 ± 0.00,40.80 ± 0.00,59.20 ± 0.00,23.00 ± 0.00


In [80]:
pretty_plot(
    'XSUM, Dolly7b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/xsum_dolly7b_seed' + str(x)
     for x in range(1, 6)])

Will measure variance using 5 seeds


Unnamed: 0_level_0,"XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b","XSUM, Dolly7b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,15.54 ± 9.75,84.46 ± 9.75,22.89 ± 0.00,3.79 ± 3.95,96.21 ± 3.95,19.93 ± 0.00,12.35 ± 8.93,87.65 ± 8.93,20.92 ± 0.00
MaxProbabilityNormalizedSeq,15.41 ± 9.67,84.59 ± 9.67,22.86 ± 0.00,3.69 ± 3.84,96.31 ± 3.84,19.82 ± 0.00,12.41 ± 8.97,87.59 ± 8.97,20.86 ± 0.00
EntropySeq,15.48 ± 9.71,84.52 ± 9.71,22.61 ± 0.00,3.79 ± 3.94,96.21 ± 3.94,19.82 ± 0.00,12.31 ± 8.90,87.69 ± 8.90,20.71 ± 0.00
MutualInformationSeq,16.07 ± 10.09,83.93 ± 10.09,25.64 ± 0.00,4.26 ± 4.43,95.74 ± 4.43,22.30 ± 0.00,13.06 ± 9.45,86.94 ± 9.45,23.69 ± 0.00
ConditionalMutualInformationSeq,15.54 ± 9.75,84.46 ± 9.75,22.89 ± 0.00,3.79 ± 3.95,96.21 ± 3.95,19.93 ± 0.00,12.35 ± 8.93,87.65 ± 8.93,20.92 ± 0.00
AttentionEntropySeq*,15.39 ± 9.63,84.61 ± 9.63,22.40 ± 0.03,3.69 ± 3.82,96.31 ± 3.82,19.46 ± 0.05,12.26 ± 8.86,87.74 ± 8.86,20.49 ± 0.02
AttentionRecursiveSeq*,15.41 ± 9.66,84.59 ± 9.66,22.46 ± 0.02,3.70 ± 3.84,96.30 ± 3.84,19.52 ± 0.04,12.27 ± 8.88,87.73 ± 8.88,20.54 ± 0.01
ExponentialAttentionEntropySeq_0.9*,15.46 ± 9.70,84.54 ± 9.70,22.60 ± 0.00,3.76 ± 3.92,96.24 ± 3.92,19.71 ± 0.00,12.29 ± 8.89,87.71 ± 8.89,20.67 ± 0.00
ExponentialAttentionEntropySeq_0.8*,15.46 ± 9.70,84.54 ± 9.70,22.60 ± 0.00,3.77 ± 3.92,96.23 ± 3.92,19.75 ± 0.00,12.29 ± 8.89,87.71 ± 8.89,20.68 ± 0.00
ExponentialAttentionRecursiveSeq_0.9*,15.45 ± 9.69,84.55 ± 9.69,22.66 ± 0.00,3.75 ± 3.91,96.25 ± 3.91,19.70 ± 0.00,12.28 ± 8.89,87.72 ± 8.89,20.70 ± 0.00


In [81]:
pretty_plot(
    'XSUM, bloomz3b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/xsum_bloomz3b_seed' + str(x)
     for x in range(1, 10)])

Will measure variance using 9 seeds


Unnamed: 0_level_0,"XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b","XSUM, bloomz3b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,33.34 ± 0.00,66.66 ± 0.00,16.50 ± 0.00,11.88 ± 0.00,88.12 ± 0.00,16.59 ± 0.00,27.38 ± 0.00,72.62 ± 0.00,16.20 ± 0.00
MaxProbabilityNormalizedSeq,31.73 ± 0.00,68.27 ± 0.00,16.42 ± 0.00,12.18 ± 0.00,87.82 ± 0.00,17.15 ± 0.00,26.62 ± 0.00,73.38 ± 0.00,16.63 ± 0.00
EntropySeq,32.00 ± 0.00,68.00 ± 0.00,15.66 ± 0.00,11.46 ± 0.00,88.54 ± 0.00,15.93 ± 0.00,26.35 ± 0.00,73.65 ± 0.00,15.52 ± 0.00
MutualInformationSeq,33.81 ± 0.00,66.19 ± 0.00,20.18 ± 0.00,14.46 ± 0.00,85.54 ± 0.00,20.30 ± 0.00,29.17 ± 0.00,70.83 ± 0.00,20.57 ± 0.00
ConditionalMutualInformationSeq,33.34 ± 0.00,66.66 ± 0.00,16.50 ± 0.00,11.88 ± 0.00,88.12 ± 0.00,16.59 ± 0.00,27.38 ± 0.00,72.62 ± 0.00,16.20 ± 0.00
AttentionEntropySeq*,31.66 ± 0.00,68.34 ± 0.00,14.90 ± 0.00,11.10 ± 0.00,88.90 ± 0.00,15.36 ± 0.00,26.07 ± 0.00,73.93 ± 0.00,14.79 ± 0.00
AttentionRecursiveSeq*,31.60 ± 0.00,68.40 ± 0.00,14.77 ± 0.00,11.11 ± 0.00,88.89 ± 0.00,15.41 ± 0.00,26.11 ± 0.00,73.89 ± 0.00,14.84 ± 0.00
ExponentialAttentionEntropySeq_0.9*,31.84 ± 0.00,68.16 ± 0.00,15.04 ± 0.00,11.17 ± 0.00,88.83 ± 0.00,15.45 ± 0.00,26.22 ± 0.00,73.78 ± 0.00,14.93 ± 0.00
ExponentialAttentionEntropySeq_0.8*,31.78 ± 0.00,68.22 ± 0.00,15.10 ± 0.00,11.18 ± 0.00,88.82 ± 0.00,15.51 ± 0.00,26.17 ± 0.00,73.83 ± 0.00,15.01 ± 0.00
ExponentialAttentionRecursiveSeq_0.9*,31.53 ± 0.00,68.47 ± 0.00,14.66 ± 0.00,11.06 ± 0.00,88.94 ± 0.00,15.22 ± 0.00,26.01 ± 0.00,73.99 ± 0.00,14.62 ± 0.00


In [82]:
pretty_plot(
    'MT, Dolly3b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/wmt_dolly3b_seed' + str(x)
     for x in range(1, 10)])

Will measure variance using 9 seeds


Unnamed: 0_level_0,"MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,17.68 ± 0.00,82.32 ± 0.00,18.93 ± 0.00,10.26 ± 0.00,89.74 ± 0.00,15.54 ± 0.00,15.98 ± 0.00,84.02 ± 0.00,18.46 ± 0.00
MaxProbabilityNormalizedSeq,22.58 ± 0.00,77.42 ± 0.00,20.28 ± 0.00,13.47 ± 0.00,86.53 ± 0.00,16.50 ± 0.00,20.47 ± 0.00,79.53 ± 0.00,19.83 ± 0.00
EntropySeq,19.33 ± 0.00,80.67 ± 0.00,19.72 ± 0.00,11.34 ± 0.00,88.66 ± 0.00,16.25 ± 0.00,17.40 ± 0.00,82.60 ± 0.00,19.23 ± 0.00
MutualInformationSeq,26.14 ± 0.00,73.86 ± 0.00,23.79 ± 0.00,15.94 ± 0.00,84.06 ± 0.00,18.88 ± 0.00,24.12 ± 0.00,75.88 ± 0.00,23.39 ± 0.00
ConditionalMutualInformationSeq,17.68 ± 0.00,82.32 ± 0.00,18.93 ± 0.00,10.26 ± 0.00,89.74 ± 0.00,15.54 ± 0.00,15.98 ± 0.00,84.02 ± 0.00,18.46 ± 0.00
AttentionEntropySeq*,18.65 ± 0.00,81.35 ± 0.00,19.57 ± 0.00,10.76 ± 0.00,89.24 ± 0.00,15.88 ± 0.00,16.94 ± 0.00,83.06 ± 0.00,19.11 ± 0.00
AttentionRecursiveSeq*,18.95 ± 0.00,81.05 ± 0.00,19.84 ± 0.00,11.11 ± 0.00,88.89 ± 0.00,16.23 ± 0.00,17.16 ± 0.00,82.84 ± 0.00,19.36 ± 0.00
ExponentialAttentionEntropySeq_0.9*,19.58 ± 0.00,80.42 ± 0.00,19.90 ± 0.00,11.42 ± 0.00,88.58 ± 0.00,16.36 ± 0.00,17.64 ± 0.00,82.36 ± 0.00,19.42 ± 0.00
ExponentialAttentionEntropySeq_0.8*,19.59 ± 0.00,80.41 ± 0.00,19.89 ± 0.00,11.46 ± 0.00,88.54 ± 0.00,16.37 ± 0.00,17.65 ± 0.00,82.35 ± 0.00,19.40 ± 0.00
ExponentialAttentionRecursiveSeq_0.9*,19.55 ± 0.00,80.45 ± 0.00,20.01 ± 0.00,11.48 ± 0.00,88.52 ± 0.00,16.47 ± 0.00,17.63 ± 0.00,82.37 ± 0.00,19.54 ± 0.00


In [83]:
pretty_plot(
    'MT, Bloomz3b',
    ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/wmt_bloomz3b_seed' + str(x)
     for x in range(1, 10)])

Will measure variance using 9 seeds


Unnamed: 0_level_0,"MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b"
Unnamed: 0_level_1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge1,Rouge_rouge2,Rouge_rouge2,Rouge_rouge2,Rouge_rougeL,Rouge_rougeL,Rouge_rougeL
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilitySeq,13.43 ± 0.20,86.57 ± 0.20,11.53 ± 0.10,4.53 ± 0.08,95.47 ± 0.08,9.66 ± 0.06,11.98 ± 0.17,88.02 ± 0.17,11.67 ± 0.11
MaxProbabilityNormalizedSeq,8.09 ± 0.14,91.91 ± 0.14,5.63 ± 0.05,3.24 ± 0.07,96.76 ± 0.07,5.31 ± 0.08,7.36 ± 0.13,92.64 ± 0.13,5.94 ± 0.06
EntropySeq,8.75 ± 0.16,91.25 ± 0.16,5.85 ± 0.07,2.92 ± 0.07,97.08 ± 0.07,5.19 ± 0.07,7.85 ± 0.15,92.15 ± 0.15,6.08 ± 0.07
MutualInformationSeq,8.64 ± 0.15,91.36 ± 0.15,6.71 ± 0.08,3.98 ± 0.08,96.02 ± 0.08,6.58 ± 0.10,7.96 ± 0.14,92.04 ± 0.14,7.04 ± 0.08
ConditionalMutualInformationSeq,13.43 ± 0.20,86.57 ± 0.20,11.53 ± 0.10,4.53 ± 0.08,95.47 ± 0.08,9.66 ± 0.06,11.98 ± 0.17,88.02 ± 0.17,11.67 ± 0.11
AttentionEntropySeq*,9.58 ± 0.19,90.42 ± 0.19,6.95 ± 0.10,3.23 ± 0.08,96.77 ± 0.08,5.99 ± 0.10,8.57 ± 0.16,91.43 ± 0.16,7.14 ± 0.10
AttentionRecursiveSeq*,9.82 ± 0.19,90.18 ± 0.19,7.24 ± 0.10,3.34 ± 0.08,96.66 ± 0.08,6.21 ± 0.11,8.79 ± 0.17,91.21 ± 0.17,7.43 ± 0.10
ExponentialAttentionEntropySeq_0.9*,9.62 ± 0.19,90.38 ± 0.19,6.95 ± 0.10,3.18 ± 0.08,96.82 ± 0.08,5.87 ± 0.10,8.59 ± 0.16,91.41 ± 0.16,7.14 ± 0.10
ExponentialAttentionEntropySeq_0.8*,9.52 ± 0.18,90.48 ± 0.18,6.82 ± 0.09,3.14 ± 0.08,96.86 ± 0.08,5.75 ± 0.10,8.52 ± 0.16,91.48 ± 0.16,7.03 ± 0.09
ExponentialAttentionRecursiveSeq_0.9*,9.30 ± 0.18,90.70 ± 0.18,6.69 ± 0.10,3.16 ± 0.08,96.84 ± 0.08,5.86 ± 0.10,8.28 ± 0.16,91.72 ± 0.16,6.86 ± 0.10


In [94]:
pretty_plot(
    ['TriviaQA, Dolly3b', 'XSUM, Dolly3b', 'MT, Dolly3b'],
    [
        ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/qa_dolly3b_seed' + str(x) for x in range(2, 9)],
        ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/xsum_dolly3b_seed' + str(x) for x in range(2, 10)],
        ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/wmt_dolly3b_seed' + str(x) for x in range(2, 10)],
    ], level='token')

Will measure variance using 7 seeds
Will measure variance using 8 seeds
Will measure variance using 8 seeds


Unnamed: 0_level_0,"TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","TriviaQA, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","XSUM, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b","MT, Dolly3b"
Unnamed: 0_level_1,BARTScoreToken-rh,BARTScoreToken-rh,BARTScoreToken-rh,LCS,LCS,LCS,BARTScoreToken-rh,BARTScoreToken-rh,BARTScoreToken-rh,LCS,LCS,LCS,BARTScoreToken-rh,BARTScoreToken-rh,BARTScoreToken-rh,LCS,LCS,LCS
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilityToken,37.23 ± 0.00,62.77 ± 0.00,25.22 ± 0.00,0.29 ± 0.00,99.71 ± 0.00,0.15 ± 0.00,35.18 ± 0.00,64.82 ± 0.00,25.73 ± 0.00,3.59 ± 0.00,96.41 ± 0.00,1.86 ± 0.00,41.52 ± 0.00,58.48 ± 0.00,26.23 ± 0.00,4.20 ± 0.00,95.80 ± 0.00,2.19 ± 0.00
MaxProbabilityNormalizedToken,37.74 ± 0.00,62.26 ± 0.00,25.41 ± 0.00,0.48 ± 0.00,99.52 ± 0.00,0.18 ± 0.00,35.43 ± 0.00,64.57 ± 0.00,25.90 ± 0.00,3.91 ± 0.00,96.09 ± 0.00,1.92 ± 0.00,41.46 ± 0.00,58.54 ± 0.00,26.34 ± 0.00,5.06 ± 0.00,94.94 ± 0.00,2.38 ± 0.00
EntropyToken,37.23 ± 0.00,62.77 ± 0.00,25.17 ± 0.00,0.29 ± 0.00,99.71 ± 0.00,0.15 ± 0.00,35.17 ± 0.00,64.83 ± 0.00,25.73 ± 0.00,3.57 ± 0.00,96.43 ± 0.00,1.85 ± 0.00,41.37 ± 0.00,58.63 ± 0.00,26.28 ± 0.00,4.25 ± 0.00,95.75 ± 0.00,2.17 ± 0.00
MutualInformationToken,39.21 ± 0.00,60.79 ± 0.00,27.49 ± 0.00,0.29 ± 0.00,99.71 ± 0.00,0.14 ± 0.00,35.40 ± 0.00,64.60 ± 0.00,25.62 ± 0.00,3.86 ± 0.00,96.14 ± 0.00,1.96 ± 0.00,41.72 ± 0.00,58.28 ± 0.00,26.92 ± 0.00,4.85 ± 0.00,95.15 ± 0.00,2.38 ± 0.00
ConditionalMutualInformationToken,37.23 ± 0.00,62.77 ± 0.00,25.22 ± 0.00,0.29 ± 0.00,99.71 ± 0.00,0.15 ± 0.00,35.18 ± 0.00,64.82 ± 0.00,25.73 ± 0.00,3.59 ± 0.00,96.41 ± 0.00,1.86 ± 0.00,41.52 ± 0.00,58.48 ± 0.00,26.23 ± 0.00,4.20 ± 0.00,95.80 ± 0.00,2.19 ± 0.00
AttentionEntropyToken*,37.63 ± 0.00,62.37 ± 0.00,25.36 ± 0.00,0.31 ± 0.00,99.69 ± 0.00,0.15 ± 0.00,35.30 ± 0.00,64.70 ± 0.00,25.88 ± 0.00,3.66 ± 0.00,96.34 ± 0.00,1.86 ± 0.00,41.60 ± 0.00,58.40 ± 0.00,26.42 ± 0.00,4.41 ± 0.00,95.59 ± 0.00,2.22 ± 0.00
AttentionRecursiveToken*,36.91 ± 0.00,63.09 ± 0.00,24.14 ± 0.00,0.32 ± 0.00,99.68 ± 0.00,0.16 ± 0.00,35.06 ± 0.00,64.94 ± 0.00,25.12 ± 0.00,3.73 ± 0.00,96.27 ± 0.00,1.92 ± 0.00,41.33 ± 0.00,58.67 ± 0.00,26.15 ± 0.00,4.08 ± 0.00,95.92 ± 0.00,2.11 ± 0.00
ExponentialAttentionEntropyToken_0.9*,37.40 ± 0.00,62.60 ± 0.00,25.26 ± 0.00,0.33 ± 0.00,99.67 ± 0.00,0.15 ± 0.00,35.23 ± 0.00,64.77 ± 0.00,25.82 ± 0.00,3.64 ± 0.00,96.36 ± 0.00,1.85 ± 0.00,41.54 ± 0.00,58.46 ± 0.00,26.42 ± 0.00,4.29 ± 0.00,95.71 ± 0.00,2.18 ± 0.00
ExponentialAttentionEntropyToken_0.8*,37.36 ± 0.00,62.64 ± 0.00,25.24 ± 0.00,0.31 ± 0.00,99.69 ± 0.00,0.15 ± 0.00,35.21 ± 0.00,64.79 ± 0.00,25.80 ± 0.00,3.60 ± 0.00,96.40 ± 0.00,1.85 ± 0.00,41.52 ± 0.00,58.48 ± 0.00,26.38 ± 0.00,4.23 ± 0.00,95.77 ± 0.00,2.17 ± 0.00
AttentionRecursiveToken_0.9*,36.98 ± 0.00,63.02 ± 0.00,24.30 ± 0.00,0.35 ± 0.00,99.65 ± 0.00,0.17 ± 0.00,35.11 ± 0.00,64.89 ± 0.00,25.19 ± 0.00,3.76 ± 0.00,96.24 ± 0.00,1.96 ± 0.00,41.39 ± 0.00,58.61 ± 0.00,26.28 ± 0.00,4.42 ± 0.00,95.58 ± 0.00,2.23 ± 0.00


In [85]:
pretty_plot(
    ['TriviaQA, Bloomz3b', 'XSUM, Bloomz3b', 'MT, Bloomz3b'],
    [
        ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/qa_bloomz3b_seed' + str(x) for x in range(2, 10)],
        ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/xsum_bloomz3b_seed' + str(x) for x in range(2, 10)],
        ['/Users/ekaterinafadeeva/work/data/uncertainty_mans/ue_mans/wmt_bloomz3b_seed' + str(x) for x in range(2, 10)],
    ], level='token')

Will measure variance using 8 seeds
Will measure variance using 8 seeds
Will measure variance using 8 seeds


Unnamed: 0_level_0,"TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","TriviaQA, Bloomz3b","XSUM, Bloomz3b","XSUM, Bloomz3b","XSUM, Bloomz3b","XSUM, Bloomz3b","XSUM, Bloomz3b","XSUM, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b","MT, Bloomz3b"
Unnamed: 0_level_1,BARTScoreToken-rh,BARTScoreToken-rh,BARTScoreToken-rh,LCS,LCS,LCS,BARTScoreToken-rh,BARTScoreToken-rh,BARTScoreToken-rh,LCS,LCS,LCS,BARTScoreToken-rh,BARTScoreToken-rh,BARTScoreToken-rh,LCS,LCS,LCS
Unnamed: 0_level_2,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp,prr,rcc-auc,rpp
MaxProbabilityToken,32.26 ± 0.00,67.74 ± 0.00,24.62 ± 0.00,4.05 ± 0.00,95.95 ± 0.00,2.20 ± 0.00,50.31 ± 0.00,49.69 ± 0.00,26.82 ± 0.00,21.61 ± 0.00,78.39 ± 0.00,8.04 ± 0.00,43.04 ± 0.06,56.96 ± 0.06,23.29 ± 0.04,17.43 ± 0.16,82.57 ± 0.16,5.66 ± 0.02
MaxProbabilityNormalizedToken,32.39 ± 0.00,67.61 ± 0.00,24.50 ± 0.00,3.91 ± 0.00,96.09 ± 0.00,2.31 ± 0.00,51.56 ± 0.00,48.44 ± 0.00,27.52 ± 0.00,22.21 ± 0.00,77.79 ± 0.00,8.05 ± 0.00,44.18 ± 0.04,55.82 ± 0.04,24.05 ± 0.03,19.56 ± 0.18,80.44 ± 0.18,6.37 ± 0.03
EntropyToken,32.88 ± 0.00,67.12 ± 0.00,24.65 ± 0.00,3.85 ± 0.00,96.15 ± 0.00,2.16 ± 0.00,50.36 ± 0.00,49.64 ± 0.00,26.87 ± 0.00,21.51 ± 0.00,78.49 ± 0.00,8.06 ± 0.00,43.23 ± 0.04,56.77 ± 0.04,23.59 ± 0.03,16.74 ± 0.18,83.26 ± 0.18,5.43 ± 0.03
MutualInformationToken,32.14 ± 0.00,67.86 ± 0.00,23.84 ± 0.00,4.16 ± 0.00,95.84 ± 0.00,2.52 ± 0.00,51.52 ± 0.00,48.48 ± 0.00,26.98 ± 0.00,25.57 ± 0.00,74.43 ± 0.00,9.57 ± 0.00,44.56 ± 0.02,55.44 ± 0.02,24.90 ± 0.02,23.63 ± 0.21,76.37 ± 0.21,8.50 ± 0.04
ConditionalMutualInformationToken,32.26 ± 0.00,67.74 ± 0.00,24.62 ± 0.00,4.05 ± 0.00,95.95 ± 0.00,2.20 ± 0.00,50.31 ± 0.00,49.69 ± 0.00,26.82 ± 0.00,21.61 ± 0.00,78.39 ± 0.00,8.04 ± 0.00,43.04 ± 0.06,56.96 ± 0.06,23.29 ± 0.04,17.43 ± 0.16,82.57 ± 0.16,5.66 ± 0.02
AttentionEntropyToken*,33.44 ± 0.00,66.56 ± 0.00,24.78 ± 0.00,3.62 ± 0.00,96.38 ± 0.00,2.09 ± 0.00,51.17 ± 0.00,48.83 ± 0.00,27.40 ± 0.00,21.15 ± 0.00,78.85 ± 0.00,7.89 ± 0.00,44.41 ± 0.04,55.59 ± 0.04,24.40 ± 0.03,17.15 ± 0.21,82.85 ± 0.21,5.52 ± 0.04
AttentionRecursiveToken*,32.74 ± 0.00,67.26 ± 0.00,23.88 ± 0.00,4.05 ± 0.00,95.95 ± 0.00,2.22 ± 0.00,49.88 ± 0.00,50.12 ± 0.00,25.01 ± 0.00,20.37 ± 0.00,79.63 ± 0.00,7.45 ± 0.00,43.21 ± 0.04,56.79 ± 0.04,23.32 ± 0.03,16.75 ± 0.26,83.25 ± 0.26,5.31 ± 0.06
ExponentialAttentionEntropyToken_0.9*,33.58 ± 0.00,66.42 ± 0.00,24.85 ± 0.00,3.58 ± 0.00,96.42 ± 0.00,2.09 ± 0.00,51.31 ± 0.00,48.69 ± 0.00,27.33 ± 0.00,21.00 ± 0.00,79.00 ± 0.00,7.72 ± 0.00,44.42 ± 0.04,55.58 ± 0.04,24.35 ± 0.03,16.93 ± 0.21,83.07 ± 0.21,5.45 ± 0.04
ExponentialAttentionEntropyToken_0.8*,33.53 ± 0.00,66.47 ± 0.00,24.83 ± 0.00,3.59 ± 0.00,96.41 ± 0.00,2.09 ± 0.00,51.13 ± 0.00,48.87 ± 0.00,27.18 ± 0.00,20.75 ± 0.00,79.25 ± 0.00,7.67 ± 0.00,44.18 ± 0.04,55.82 ± 0.04,24.08 ± 0.03,16.62 ± 0.19,83.38 ± 0.19,5.34 ± 0.03
AttentionRecursiveToken_0.9*,32.96 ± 0.00,67.04 ± 0.00,23.72 ± 0.00,3.62 ± 0.00,96.38 ± 0.00,2.01 ± 0.00,49.97 ± 0.00,50.03 ± 0.00,25.20 ± 0.00,20.89 ± 0.00,79.11 ± 0.00,7.72 ± 0.00,43.38 ± 0.03,56.62 ± 0.03,23.39 ± 0.01,17.27 ± 0.21,82.73 ± 0.21,5.55 ± 0.04
