In [1]:
import sys
sys.path.insert(0,'..')

import yaml
import os
from yaml import Loader as Loader
from pathlib import Path
import pandas as pd
import numpy as np
import json
from sklearn.metrics import roc_auc_score

from analyze_results import (
    extract_result,
    aggregate_runs,
    from_model_outputs_calc_rcc_auc,
    format_results2,
    improvement_over_baseline,
    from_model_outputs_calc_pr_auc,
    from_model_outputs_calc_rpp,
    from_model_outputs_calc_roc_auc,
    from_model_outputs_calc_arc_auc,
    from_model_outputs_calc_f1_score,
    from_model_outputs_calc_recall,
    from_model_outputs_calc_precision
)

from utils.utils_wandb import init_wandb, wandb
from ue4nlp.ue_scores import *

In [2]:
def choose_metric(metric_type):
    if metric_type  == "rejection-curve-auc":
        return metric_type
    elif metric_type  == "roc-auc":
        return metric_type
    elif metric_type == "rcc-auc":
        return from_model_outputs_calc_rcc_auc
    elif metric_type == "rpp":
        return from_model_outputs_calc_rpp
    elif metric_type == "pr-auc":
        return from_model_outputs_calc_pr_auc
    elif metric_type == "f1-score":
        return from_model_outputs_calc_f1_score
    elif metric_type == "recall":
        return from_model_outputs_calc_recall
    elif metric_type == "precision":
        return from_model_outputs_calc_precision
    else:
        raise ValueError("Wrong metric type!")
        
def get_one_table(runs_dir, metric_types=["rejection-curve-auc", "roc-auc", "pr-auc"], baseline=None, methods=None):
    default_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    
    if methods is None:
        methods = default_methods

    table = []
    for metric_type in metric_types:
        metric = choose_metric(metric_type=metric_type)

        agg_res = aggregate_runs(
            runs_dir, methods=methods, metric=metric, oos=True
        )
        if agg_res.empty:
            print("Broken\n")
            continue

        if metric_type == "rcc-auc":
            final_score = format_results2(agg_res, percents=False)
        elif metric_type == "rpp":
            final_score = format_results2(agg_res, percents=True)
        elif metric_type == "accuracy":
            final_score = format_results2(agg_res, percents=True)
        else:
            final_score = improvement_over_baseline(agg_res, baseline_col="max_prob", baseline=baseline, metric=metric_type, percents=True, subtract=False)
        table.append(final_score)
    res_table = pd.concat(table, axis=1)
    res_table.columns = metric_types
    # fix for rcc-auc and rpp
    if 'baseline (max_prob)' not in res_table.index:
        res_table.loc['baseline (max_prob)'] = 0
    try:
        res_table = res_table.drop(['max_prob', 'count'])
    except:
        res_table = res_table.drop(['count'])
    return res_table


def collect_tables(run_dirs, names, metric_types=["rejection-curve-auc", "roc-auc", "rcc-auc", "pr-auc", "rpp"], baseline=None, methods=None):
    all_tables = []
    for run_dir, name in zip(run_dirs, names):
        buf_table = get_one_table(run_dir, metric_types, baseline, methods)
        #print(buf_table)
        # add name to index
        indices = [(name, ind) for ind in list(buf_table.index)]
        baseline_name = 'baseline|'+'|'.join(name.split('|')[1:])
        buf_table.loc[baseline_name] = buf_table.loc['baseline (max_prob)']
        # add reindex
        indices = indices + [(baseline_name, 'max_prob')]
        
        index = pd.MultiIndex.from_tuples(indices, names=['Method', 'UE Score'])
        buf_table.index = index
        buf_table.drop((name, 'baseline (max_prob)'), inplace=True)
        # add buf_table to final_table
        all_tables.append(buf_table)
    return pd.concat(all_tables)


def collect_datasets(runs_dirs, names, dataset_names, metric_types=["rejection-curve-auc", "roc-auc", "rcc-auc", "pr-auc", "rpp"], baselines={}, methods=None):
    all_tables = []
    for run_dir, dataset_name in zip(runs_dirs, dataset_names):
        try:
            dataset_table = collect_tables(run_dir, names, metric_types, baselines.get(dataset_name, None), methods=methods)
            columns = pd.MultiIndex.from_tuples([(dataset_name, ind) for ind in list(dataset_table.columns)])
            dataset_table.columns = columns
            all_tables.append(dataset_table)
        except:
            print(f'empty dir {run_dir}')
    return pd.concat(all_tables, axis=1)

# OOD datasets

In [21]:
import os 
def choose_agg_func(method):
    agg_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    if method=='nuq' or method=='nuq_best' or method=='nuq_best1':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "nuq_aleatoric": nuq_aleatoric,
            "nuq_epistemic": nuq_epistemic,
            "nuq_total": nuq_total,
        }
    elif method=='decomposing_md':
        disc_md = lambda x: np.squeeze(x[0], axis=-1)
        nondisc_md = lambda x: np.squeeze(x[1], axis=-1)
        sum_md = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"disc_md": disc_md, 
                       'nondisc_md': nondisc_md,
                       'disc+nondisc_md': sum_md}
    elif method=='mahalanobis':
        maha_dist = lambda x: np.squeeze(x[0], axis=-1)
        rel_maha_dist = lambda x: np.squeeze(x[1], axis=-1)
        marg_maha_dist = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"mahalanobis_distance": maha_dist,
                       "relative_mahalanobis_distance": rel_maha_dist,
                       "marginal_mahalanobis_distance": marg_maha_dist}
    elif method=='ddu' or method=='ddu_maha':
        ddu = lambda x: -np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"ddu": ddu}
    elif method=='sngp':
        sngp = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"sngp": sngp}
    elif method=='mc_mahalanobis':
        sm_maha_dist = lambda x: np.squeeze(x[:, 1:], axis=-1).max(1)
        agg_methods = {"sampled_mahalanobis_distance": sm_maha_dist}
    return agg_methods
metric_types=["roc-auc", 'pr-auc']

methods = ['nuq', 'mahalanobis', 'decomposing_md', 'ddu']#, 'sngp']
regs = ['raw', 'reg', 'metric']
spectralnorm = ['sn', 'no_sn']
dataset_names = ['CLINC', 'ROSTD']
dataset_fnames = ['clinc', 'rostd']
names = []
tables = []
baselines = []
for method in methods:
    for reg in regs:
        for sn in spectralnorm:
            run_dirs = []
            name_sn = ''
            names = [f'{method}|{reg}|{sn}']
            for name in dataset_fnames:
                model_series_dir = f'../../workdir/run_glue_for_model_series/electra_{reg}_{sn}/{name}/{method}'
                print(model_series_dir)
                run_dirs.append([model_series_dir])
            agg_func = choose_agg_func(method)
            try:
                res_df = collect_datasets(run_dirs, names, dataset_names, metric_types=metric_types, baselines={}, methods=agg_func)
                baselines.append(res_df.iloc[-1:])
                tables.append(res_df.iloc[:-1])
            except:
                print('pass')
                pass

../../workdir/run_glue_for_model_series/electra_raw_sn/clinc/nuq
../../workdir/run_glue_for_model_series/electra_raw_sn/rostd/nuq
empty dir ['../../workdir/run_glue_for_model_series/electra_raw_sn/clinc/nuq']
empty dir ['../../workdir/run_glue_for_model_series/electra_raw_sn/rostd/nuq']
pass
../../workdir/run_glue_for_model_series/electra_raw_no_sn/clinc/nuq
../../workdir/run_glue_for_model_series/electra_raw_no_sn/rostd/nuq
empty dir ['../../workdir/run_glue_for_model_series/electra_raw_no_sn/clinc/nuq']
empty dir ['../../workdir/run_glue_for_model_series/electra_raw_no_sn/rostd/nuq']
pass
../../workdir/run_glue_for_model_series/electra_reg_sn/clinc/nuq
../../workdir/run_glue_for_model_series/electra_reg_sn/rostd/nuq
empty dir ['../../workdir/run_glue_for_model_series/electra_reg_sn/clinc/nuq']
empty dir ['../../workdir/run_glue_for_model_series/electra_reg_sn/rostd/nuq']
pass
../../workdir/run_glue_for_model_series/electra_reg_no_sn/clinc/nuq
../../workdir/run_glue_for_model_series/e

In [4]:
table_all = pd.concat([pd.concat(tables), pd.concat(baselines[:2])]).reset_index()

In [5]:
def preproc_regs(x):
    reg = x.split('|')[1]
    if reg == 'reg':
        return 'CER'
    elif reg == 'raw':
        return '-'
    else:
        return reg
    
def preproc_method(x):
    method = x.split('|')[0]
    sn = x.split('|')[-1]
    if method == 'mahalanobis' and not 'no_sn' in sn:
        return 'MD SN (ours)'
    elif method == 'mahalanobis':
        return 'MD'
    elif method == 'mc_mahalanobis' and not 'no_sn' in sn:
        return 'SMD SN (ours)'
    elif method == 'mc_mahalanobis':
        return 'SMD'
    elif method == 'nuq' and not 'no_sn' in sn:
        return 'NUQ SN'
    elif method == 'nuq':
        return 'NUQ'
    elif method == 'sngp':
        return 'SNGP'
    
    elif method == 'decomposing_md' and not 'no_sn' in sn:
        return 'Decomposing SN'
    elif method == 'decomposing_md':
        return 'Decomposing'
    
    elif method == 'nuq_best1' and not 'no_sn' in sn:
        return 'Best1 NUQ SN'
    elif method == 'nuq_best1':
        return 'Best1 NUQ'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu_maha' and not 'no_sn' in sn:
        return 'DDU Maha SN'
    elif method == 'ddu_maha':
        return 'DDU Maha'
    
    elif 'ddpp_dpp' in method:
        return 'DDPP (+DPP) (ours)'
    elif 'ddpp_ood' in method:
        return 'DDPP (+OOD) (ours)'
    elif 'mc_all' in method:
        return 'MC dropout'
    elif 'Deep' in method:
        return 'Deep Ensemble'
    elif 'baseline|raw_no_sn' in x:
        return 'SR (baseline)'
    elif 'baseline' in x and not 'no_sn' in x:
        return 'SR SN'
    return 'SR'

def preproc_ue(x):
    if x == 'bald':
        return 'BALD'
    elif 'sampled_mahalanobis_distance' in x:
        return 'SMD'
    elif 'relative_mahalanobis_distance' in x:
        return 'rel_MD'
    elif 'marginal_mahalanobis_distance' in x:
        return 'marg_MD'
    elif 'mahalanobis_distance' in x:
        return 'MD'
    elif 'sampled_max_prob' in x:
        return 'SMP'
    elif 'variance' in x:
        return 'PV'
    elif 'aleatoric' in x:
        return 'aleatoric'
    elif 'epistemic' in x:
        return 'epistemic'
    elif 'total' in x:
        return 'total'
    elif x == 'disc_md':
        return 'Disc MD'
    elif x == 'nondisc_md':
        return 'Nondisc MD'
    elif x == 'disc+nondisc_md':
        return 'Disc+Nondisc MD'
    elif x == 'sngp':
        return 'std'
    return 'MP'

table_all['Reg. Type'] = table_all.Method.apply(lambda x: preproc_regs(x))
table_all['Method'] = table_all.Method.apply(lambda x: preproc_method(x))
table_all['UE Score'] = table_all['UE Score'].apply(lambda x: preproc_ue(x))
table_all = table_all[list(table_all.columns[:1]) + list(table_all.columns[-1:]) + list(table_all.columns[1:-1])].reset_index(drop=True)

In [8]:
table_all.iloc[:-2]

Unnamed: 0_level_0,Method,Reg. Type,UE Score,CLINC,CLINC,ROSTD,ROSTD
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,roc-auc,pr-auc,roc-auc
0,NUQ SN,-,aleatoric,81.0±2.7,66.3±26.5,95.2±0.9,98.1±0.2
1,NUQ SN,-,epistemic,86.5±1.7,95.1±0.9,78.3±5.8,83.7±3.5
2,NUQ SN,-,total,82.7±2.0,77.9±17.3,84.8±4.3,88.5±3.0
3,NUQ,-,aleatoric,79.6±3.0,53.3±26.2,95.9±0.4,98.3±0.2
4,NUQ,-,epistemic,87.1±0.4,95.9±0.5,82.4±3.2,86.7±2.7
5,NUQ,-,total,81.5±2.5,70.7±16.1,88.6±3.2,91.5±2.8
6,NUQ SN,CER,aleatoric,79.6±2.1,54.8±18.9,95.6±1.3,98.2±0.4
7,NUQ SN,CER,epistemic,87.2±0.7,95.1±0.5,81.6±6.9,86.2±4.8
8,NUQ SN,CER,total,81.8±1.6,71.9±12.2,87.5±5.6,90.7±4.3
9,NUQ,CER,aleatoric,80.3±1.7,62.1±17.5,92.9±3.3,97.6±0.5


In [9]:
table_all.iloc[-2:]

Unnamed: 0_level_0,Method,Reg. Type,UE Score,CLINC,CLINC,ROSTD,ROSTD
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,roc-auc,pr-auc,roc-auc
60,SR SN,-,MP,86.0±0.9,94.3±0.4,68.4±4.2,77.0±3.2
61,SR,-,MP,85.9±0.5,95.7±0.2,70.9±5.0,78.7±3.5


# OOD new Benchmark

## SST-2

In [17]:
import os 

def choose_agg_func(method):
    agg_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    if method=='nuq' or method=='nuq_best' or method=='nuq_best1':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "nuq_aleatoric": nuq_aleatoric,
            "nuq_epistemic": nuq_epistemic,
            "nuq_total": nuq_total,
        }
    elif method=='l_nuq':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "l_nuq_aleatoric": nuq_aleatoric,
            "l_nuq_epistemic": nuq_epistemic,
            "l_nuq_total": nuq_total,
        }
    elif method=='decomposing_md':
        disc_md = lambda x: np.squeeze(x[0], axis=-1)
        nondisc_md = lambda x: np.squeeze(x[1], axis=-1)
        sum_md = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"disc_md": disc_md, 
                       'nondisc_md': nondisc_md,
                       'disc+nondisc_md': sum_md}
    elif method=='mahalanobis':
        maha_dist = lambda x: np.squeeze(x[0], axis=-1)
        rel_maha_dist = lambda x: np.squeeze(x[1], axis=-1)
        marg_maha_dist = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"mahalanobis_distance": maha_dist,
                       "relative_mahalanobis_distance": rel_maha_dist,
                       "marginal_mahalanobis_distance": marg_maha_dist}
    elif method=='l_mahalanobis':
        maha_dist = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"l_mahalanobis_distance": maha_dist}
    elif method=='ddu' or method=='ddu_maha':
        ddu = lambda x: -np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"ddu": ddu}
    elif method=='sngp':
        sngp = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"sngp": sngp}
    elif method=='mc_mahalanobis':
        sm_maha_dist = lambda x: np.squeeze(x[:, 1:], axis=-1).max(1)
        agg_methods = {"sampled_mahalanobis_distance": sm_maha_dist}
    return agg_methods
metric_types=["roc-auc", 'pr-auc', 'f1-score', 'precision', 'recall']

methods = ['l_nuq', 'nuq', 'mahalanobis', 'l_mahalanobis', 'decomposing_md', 'ddu']
regs = ['raw']
spectralnorm = ['sn', 'no_sn']
dataset_names = ['IMDB', 'TREC', 'WMT16', 'Amazon', 'MNLI', '20newsgroups', 'RTE']
dataset_fnames = ['imdb', 'trec', 'wmt16', 'amazon', 'mnli', 'newsgroup', 'rte']
names = []
tables = []
baselines = []
for method in methods:
    for reg in regs:
        for sn in spectralnorm:
            run_dirs = []
            name_sn = ''
            names = [f'{method}|{reg}|{sn}']
            for name in dataset_fnames:
                model_series_dir = f'../../workdir/run_tasks_for_model_series/electra_{reg}_{sn}/sst2/{name}/{method}'
                print(model_series_dir)
                run_dirs.append([model_series_dir])
            agg_func = choose_agg_func(method)
            try:
                res_df = collect_datasets(run_dirs, names, dataset_names, metric_types=metric_types, baselines={}, methods=agg_func)
                baselines.append(res_df.iloc[-1:])
                tables.append(res_df.iloc[:-1])
            except:
                print('pass')
                pass

../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/imdb/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/trec/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/wmt16/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/amazon/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/mnli/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/newsgroup/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/rte/l_nuq
empty dir ['../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/imdb/l_nuq']
empty dir ['../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/trec/l_nuq']
empty dir ['../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/wmt16/l_nuq']
empty dir ['../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/amazon/l_nuq']
empty dir ['../../workdir/run_tasks_for_model_series/electra_raw_sn/sst2/mnli/l_nuq']
empty dir ['../../workdir/run_tasks_for_model_series/ele

In [18]:
table_all = pd.concat([pd.concat(tables), pd.concat(baselines[-2:])]).reset_index()

ValueError: No objects to concatenate

In [19]:
def preproc_regs(x):
    reg = x.split('|')[1]
    if reg == 'reg':
        return 'CER'
    elif reg == 'raw':
        return '-'
    else:
        return reg
    
def preproc_method(x):
    method = x.split('|')[0]
    sn = x.split('|')[-1]
    if method == 'mahalanobis' and not 'no_sn' in sn:
        return 'MD SN (ours)'
    elif method == 'mahalanobis':
        return 'MD'
    elif method == 'l_mahalanobis' and not 'no_sn' in sn:
        return 'L-MD SN (ours)'
    elif method == 'l_mahalanobis':
        return 'L-MD'
    elif method == 'mc_mahalanobis' and not 'no_sn' in sn:
        return 'SMD SN (ours)'
    elif method == 'mc_mahalanobis':
        return 'SMD'
    elif method == 'nuq' and not 'no_sn' in sn:
        return 'NUQ SN'
    elif method == 'nuq':
        return 'NUQ'
    elif method == 'l_nuq' and not 'no_sn' in sn:
        return 'L-NUQ SN'
    elif method == 'l_nuq':
        return 'L-NUQ'
    elif method == 'sngp':
        return 'SNGP'
    
    elif method == 'decomposing_md' and not 'no_sn' in sn:
        return 'Decomposing SN'
    elif method == 'decomposing_md':
        return 'Decomposing'
    
    elif method == 'nuq_best1' and not 'no_sn' in sn:
        return 'Best1 NUQ SN'
    elif method == 'nuq_best1':
        return 'Best1 NUQ'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu_maha' and not 'no_sn' in sn:
        return 'DDU Maha SN'
    elif method == 'ddu_maha':
        return 'DDU Maha'
    
    elif 'ddpp_dpp' in method:
        return 'DDPP (+DPP) (ours)'
    elif 'ddpp_ood' in method:
        return 'DDPP (+OOD) (ours)'
    elif 'mc_all' in method:
        return 'MC dropout'
    elif 'Deep' in method:
        return 'Deep Ensemble'
    elif 'baseline|raw_no_sn' in x:
        return 'SR (baseline)'
    elif 'baseline' in x and not 'no_sn' in x:
        return 'SR SN'
    return 'SR'

def preproc_ue(x):
    if x == 'bald':
        return 'BALD'
    elif 'sampled_mahalanobis_distance' in x:
        return 'SMD'
    elif 'relative_mahalanobis_distance' in x:
        return 'rel_MD'
    elif 'marginal_mahalanobis_distance' in x:
        return 'marg_MD'
    elif 'mahalanobis_distance' in x:
        return 'MD'
    elif 'sampled_max_prob' in x:
        return 'SMP'
    elif 'variance' in x:
        return 'PV'
    elif 'aleatoric' in x:
        return 'aleatoric'
    elif 'epistemic' in x:
        return 'epistemic'
    elif 'total' in x:
        return 'total'
    elif x == 'disc_md':
        return 'Disc MD'
    elif x == 'nondisc_md':
        return 'Nondisc MD'
    elif x == 'disc+nondisc_md':
        return 'Disc+Nondisc MD'
    elif x == 'sngp':
        return 'std'
    return 'MP'

table_all['Reg. Type'] = table_all.Method.apply(lambda x: preproc_regs(x))
table_all['Method'] = table_all.Method.apply(lambda x: preproc_method(x))
table_all['UE Score'] = table_all['UE Score'].apply(lambda x: preproc_ue(x))
table_all = table_all[list(table_all.columns[:1]) + list(table_all.columns[-1:]) + list(table_all.columns[1:-1])].reset_index(drop=True)

IndexError: list index out of range

In [20]:
roc_auc_mean = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
roc_auc_std = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

f1_mean = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
f1_std = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_auc_mean = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_auc_std = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_mean = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_std = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

rec_auc_mean = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
rec_auc_std = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

In [27]:
table_all[('Mean',  'pr-auc')] = pr_auc_mean.round(2).astype(str)+'±'+pr_auc_std.round(2).astype(str)
table_all[('Mean',  'f1-score')] = f1_mean.round(2).astype(str)+'±'+f1_std.round(2).astype(str)
table_all[('Mean',  'precision')] = pr_mean.round(2).astype(str)+'±'+pr_std.round(2).astype(str)
table_all[('Mean',  'recall')] = rec_auc_mean.round(2).astype(str)+'±'+rec_auc_std.round(2).astype(str)
table_all[('Mean',  'roc-auc')] = roc_auc_mean.round(2).astype(str)+'±'+roc_auc_std.round(2).astype(str)

In [28]:
def bold_max(table):
    attr = 'font-weight: bold'
    data = table[table.columns[3:]].apply(lambda x: x.str.split('±').str[0].astype(float))
    is_max = data == data.max()
    info_col = table[table.columns[:3]]
    for col in info_col.columns:
        info_col[col] = ''
    vals = pd.DataFrame(np.where(is_max, attr, ''),
                        index=data.index, columns=data.columns)
    return pd.concat([info_col, vals], axis=1)

def highlight_nmax(s):
    try:
        s_vals = s.str.split('±').str[0].astype(float)
        is_large = s_vals.nlargest(3).values
        return ['background-color: yellow' if v in is_large else '' for v in s_vals]
    except:
        return ['']*len(s)

In [29]:
table_all[list(table_all.columns[:3]) + list(table_all.columns[-5:])].style.apply(highlight_nmax).apply(bold_max, axis=None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys


Unnamed: 0_level_0,Method,Reg. Type,UE Score,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,f1-score,precision,recall,roc-auc
0,L-NUQ SN,-,aleatoric,93.7±1.84,92.16±1.46,89.46±2.6,95.41±1.41,92.61±1.93
1,L-NUQ SN,-,epistemic,87.56±2.3,88.29±1.09,83.51±1.86,95.06±1.3,82.1±3.26
2,L-NUQ SN,-,total,92.03±1.89,91.06±1.27,87.43±2.17,95.64±1.27,90.37±1.99
3,L-NUQ,-,aleatoric,92.06±2.03,90.3±1.17,87.11±2.31,94.46±1.61,90.21±2.47
4,L-NUQ,-,epistemic,87.13±2.13,87.9±0.81,83.4±1.8,94.14±1.44,81.97±2.64
5,L-NUQ,-,total,90.63±1.84,89.71±0.77,85.83±1.83,94.9±1.39,88.53±2.29
6,NUQ SN,-,aleatoric,88.67±2.31,90.3±1.23,85.74±2.09,96.34±1.13,84.01±3.03
7,NUQ SN,-,epistemic,87.03±3.09,89.41±2.31,85.63±2.54,94.26±3.47,78.46±6.79
8,NUQ SN,-,total,87.9±3.24,89.9±2.09,85.9±2.53,95.07±2.39,80.13±6.09
9,NUQ,-,aleatoric,87.77±2.24,89.41±1.09,84.97±1.69,95.2±0.89,81.69±3.03


In [30]:
table_all.style.apply(highlight_nmax).apply(bold_max, axis=None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys


Unnamed: 0_level_0,Method,Reg. Type,UE Score,IMDB,IMDB,IMDB,IMDB,IMDB,TREC,TREC,TREC,TREC,TREC,WMT16,WMT16,WMT16,WMT16,WMT16,Amazon,Amazon,Amazon,Amazon,Amazon,MNLI,MNLI,MNLI,MNLI,MNLI,20newsgroups,20newsgroups,20newsgroups,20newsgroups,20newsgroups,RTE,RTE,RTE,RTE,RTE,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,pr-auc,f1-score,precision,recall,roc-auc
0,L-NUQ SN,-,aleatoric,96.9±1.8,99.8±0.1,99.4±0.1,99.1±0.2,99.6±0.1,89.7±4.3,77.6±6.4,80.4±5.4,71.4±7.0,92.4±4.1,88.1±2.1,95.6±1.0,91.1±1.0,88.0±2.5,94.4±1.2,93.5±0.7,99.7±0.0,99.0±0.1,98.3±0.1,99.6±0.1,85.6±3.0,98.2±0.5,96.3±0.2,93.6±0.7,99.3±0.4,99.8±0.1,100.0±0.0,99.8±0.1,99.6±0.2,99.9±0.1,94.7±1.5,85.0±4.9,79.1±3.3,76.2±7.5,82.7±3.9,93.7±1.84,92.16±1.46,89.46±2.6,95.41±1.41,92.61±1.93
1,L-NUQ SN,-,epistemic,78.1±2.2,98.7±0.2,98.6±0.0,97.4±0.1,99.8±0.1,77.2±4.3,59.0±6.6,68.3±2.6,53.6±3.3,94.3±2.4,80.3±3.7,91.9±1.7,89.3±0.7,83.0±2.1,96.7±2.1,76.4±4.5,98.8±0.3,98.6±0.0,97.4±0.1,99.8±0.1,79.2±3.5,97.2±0.5,96.1±0.2,92.9±0.6,99.5±0.3,95.3±1.4,99.2±0.3,98.2±0.3,96.9±0.5,99.4±0.3,88.2±3.2,68.1±6.5,68.9±3.8,63.4±6.3,75.9±3.8,87.56±2.3,88.29±1.09,83.51±1.86,95.06±1.3,82.1±3.26
2,L-NUQ SN,-,total,93.8±2.0,99.6±0.1,99.1±0.1,98.7±0.1,99.5±0.1,86.7±3.7,72.0±5.8,77.0±4.3,66.0±5.9,92.8±2.9,86.2±2.3,94.7±1.1,90.6±1.0,86.9±2.1,94.7±1.3,89.4±0.7,99.5±0.0,98.8±0.1,97.8±0.2,99.7±0.1,84.0±2.9,97.9±0.5,96.3±0.3,93.3±0.8,99.5±0.3,99.3±0.5,99.9±0.1,99.6±0.2,99.2±0.3,99.9±0.1,93.2±1.8,80.6±5.6,76.0±2.9,70.1±5.8,83.4±4.1,92.03±1.89,91.06±1.27,87.43±2.17,95.64±1.27,90.37±1.99
3,L-NUQ,-,aleatoric,96.3±3.0,99.8±0.2,99.4±0.1,99.1±0.2,99.7±0.1,85.5±5.0,72.4±7.5,74.0±4.9,63.5±8.5,90.1±5.5,84.9±2.0,94.5±1.0,89.6±0.9,85.4±2.2,94.3±2.0,92.8±2.5,99.7±0.1,98.9±0.1,98.2±0.2,99.5±0.1,80.0±3.3,97.5±0.5,96.0±0.1,92.6±0.2,99.7±0.1,99.8±0.1,100.0±0.0,99.7±0.2,99.4±0.5,99.9±0.0,92.2±1.4,80.5±4.9,74.5±1.9,71.6±4.4,78.0±3.5,92.06±2.03,90.3±1.17,87.11±2.31,94.46±1.61,90.21±2.47
4,L-NUQ,-,epistemic,78.5±3.0,98.7±0.2,98.7±0.1,97.6±0.2,99.8±0.1,77.7±6.6,58.7±10.2,68.8±2.9,55.4±4.6,91.3±2.9,79.6±1.7,91.7±0.7,89.0±0.3,82.4±1.2,96.9±1.4,78.5±2.9,98.9±0.2,98.6±0.0,97.5±0.1,99.8±0.1,77.2±2.4,96.9±0.3,96.0±0.0,92.5±0.2,99.7±0.1,95.2±0.7,99.1±0.1,98.2±0.3,97.0±0.6,99.4±0.2,87.1±1.2,65.9±3.2,66.0±2.1,61.4±5.7,72.1±5.3,87.13±2.13,87.9±0.81,83.4±1.8,94.14±1.44,81.97±2.64
5,L-NUQ,-,total,92.9±3.7,99.6±0.2,99.2±0.1,98.8±0.3,99.6±0.1,83.6±3.8,67.5±6.9,72.9±2.8,60.7±3.9,91.5±2.3,83.8±1.6,93.8±0.8,89.8±0.7,84.2±1.8,96.3±0.9,89.1±2.9,99.5±0.2,98.7±0.0,97.8±0.1,99.7±0.1,79.9±2.4,97.4±0.4,96.0±0.1,92.6±0.3,99.7±0.2,99.2±0.4,99.9±0.1,99.5±0.3,99.0±0.5,99.9±0.0,91.2±1.2,76.7±4.3,71.9±1.4,67.7±5.9,77.6±6.1,90.63±1.84,89.71±0.77,85.83±1.83,94.9±1.39,88.53±2.29
6,NUQ SN,-,aleatoric,69.0±2.3,98.2±0.2,98.3±0.0,96.7±0.1,100.0±0.0,82.5±4.5,62.5±6.7,73.9±3.8,61.9±5.5,92.2±2.8,85.9±2.5,93.5±1.4,91.6±1.1,86.9±2.5,96.8±1.3,78.1±5.8,98.9±0.3,98.7±0.0,97.6±0.2,99.8±0.1,86.1±1.8,97.9±0.3,96.9±0.3,94.7±0.6,99.1±0.1,94.6±2.3,98.8±0.6,98.1±0.5,97.1±0.8,99.2±0.3,91.9±2.0,70.9±6.7,74.6±2.9,65.3±4.9,87.3±3.3,88.67±2.31,90.3±1.23,85.74±2.09,96.34±1.13,84.01±3.03
7,NUQ SN,-,epistemic,66.8±4.9,98.0±0.4,98.4±0.0,96.9±0.0,99.9±0.0,84.6±3.5,65.1±5.4,76.8±3.7,66.0±5.5,92.4±5.3,78.0±8.1,90.8±3.0,90.6±1.6,84.7±3.6,97.6±1.3,64.1±13.6,98.0±0.8,98.7±0.0,97.5±0.1,99.9±0.0,79.2±7.1,97.1±0.9,96.8±0.4,94.5±1.0,99.3±0.3,93.4±1.7,98.7±0.2,97.6±0.8,95.9±1.6,99.3±0.9,83.1±8.6,61.5±10.9,67.0±9.7,63.9±6.0,71.4±16.5,87.03±3.09,89.41±2.31,85.63±2.54,94.26±3.47,78.46±6.79
8,NUQ SN,-,total,67.5±3.7,98.1±0.3,98.4±0.0,96.9±0.1,99.9±0.0,84.2±4.5,64.4±6.4,76.5±4.3,64.9±6.3,93.6±2.1,80.4±7.0,91.9±2.8,90.7±1.6,84.8±3.8,97.5±1.2,66.5±12.7,98.3±0.7,98.7±0.1,97.6±0.2,99.9±0.1,81.2±5.7,97.4±0.8,96.8±0.4,94.4±0.9,99.4±0.2,94.5±2.0,98.8±0.5,98.0±0.4,97.3±0.6,98.7±0.8,86.6±7.0,66.4±11.2,70.2±7.8,65.4±5.8,76.5±12.3,87.9±3.24,89.9±2.09,85.9±2.53,95.07±2.39,80.13±6.09
9,NUQ,-,aleatoric,66.0±3.8,98.0±0.2,98.3±0.0,96.7±0.0,100.0±0.0,82.8±6.3,63.8±10.4,73.9±4.3,62.6±6.3,90.9±2.7,83.6±1.5,92.7±0.9,90.8±0.8,86.4±1.9,95.9±0.8,73.8±5.2,98.7±0.3,98.6±0.0,97.4±0.1,99.9±0.1,83.0±1.6,97.5±0.2,96.5±0.3,93.9±0.8,99.2±0.3,93.3±1.9,98.7±0.3,97.7±0.6,96.3±1.0,99.0±0.4,89.3±0.9,65.0±3.4,70.1±1.6,61.5±1.7,81.5±1.9,87.77±2.24,89.41±1.09,84.97±1.69,95.2±0.89,81.69±3.03


## 20newsgroups

In [31]:
import os 
def choose_agg_func(method):
    agg_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    if method=='nuq' or method=='nuq_best' or method=='nuq_best1':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "nuq_aleatoric": nuq_aleatoric,
            "nuq_epistemic": nuq_epistemic,
            "nuq_total": nuq_total,
        }
    elif method=='l_nuq':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "l_nuq_aleatoric": nuq_aleatoric,
            "l_nuq_epistemic": nuq_epistemic,
            "l_nuq_total": nuq_total,
        }
    elif method=='decomposing_md':
        disc_md = lambda x: np.squeeze(x[0], axis=-1)
        nondisc_md = lambda x: np.squeeze(x[1], axis=-1)
        sum_md = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"disc_md": disc_md, 
                       'nondisc_md': nondisc_md,
                       'disc+nondisc_md': sum_md}
    elif method=='mahalanobis':
        maha_dist = lambda x: np.squeeze(x[0], axis=-1)
        rel_maha_dist = lambda x: np.squeeze(x[1], axis=-1)
        marg_maha_dist = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"mahalanobis_distance": maha_dist,
                       "relative_mahalanobis_distance": rel_maha_dist,
                       "marginal_mahalanobis_distance": marg_maha_dist}
    elif method=='l_mahalanobis':
        maha_dist = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"l_mahalanobis_distance": maha_dist}
    elif method=='ddu' or method=='ddu_maha':
        ddu = lambda x: -np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"ddu": ddu}
    elif method=='sngp':
        sngp = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"sngp": sngp}
    elif method=='mc_mahalanobis':
        sm_maha_dist = lambda x: np.squeeze(x[:, 1:], axis=-1).max(1)
        agg_methods = {"sampled_mahalanobis_distance": sm_maha_dist}
    return agg_methods

metric_types=["roc-auc", 'pr-auc', 'f1-score', 'precision', 'recall']

methods = ['l_nuq', 'nuq', 'mahalanobis', 'l_mahalanobis', 'decomposing_md', 'ddu']
regs = ['raw']
spectralnorm = ['sn', 'no_sn']
dataset_names = ['IMDB', 'TREC', 'WMT16', 'Amazon', 'MNLI', 'RTE', 'SST-2']
dataset_fnames = ['imdb', 'trec', 'wmt16', 'amazon', 'mnli', 'rte', 'sst2']
names = []
tables = []
baselines = []
for method in methods:
    for reg in regs:
        for sn in spectralnorm:
            run_dirs = []
            name_sn = ''
            names = [f'{method}|{reg}|{sn}']
            for name in dataset_fnames:
                model_series_dir = f'../../workdir/run_tasks_for_model_series/electra_{reg}_{sn}/20newsgroups/{name}/{method}'
                print(model_series_dir)
                run_dirs.append([model_series_dir])
            agg_func = choose_agg_func(method)
            try:
                res_df = collect_datasets(run_dirs, names, dataset_names, metric_types=metric_types, baselines={}, methods=agg_func)
                baselines.append(res_df.iloc[-1:])
                tables.append(res_df.iloc[:-1])
            except:
                print('pass')
                pass

../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/imdb/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/trec/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/wmt16/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/amazon/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/mnli/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/rte/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn/20newsgroups/sst2/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_no_sn/20newsgroups/imdb/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_no_sn/20newsgroups/trec/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_no_sn/20newsgroups/wmt16/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_no_sn/20newsgroups/amazon/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_no_sn/20newsgroups/mnli/l_nuq
../../workdir/run_ta

In [32]:
table_all = pd.concat([pd.concat(tables), pd.concat(baselines[-2:])]).reset_index()

In [33]:
def preproc_regs(x):
    reg = x.split('|')[1]
    if reg == 'reg':
        return 'CER'
    elif reg == 'raw':
        return '-'
    else:
        return reg
    
def preproc_method(x):
    method = x.split('|')[0]
    sn = x.split('|')[-1]
    if method == 'mahalanobis' and not 'no_sn' in sn:
        return 'MD SN (ours)'
    elif method == 'mahalanobis':
        return 'MD'
    elif method == 'l_mahalanobis' and not 'no_sn' in sn:
        return 'L-MD SN (ours)'
    elif method == 'l_mahalanobis':
        return 'L-MD'
    elif method == 'mc_mahalanobis' and not 'no_sn' in sn:
        return 'SMD SN (ours)'
    elif method == 'mc_mahalanobis':
        return 'SMD'
    elif method == 'nuq' and not 'no_sn' in sn:
        return 'NUQ SN'
    elif method == 'nuq':
        return 'NUQ'
    elif method == 'l_nuq' and not 'no_sn' in sn:
        return 'L-NUQ SN'
    elif method == 'l_nuq':
        return 'L-NUQ'
    elif method == 'sngp':
        return 'SNGP'
    
    elif method == 'decomposing_md' and not 'no_sn' in sn:
        return 'Decomposing SN'
    elif method == 'decomposing_md':
        return 'Decomposing'
    
    elif method == 'nuq_best1' and not 'no_sn' in sn:
        return 'Best1 NUQ SN'
    elif method == 'nuq_best1':
        return 'Best1 NUQ'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu_maha' and not 'no_sn' in sn:
        return 'DDU Maha SN'
    elif method == 'ddu_maha':
        return 'DDU Maha'
    
    elif 'ddpp_dpp' in method:
        return 'DDPP (+DPP) (ours)'
    elif 'ddpp_ood' in method:
        return 'DDPP (+OOD) (ours)'
    elif 'mc_all' in method:
        return 'MC dropout'
    elif 'Deep' in method:
        return 'Deep Ensemble'
    elif 'baseline|raw_no_sn' in x:
        return 'SR (baseline)'
    elif 'baseline' in x and not 'no_sn' in x:
        return 'SR SN'
    return 'SR'

def preproc_ue(x):
    if x == 'bald':
        return 'BALD'
    elif 'sampled_mahalanobis_distance' in x:
        return 'SMD'
    elif 'relative_mahalanobis_distance' in x:
        return 'rel_MD'
    elif 'marginal_mahalanobis_distance' in x:
        return 'marg_MD'
    elif 'mahalanobis_distance' in x:
        return 'MD'
    elif 'sampled_max_prob' in x:
        return 'SMP'
    elif 'variance' in x:
        return 'PV'
    elif 'aleatoric' in x:
        return 'aleatoric'
    elif 'epistemic' in x:
        return 'epistemic'
    elif 'total' in x:
        return 'total'
    elif x == 'disc_md':
        return 'Disc MD'
    elif x == 'nondisc_md':
        return 'Nondisc MD'
    elif x == 'disc+nondisc_md':
        return 'Disc+Nondisc MD'
    elif x == 'sngp':
        return 'std'
    return 'MP'

table_all['Reg. Type'] = table_all.Method.apply(lambda x: preproc_regs(x))
table_all['Method'] = table_all.Method.apply(lambda x: preproc_method(x))
table_all['UE Score'] = table_all['UE Score'].apply(lambda x: preproc_ue(x))
table_all = table_all[list(table_all.columns[:1]) + list(table_all.columns[-1:]) + list(table_all.columns[1:-1])].reset_index(drop=True)

In [34]:
roc_auc_mean = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
roc_auc_std = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

f1_mean = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
f1_std = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_auc_mean = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_auc_std = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_mean = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_std = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

rec_auc_mean = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
rec_auc_std = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

table_all[('Mean',  'pr-auc')] = pr_auc_mean.round(2).astype(str)+'±'+pr_auc_std.round(2).astype(str)
table_all[('Mean',  'f1-score')] = f1_mean.round(2).astype(str)+'±'+f1_std.round(2).astype(str)
table_all[('Mean',  'precision')] = pr_mean.round(2).astype(str)+'±'+pr_std.round(2).astype(str)
table_all[('Mean',  'recall')] = rec_auc_mean.round(2).astype(str)+'±'+rec_auc_std.round(2).astype(str)
table_all[('Mean',  'roc-auc')] = roc_auc_mean.round(2).astype(str)+'±'+roc_auc_std.round(2).astype(str)

In [35]:
def bold_max(table):
    attr = 'font-weight: bold'
    data = table[table.columns[3:]].apply(lambda x: x.str.split('±').str[0].astype(float))
    is_max = data == data.max()
    info_col = table[table.columns[:3]]
    for col in info_col.columns:
        info_col[col] = ''
    vals = pd.DataFrame(np.where(is_max, attr, ''),
                        index=data.index, columns=data.columns)
    return pd.concat([info_col, vals], axis=1)

def highlight_nmax(s):
    try:
        s_vals = s.str.split('±').str[0].astype(float)
        is_large = s_vals.nlargest(3).values
        style_str = ['background-color: yellow' if v in is_large else '' for v in s_vals]
        if len(np.unique(style_str))==1:
            return ['']*len(style_str)
        return style_str
    except:
        return ['']*len(s)

In [36]:
table_all[list(table_all.columns[:3]) + list(table_all.columns[-5:])].style.apply(highlight_nmax).apply(bold_max, axis=None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys


Unnamed: 0_level_0,Method,Reg. Type,UE Score,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,f1-score,precision,recall,roc-auc
0,L-NUQ SN,-,aleatoric,59.56±4.17,70.27±2.97,62.07±3.96,87.69±3.93,89.97±1.79
1,L-NUQ SN,-,epistemic,65.37±5.07,72.17±2.67,65.04±3.7,84.23±4.74,91.84±1.74
2,L-NUQ SN,-,total,62.93±5.41,71.54±3.37,64.16±5.6,86.2±4.59,91.1±1.8
3,L-NUQ,-,aleatoric,59.59±3.86,68.27±2.87,60.21±3.53,83.64±4.19,89.2±2.31
4,L-NUQ,-,epistemic,64.5±5.69,69.64±2.41,62.9±3.53,80.63±3.94,90.86±1.8
5,L-NUQ,-,total,62.46±4.46,69.31±2.59,62.01±3.56,82.27±4.86,90.39±1.97
6,NUQ SN,-,aleatoric,60.03±2.09,68.74±1.77,61.11±2.39,83.73±3.86,89.23±1.23
7,NUQ SN,-,epistemic,65.63±2.91,70.47±1.94,64.84±3.59,78.69±4.24,91.04±1.26
8,NUQ SN,-,total,61.17±2.1,69.13±1.79,61.94±2.6,82.3±3.97,89.71±1.23
9,NUQ,-,aleatoric,61.23±1.86,67.09±1.99,60.19±2.17,78.69±2.9,88.73±1.19


In [37]:
table_all.style.apply(highlight_nmax).apply(bold_max, axis=None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys


Unnamed: 0_level_0,Method,Reg. Type,UE Score,IMDB,IMDB,IMDB,IMDB,IMDB,TREC,TREC,TREC,TREC,TREC,WMT16,WMT16,WMT16,WMT16,WMT16,Amazon,Amazon,Amazon,Amazon,Amazon,MNLI,MNLI,MNLI,MNLI,MNLI,RTE,RTE,RTE,RTE,RTE,SST-2,SST-2,SST-2,SST-2,SST-2,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,f1-score,precision,recall,pr-auc,f1-score,precision,recall,roc-auc
0,L-NUQ SN,-,aleatoric,90.8±2.0,93.2±1.8,95.8±0.6,93.7±1.3,98.1±0.3,92.0±2.1,30.5±8.8,48.4±8.6,36.7±9.8,74.0±8.9,89.9±1.8,65.4±6.0,75.9±1.9,66.0±5.5,90.2±5.5,88.5±1.3,94.0±1.1,94.5±0.3,91.3±0.7,98.0±0.3,89.4±1.8,83.3±2.8,89.8±1.7,83.7±1.9,96.9±2.1,88.5±1.5,14.5±3.1,27.3±4.2,17.4±4.0,67.5±6.2,90.7±2.0,36.0±5.6,60.2±3.5,45.7±4.5,89.1±4.2,59.56±4.17,70.27±2.97,62.07±3.96,87.69±3.93,89.97±1.79
1,L-NUQ SN,-,epistemic,93.2±2.3,95.4±2.1,96.1±0.2,94.3±0.5,98.0±0.2,93.2±1.4,36.9±7.4,50.3±5.4,40.7±6.9,67.6±7.7,91.8±1.5,74.8±6.4,76.5±1.0,66.3±2.5,90.6±3.0,89.4±1.6,94.8±1.3,94.2±0.3,90.5±0.5,98.2±0.4,92.2±1.6,88.5±3.3,91.2±0.5,85.5±0.8,97.7±1.0,89.7±1.7,20.7±7.0,31.3±6.5,24.1±7.1,52.0±15.1,93.4±2.1,46.5±8.0,65.6±4.8,53.9±7.6,85.5±5.8,65.37±5.07,72.17±2.67,65.04±3.7,84.23±4.74,91.84±1.74
2,L-NUQ SN,-,total,92.1±2.0,94.5±1.8,96.0±0.2,94.1±0.6,98.1±0.3,92.9±1.8,35.9±11.7,50.6±9.4,40.9±14.1,71.1±6.9,91.1±1.8,70.6±6.8,76.4±1.8,67.7±7.6,89.3±6.9,89.2±1.4,94.6±1.0,94.4±0.3,91.2±0.5,97.9±0.3,90.9±1.7,86.0±3.1,90.6±1.0,84.5±1.4,97.6±1.1,89.3±1.8,17.6±5.7,29.6±6.6,20.7±8.2,61.9±11.6,92.2±2.1,41.3±7.8,63.2±4.3,50.0±6.8,87.5±5.0,62.93±5.41,71.54±3.37,64.16±5.6,86.2±4.59,91.1±1.8
3,L-NUQ,-,aleatoric,90.7±1.9,93.1±1.7,95.0±0.8,92.2±1.4,98.0±0.4,90.5±2.0,25.7±3.9,42.6±4.2,32.0±4.8,66.5±14.3,89.6±1.5,66.4±4.0,73.7±2.9,62.4±4.4,90.3±2.6,85.6±3.0,93.5±1.9,92.4±1.4,88.8±1.5,96.2±1.6,89.4±3.3,85.0±5.7,87.9±2.0,81.1±3.2,95.9±0.9,87.4±2.1,14.4±2.7,27.1±3.8,17.9±3.1,57.8±5.5,91.2±2.4,39.0±7.1,59.2±5.0,47.1±6.3,80.8±4.0,59.59±3.86,68.27±2.87,60.21±3.53,83.64±4.19,89.2±2.31
4,L-NUQ,-,epistemic,93.2±2.1,95.5±2.3,95.7±0.3,94.0±0.7,97.5±0.6,92.0±2.0,33.9±8.7,44.1±4.0,35.7±5.8,59.5±8.3,90.4±1.1,71.6±5.7,74.3±1.4,63.4±1.8,89.8±2.4,87.3±1.3,94.5±0.9,92.8±1.0,89.3±0.9,96.5±1.5,92.0±2.3,89.5±5.0,90.5±1.2,85.0±1.8,96.8±0.6,88.1±1.1,17.1±3.4,28.9±2.9,22.0±4.3,45.7±9.3,93.0±2.7,49.4±13.8,61.2±6.1,50.9±9.4,78.6±4.9,64.5±5.69,69.64±2.41,62.9±3.53,80.63±3.94,90.86±1.8
5,L-NUQ,-,total,92.2±2.0,94.4±2.0,95.5±0.4,93.4±0.8,97.7±0.6,91.6±1.8,30.0±5.9,43.7±4.4,34.0±5.2,64.2±15.3,90.4±1.1,70.3±4.3,74.4±1.9,63.7±3.4,89.8±3.7,86.8±2.1,94.2±1.3,92.7±1.3,89.0±1.5,96.6±1.4,91.2±2.9,87.7±5.3,89.8±1.6,83.7±2.6,97.0±1.1,88.1±1.3,16.1±2.9,28.5±3.4,19.7±3.4,53.3±6.9,92.4±2.6,44.5±9.5,60.6±5.1,50.6±8.0,77.3±5.0,62.46±4.46,69.31±2.59,62.01±3.56,82.27±4.86,90.39±1.97
6,NUQ SN,-,aleatoric,92.0±1.3,95.1±1.0,95.8±0.2,94.0±0.4,97.7±0.3,89.2±1.1,24.8±2.1,42.9±4.0,31.5±4.0,67.7±5.1,87.6±0.8,62.8±2.1,71.0±1.6,61.5±3.5,84.4±3.7,87.7±1.2,94.3±0.7,94.0±0.5,90.5±0.8,97.8±0.7,90.6±1.1,87.2±1.5,89.5±0.4,83.9±1.0,95.9±0.9,85.9±1.5,12.9±1.3,25.4±3.0,16.4±2.3,58.1±10.2,91.6±1.6,43.1±5.9,62.6±2.7,50.0±4.7,84.5±6.1,60.03±2.09,68.74±1.77,61.11±2.39,83.73±3.86,89.23±1.23
7,NUQ SN,-,epistemic,94.6±1.1,97.0±0.8,95.7±0.2,93.9±0.5,97.6±0.4,91.1±1.3,34.3±4.8,47.9±4.2,43.0±7.8,55.9±6.5,89.3±1.0,70.5±2.9,70.7±1.7,61.6±5.0,83.8±6.4,88.7±1.5,95.0±0.5,93.9±0.6,90.5±0.8,97.6±0.8,92.4±1.0,90.7±1.7,89.3±0.5,83.6±1.2,95.8±0.9,87.2±1.7,17.8±2.9,29.3±3.8,23.6±4.6,40.6±8.9,94.0±1.2,54.1±6.8,66.5±2.6,57.7±5.2,79.5±5.8,65.63±2.91,70.47±1.94,64.84±3.59,78.69±4.24,91.04±1.26
8,NUQ SN,-,total,92.7±1.2,95.6±0.9,95.8±0.2,94.2±0.4,97.4±0.4,89.7±1.1,26.7±2.8,44.3±4.3,34.0±5.1,64.7±6.1,88.0±0.9,64.5±2.4,70.8±1.7,62.3±4.4,82.6±4.5,88.0±1.3,94.6±0.7,94.0±0.5,90.4±0.8,97.8±0.6,91.1±1.1,88.1±1.6,89.4±0.5,83.8±0.7,95.7±0.7,86.2±1.6,13.4±1.7,26.0±3.1,17.3±2.7,54.5±10.5,92.3±1.4,45.3±4.6,63.6±2.2,51.6±4.1,83.4±5.0,61.17±2.1,69.13±1.79,61.94±2.6,82.3±3.97,89.71±1.23
9,NUQ,-,aleatoric,92.9±1.2,96.2±0.9,95.5±0.5,93.5±0.7,97.6±0.5,88.0±1.2,25.1±1.7,39.1±2.1,30.3±2.1,55.5±3.0,86.6±0.9,64.3±1.5,67.9±2.0,59.5±1.9,79.1±3.1,86.5±1.7,94.4±0.5,93.0±0.9,88.4±1.0,98.0±0.8,90.0±1.0,87.9±1.3,87.8±0.7,82.0±1.5,94.5±1.1,84.9±1.0,14.5±1.5,25.6±2.5,17.4±1.6,48.9±7.5,92.2±1.3,46.2±5.6,60.7±5.2,50.2±6.4,77.2±4.3,61.23±1.86,67.09±1.99,60.19±2.17,78.69±2.9,88.73±1.19


## SN ALL

In [21]:
import os 

def choose_agg_func(method):
    agg_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    if method=='nuq' or method=='nuq_best' or method=='nuq_best1':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "nuq_aleatoric": nuq_aleatoric,
            "nuq_epistemic": nuq_epistemic,
            "nuq_total": nuq_total,
        }
    elif method=='l_nuq':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "l_nuq_aleatoric": nuq_aleatoric,
            "l_nuq_epistemic": nuq_epistemic,
            "l_nuq_total": nuq_total,
        }
    elif method=='decomposing_md':
        disc_md = lambda x: np.squeeze(x[0], axis=-1)
        nondisc_md = lambda x: np.squeeze(x[1], axis=-1)
        sum_md = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"disc_md": disc_md, 
                       'nondisc_md': nondisc_md,
                       'disc+nondisc_md': sum_md}
    elif method=='mahalanobis':
        maha_dist = lambda x: np.squeeze(x[0], axis=-1)
        rel_maha_dist = lambda x: np.squeeze(x[1], axis=-1)
        marg_maha_dist = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"mahalanobis_distance": maha_dist,
                       "relative_mahalanobis_distance": rel_maha_dist,
                       "marginal_mahalanobis_distance": marg_maha_dist}
    elif method=='l_mahalanobis':
        maha_dist = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"l_mahalanobis_distance": maha_dist}
    elif method=='ddu' or method=='ddu_maha':
        ddu = lambda x: -np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"ddu": ddu}
    elif method=='sngp':
        sngp = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"sngp": sngp}
    elif method=='mc_mahalanobis':
        sm_maha_dist = lambda x: np.squeeze(x[:, 1:], axis=-1).max(1)
        agg_methods = {"sampled_mahalanobis_distance": sm_maha_dist}
    return agg_methods
metric_types=["roc-auc", 'pr-auc', 'f1-score', 'precision', 'recall']

methods = ['l_nuq', 'nuq', 'mahalanobis', 'l_mahalanobis', 'decomposing_md', 'ddu']
regs = ['raw']
spectralnorm = ['sn_all']
dataset_names = ['IMDB', 'TREC', 'WMT16', 'Amazon', 'MNLI', '20newsgroups', 'RTE']
dataset_fnames = ['imdb', 'trec', 'wmt16', 'amazon', 'mnli', 'newsgroup', 'rte']
names = []
tables = []
baselines = []
for method in methods:
    for reg in regs:
        for sn in spectralnorm:
            run_dirs = []
            name_sn = ''
            names = [f'{method}|{reg}|{sn}']
            for name in dataset_fnames:
                model_series_dir = f'../../workdir/run_tasks_for_model_series/electra_{reg}_{sn}/sst2/{name}/{method}'
                print(model_series_dir)
                run_dirs.append([model_series_dir])
            agg_func = choose_agg_func(method)
            try:
                res_df = collect_datasets(run_dirs, names, dataset_names, metric_types=metric_types, baselines={}, methods=agg_func)
                baselines.append(res_df.iloc[-1:])
                tables.append(res_df.iloc[:-1])
            except:
                print('pass')
                pass

../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/imdb/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/trec/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/wmt16/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/amazon/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/mnli/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/newsgroup/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/rte/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/imdb/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/trec/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/wmt16/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/amazon/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/mnli/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/sst2/newsgroup/nuq
../../wor

In [22]:
table_all = pd.concat([pd.concat(tables), pd.concat(baselines[-1:])]).reset_index()

In [23]:
def preproc_regs(x):
    reg = x.split('|')[1]
    if reg == 'reg':
        return 'CER'
    elif reg == 'raw':
        return '-'
    else:
        return reg
    
def preproc_method(x):
    method = x.split('|')[0]
    sn = x.split('|')[-1]
    if method == 'mahalanobis' and not 'no_sn' in sn:
        return 'MD SN_ALL (ours)'
    elif method == 'mahalanobis':
        return 'MD'
    elif method == 'l_mahalanobis' and not 'no_sn' in sn:
        return 'L-MD SN_ALL (ours)'
    elif method == 'l_mahalanobis':
        return 'L-MD'
    elif method == 'mc_mahalanobis' and not 'no_sn' in sn:
        return 'SMD SN_ALL (ours)'
    elif method == 'mc_mahalanobis':
        return 'SMD'
    elif method == 'nuq' and not 'no_sn' in sn:
        return 'NUQ SN_ALL'
    elif method == 'nuq':
        return 'NUQ'
    elif method == 'l_nuq' and not 'no_sn' in sn:
        return 'L-NUQ SN_ALL'
    elif method == 'l_nuq':
        return 'L-NUQ'
    elif method == 'sngp':
        return 'SNGP'
    
    elif method == 'decomposing_md' and not 'no_sn' in sn:
        return 'Decomposing SN_ALL'
    elif method == 'decomposing_md':
        return 'Decomposing'
    
    elif method == 'nuq_best1' and not 'no_sn' in sn:
        return 'Best1 NUQ SN_ALL'
    elif method == 'nuq_best1':
        return 'Best1 NUQ'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN_ALL'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu_maha' and not 'no_sn' in sn:
        return 'DDU Maha SN_ALL'
    elif method == 'ddu_maha':
        return 'DDU Maha'
    
    elif 'ddpp_dpp' in method:
        return 'DDPP (+DPP) (ours)'
    elif 'ddpp_ood' in method:
        return 'DDPP (+OOD) (ours)'
    elif 'mc_all' in method:
        return 'MC dropout'
    elif 'Deep' in method:
        return 'Deep Ensemble'
    elif 'baseline|raw_no_sn' in x:
        return 'SR (baseline)'
    elif 'baseline' in x and not 'no_sn' in x:
        return 'SR SN'
    return 'SR'

def preproc_ue(x):
    if x == 'bald':
        return 'BALD'
    elif 'sampled_mahalanobis_distance' in x:
        return 'SMD'
    elif 'relative_mahalanobis_distance' in x:
        return 'rel_MD'
    elif 'marginal_mahalanobis_distance' in x:
        return 'marg_MD'
    elif 'mahalanobis_distance' in x:
        return 'MD'
    elif 'sampled_max_prob' in x:
        return 'SMP'
    elif 'variance' in x:
        return 'PV'
    elif 'aleatoric' in x:
        return 'aleatoric'
    elif 'epistemic' in x:
        return 'epistemic'
    elif 'total' in x:
        return 'total'
    elif x == 'disc_md':
        return 'Disc MD'
    elif x == 'nondisc_md':
        return 'Nondisc MD'
    elif x == 'disc+nondisc_md':
        return 'Disc+Nondisc MD'
    elif x == 'sngp':
        return 'std'
    return 'MP'

table_all['Reg. Type'] = table_all.Method.apply(lambda x: preproc_regs(x))
table_all['Method'] = table_all.Method.apply(lambda x: preproc_method(x))
table_all['UE Score'] = table_all['UE Score'].apply(lambda x: preproc_ue(x))
table_all = table_all[list(table_all.columns[:1]) + list(table_all.columns[-1:]) + list(table_all.columns[1:-1])].reset_index(drop=True)

In [24]:
roc_auc_mean = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
roc_auc_std = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

f1_mean = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
f1_std = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_auc_mean = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_auc_std = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_mean = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_std = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

rec_auc_mean = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
rec_auc_std = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

table_all[('Mean',  'pr-auc')] = pr_auc_mean.round(2).astype(str)+'±'+pr_auc_std.round(2).astype(str)
table_all[('Mean',  'f1-score')] = f1_mean.round(2).astype(str)+'±'+f1_std.round(2).astype(str)
table_all[('Mean',  'precision')] = pr_mean.round(2).astype(str)+'±'+pr_std.round(2).astype(str)
table_all[('Mean',  'recall')] = rec_auc_mean.round(2).astype(str)+'±'+rec_auc_std.round(2).astype(str)
table_all[('Mean',  'roc-auc')] = roc_auc_mean.round(2).astype(str)+'±'+roc_auc_std.round(2).astype(str)

In [25]:
def bold_max(table):
    attr = 'font-weight: bold'
    data = table[table.columns[3:]].apply(lambda x: x.str.split('±').str[0].astype(float))
    is_max = data == data.max()
    info_col = table[table.columns[:3]]
    for col in info_col.columns:
        info_col[col] = ''
    vals = pd.DataFrame(np.where(is_max, attr, ''),
                        index=data.index, columns=data.columns)
    return pd.concat([info_col, vals], axis=1)

def highlight_nmax(s):
    try:
        s_vals = s.str.split('±').str[0].astype(float)
        is_large = s_vals.nlargest(3).values
        style_str = ['background-color: yellow' if v in is_large else '' for v in s_vals]
        if len(np.unique(style_str))==1:
            return ['']*len(style_str)
        return style_str
    except:
        return ['']*len(s)

In [26]:
table_all[list(table_all.columns[:3]) + list(table_all.columns[-5:])].style.apply(highlight_nmax).apply(bold_max, axis=None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys


Unnamed: 0_level_0,Method,Reg. Type,UE Score,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,f1-score,precision,recall,roc-auc
0,L-NUQ SN_ALL,-,aleatoric,81.79±1.31,85.13±0.59,78.89±0.93,95.8±1.86,69.6±2.64
1,L-NUQ SN_ALL,-,epistemic,80.53±1.0,84.31±0.53,77.61±0.81,96.29±1.81,70.76±1.76
2,L-NUQ SN_ALL,-,total,81.47±1.2,84.91±0.44,78.47±0.74,95.73±1.66,70.8±1.9
3,NUQ SN_ALL,-,aleatoric,79.74±0.69,84.09±0.51,77.69±0.87,95.27±2.13,70.79±1.76
4,NUQ SN_ALL,-,epistemic,79.56±0.96,84.23±0.6,77.96±1.07,94.97±2.6,68.76±2.24
5,NUQ SN_ALL,-,total,79.67±0.71,84.2±0.53,77.89±1.01,95.06±2.49,69.7±2.09
6,MD SN_ALL (ours),-,MD,79.87±1.09,84.31±0.53,77.86±0.8,95.51±1.47,72.87±1.8
7,MD SN_ALL (ours),-,rel_MD,76.1±2.76,81.66±0.91,74.9±1.59,95.3±4.86,54.64±13.2
8,MD SN_ALL (ours),-,marg_MD,79.9±1.06,84.31±0.54,77.89±0.77,95.59±1.54,72.89±1.73
9,L-MD SN_ALL (ours),-,MD,81.09±0.91,86.13±0.39,79.73±0.53,96.9±0.76,76.51±1.33


#### baseline - 75.56±5.49
#### best - 92.64±1.23

In [27]:
import os 

def choose_agg_func(method):
    agg_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    if method=='nuq' or method=='nuq_best' or method=='nuq_best1':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "nuq_aleatoric": nuq_aleatoric,
            "nuq_epistemic": nuq_epistemic,
            "nuq_total": nuq_total,
        }
    elif method=='l_nuq':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "l_nuq_aleatoric": nuq_aleatoric,
            "l_nuq_epistemic": nuq_epistemic,
            "l_nuq_total": nuq_total,
        }
    elif method=='decomposing_md':
        disc_md = lambda x: np.squeeze(x[0], axis=-1)
        nondisc_md = lambda x: np.squeeze(x[1], axis=-1)
        sum_md = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"disc_md": disc_md, 
                       'nondisc_md': nondisc_md,
                       'disc+nondisc_md': sum_md}
    elif method=='mahalanobis':
        maha_dist = lambda x: np.squeeze(x[0], axis=-1)
        rel_maha_dist = lambda x: np.squeeze(x[1], axis=-1)
        marg_maha_dist = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"mahalanobis_distance": maha_dist,
                       "relative_mahalanobis_distance": rel_maha_dist,
                       "marginal_mahalanobis_distance": marg_maha_dist}
    elif method=='l_mahalanobis':
        maha_dist = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"l_mahalanobis_distance": maha_dist}
    elif method=='ddu' or method=='ddu_maha':
        ddu = lambda x: -np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"ddu": ddu}
    elif method=='sngp':
        sngp = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"sngp": sngp}
    elif method=='mc_mahalanobis':
        sm_maha_dist = lambda x: np.squeeze(x[:, 1:], axis=-1).max(1)
        agg_methods = {"sampled_mahalanobis_distance": sm_maha_dist}
    return agg_methods
metric_types=["roc-auc", 'pr-auc', 'f1-score', 'precision', 'recall']

methods = ['l_nuq', 'nuq', 'mahalanobis', 'l_mahalanobis', 'decomposing_md', 'ddu']
regs = ['raw']
spectralnorm = ['sn_all']
dataset_names = ['IMDB', 'TREC', 'WMT16', 'Amazon', 'MNLI', 'SST-2', 'RTE']
dataset_fnames = ['imdb', 'trec', 'wmt16', 'amazon', 'mnli', 'sst2', 'rte']
names = []
tables = []
baselines = []
for method in methods:
    for reg in regs:
        for sn in spectralnorm:
            run_dirs = []
            name_sn = ''
            names = [f'{method}|{reg}|{sn}']
            for name in dataset_fnames:
                model_series_dir = f'../../workdir/run_tasks_for_model_series/electra_{reg}_{sn}/20newsgroups/{name}/{method}'
                print(model_series_dir)
                run_dirs.append([model_series_dir])
            agg_func = choose_agg_func(method)
            try:
                res_df = collect_datasets(run_dirs, names, dataset_names, metric_types=metric_types, baselines={}, methods=agg_func)
                baselines.append(res_df.iloc[-1:])
                tables.append(res_df.iloc[:-1])
            except:
                print('pass')
                pass

../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/imdb/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/trec/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/wmt16/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/amazon/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/mnli/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/sst2/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/rte/l_nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/imdb/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/trec/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/wmt16/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/amazon/nuq
../../workdir/run_tasks_for_model_series/electra_raw_sn_all/20newsgroups/mnli/n

In [28]:
table_all = pd.concat([pd.concat(tables), pd.concat(baselines[-1:])]).reset_index()

In [29]:
def preproc_regs(x):
    reg = x.split('|')[1]
    if reg == 'reg':
        return 'CER'
    elif reg == 'raw':
        return '-'
    else:
        return reg
    
def preproc_method(x):
    method = x.split('|')[0]
    sn = x.split('|')[-1]
    if method == 'mahalanobis' and not 'no_sn' in sn:
        return 'MD SN_ALL (ours)'
    elif method == 'mahalanobis':
        return 'MD'
    elif method == 'l_mahalanobis' and not 'no_sn' in sn:
        return 'L-MD SN_ALL (ours)'
    elif method == 'l_mahalanobis':
        return 'L-MD'
    elif method == 'mc_mahalanobis' and not 'no_sn' in sn:
        return 'SMD SN_ALL (ours)'
    elif method == 'mc_mahalanobis':
        return 'SMD'
    elif method == 'nuq' and not 'no_sn' in sn:
        return 'NUQ SN_ALL'
    elif method == 'nuq':
        return 'NUQ'
    elif method == 'l_nuq' and not 'no_sn' in sn:
        return 'L-NUQ SN_ALL'
    elif method == 'l_nuq':
        return 'L-NUQ'
    elif method == 'sngp':
        return 'SNGP'
    
    elif method == 'decomposing_md' and not 'no_sn' in sn:
        return 'Decomposing SN_ALL'
    elif method == 'decomposing_md':
        return 'Decomposing'
    
    elif method == 'nuq_best1' and not 'no_sn' in sn:
        return 'Best1 NUQ SN_ALL'
    elif method == 'nuq_best1':
        return 'Best1 NUQ'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN_ALL'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu_maha' and not 'no_sn' in sn:
        return 'DDU Maha SN_ALL'
    elif method == 'ddu_maha':
        return 'DDU Maha'
    
    elif 'ddpp_dpp' in method:
        return 'DDPP (+DPP) (ours)'
    elif 'ddpp_ood' in method:
        return 'DDPP (+OOD) (ours)'
    elif 'mc_all' in method:
        return 'MC dropout'
    elif 'Deep' in method:
        return 'Deep Ensemble'
    elif 'baseline|raw_no_sn' in x:
        return 'SR (baseline)'
    elif 'baseline' in x and not 'no_sn' in x:
        return 'SR SN'
    return 'SR'

def preproc_ue(x):
    if x == 'bald':
        return 'BALD'
    elif 'sampled_mahalanobis_distance' in x:
        return 'SMD'
    elif 'relative_mahalanobis_distance' in x:
        return 'rel_MD'
    elif 'marginal_mahalanobis_distance' in x:
        return 'marg_MD'
    elif 'mahalanobis_distance' in x:
        return 'MD'
    elif 'sampled_max_prob' in x:
        return 'SMP'
    elif 'variance' in x:
        return 'PV'
    elif 'aleatoric' in x:
        return 'aleatoric'
    elif 'epistemic' in x:
        return 'epistemic'
    elif 'total' in x:
        return 'total'
    elif x == 'disc_md':
        return 'Disc MD'
    elif x == 'nondisc_md':
        return 'Nondisc MD'
    elif x == 'disc+nondisc_md':
        return 'Disc+Nondisc MD'
    elif x == 'sngp':
        return 'std'
    return 'MP'

table_all['Reg. Type'] = table_all.Method.apply(lambda x: preproc_regs(x))
table_all['Method'] = table_all.Method.apply(lambda x: preproc_method(x))
table_all['UE Score'] = table_all['UE Score'].apply(lambda x: preproc_ue(x))
table_all = table_all[list(table_all.columns[:1]) + list(table_all.columns[-1:]) + list(table_all.columns[1:-1])].reset_index(drop=True)

In [30]:
roc_auc_mean = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
roc_auc_std = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

f1_mean = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
f1_std = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_auc_mean = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_auc_std = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_mean = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_std = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

rec_auc_mean = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
rec_auc_std = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

table_all[('Mean',  'pr-auc')] = pr_auc_mean.round(2).astype(str)+'±'+pr_auc_std.round(2).astype(str)
table_all[('Mean',  'f1-score')] = f1_mean.round(2).astype(str)+'±'+f1_std.round(2).astype(str)
table_all[('Mean',  'precision')] = pr_mean.round(2).astype(str)+'±'+pr_std.round(2).astype(str)
table_all[('Mean',  'recall')] = rec_auc_mean.round(2).astype(str)+'±'+rec_auc_std.round(2).astype(str)
table_all[('Mean',  'roc-auc')] = roc_auc_mean.round(2).astype(str)+'±'+roc_auc_std.round(2).astype(str)

In [31]:
def bold_max(table):
    attr = 'font-weight: bold'
    data = table[table.columns[3:]].apply(lambda x: x.str.split('±').str[0].astype(float))
    is_max = data == data.max()
    info_col = table[table.columns[:3]]
    for col in info_col.columns:
        info_col[col] = ''
    vals = pd.DataFrame(np.where(is_max, attr, ''),
                        index=data.index, columns=data.columns)
    return pd.concat([info_col, vals], axis=1)

def highlight_nmax(s):
    try:
        s_vals = s.str.split('±').str[0].astype(float)
        is_large = s_vals.nlargest(3).values
        style_str = ['background-color: yellow' if v in is_large else '' for v in s_vals]
        if len(np.unique(style_str))==1:
            return ['']*len(style_str)
        return style_str
    except:
        return ['']*len(s)

In [32]:
table_all[list(table_all.columns[:3]) + list(table_all.columns[-5:])].style.apply(highlight_nmax).apply(bold_max, axis=None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys


Unnamed: 0_level_0,Method,Reg. Type,UE Score,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,f1-score,precision,recall,roc-auc
0,L-NUQ SN_ALL,-,aleatoric,70.49±4.9,71.34±2.66,66.99±5.14,77.64±4.86,89.81±1.39
1,L-NUQ SN_ALL,-,epistemic,67.79±3.37,67.44±2.29,63.11±4.54,74.46±3.06,88.41±1.07
2,L-NUQ SN_ALL,-,total,71.49±4.17,71.11±2.79,68.21±4.5,75.61±3.26,89.57±1.19
3,NUQ SN_ALL,-,aleatoric,57.36±2.09,63.63±1.4,56.07±1.66,77.61±4.14,85.11±1.31
4,NUQ SN_ALL,-,epistemic,61.51±1.6,63.7±1.39,57.83±2.71,71.81±2.97,85.64±1.2
5,NUQ SN_ALL,-,total,58.31±1.91,63.66±1.43,56.74±2.09,75.21±4.04,85.23±1.2
6,MD SN_ALL (ours),-,MD,64.89±2.29,68.5±1.74,61.94±2.27,78.19±2.89,89.49±1.3
7,MD SN_ALL (ours),-,rel_MD,57.17±7.91,62.19±5.46,56.07±6.71,72.16±9.0,73.63±19.04
8,MD SN_ALL (ours),-,marg_MD,64.09±3.36,67.56±2.96,61.53±3.4,76.36±4.81,86.81±6.41
9,L-MD SN_ALL (ours),-,MD,68.5±1.83,71.14±1.26,64.03±1.96,81.83±2.76,91.26±0.74


#### baseline - 81.06±3.09
#### best - 95.91±0.87

In [51]:
import os 

def choose_agg_func(method):
    agg_methods = {
        "bald": bald,
        "sampled_max_prob": sampled_max_prob,
        "variance": probability_variance,
    }
    if method=='nuq' or method=='nuq_best' or method=='nuq_best1':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "nuq_aleatoric": nuq_aleatoric,
            "nuq_epistemic": nuq_epistemic,
            "nuq_total": nuq_total,
        }
    elif method=='l_nuq':
        nuq_aleatoric = lambda x: np.squeeze(x[0], axis=-1)
        nuq_epistemic = lambda x: np.squeeze(x[1], axis=-1)
        nuq_total = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {
            "l_nuq_aleatoric": nuq_aleatoric,
            "l_nuq_epistemic": nuq_epistemic,
            "l_nuq_total": nuq_total,
        }
    elif method=='decomposing_md':
        disc_md = lambda x: np.squeeze(x[0], axis=-1)
        nondisc_md = lambda x: np.squeeze(x[1], axis=-1)
        sum_md = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"disc_md": disc_md, 
                       'nondisc_md': nondisc_md,
                       'disc+nondisc_md': sum_md}
    elif method=='mahalanobis':
        maha_dist = lambda x: np.squeeze(x[:, 0], axis=-1)
        #rel_maha_dist = lambda x: np.squeeze(x[1], axis=-1)
        #marg_maha_dist = lambda x: np.squeeze(x[2], axis=-1)
        agg_methods = {"mahalanobis_distance": maha_dist,
                       #"relative_mahalanobis_distance": rel_maha_dist,
                       #"marginal_mahalanobis_distance": marg_maha_dist
                      }
    elif method=='l_mahalanobis':
        maha_dist = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"l_mahalanobis_distance": maha_dist}
    elif method=='ddu' or method=='ddu_maha':
        ddu = lambda x: -np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"ddu": ddu}
    elif method=='sngp':
        sngp = lambda x: np.squeeze(x[:, 0], axis=-1)
        agg_methods = {"sngp": sngp}
    elif method=='mc_mahalanobis':
        sm_maha_dist = lambda x: np.squeeze(x[:, 1:], axis=-1).max(1)
        agg_methods = {"sampled_mahalanobis_distance": sm_maha_dist}
    return agg_methods
metric_types=["roc-auc", 'pr-auc', 'f1-score', 'precision', 'recall']

methods = ['mahalanobis', 'mc', 'nuq']
regs = ['hs_rau']
spectralnorm = ['no_sn']
dataset_names = ['IMDB', 'TREC', 'WMT16', 'Amazon', 'MNLI', '20NewsGroups', 'RTE']
dataset_fnames = ['imdb', 'trec', 'wmt16', 'amazon', 'mnli', '20newsgroups', 'rte']
names = []
tables = []
baselines = []
for method in methods:
    for reg in regs:
        for sn in spectralnorm:
            run_dirs = []
            name_sn = ''
            names = [f'{method}|{reg}|{sn}']
            for name in dataset_fnames:
                model_series_dir = f'../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_{reg}_{sn}/sst2/{name}/{method}'
                print(model_series_dir)
                run_dirs.append([model_series_dir])
            agg_func = choose_agg_func(method)
            try:
                res_df = collect_datasets(run_dirs, names, dataset_names, metric_types=metric_types, baselines={}, methods=agg_func)
                baselines.append(res_df.iloc[-1:])
                tables.append(res_df.iloc[:-1])
            except:
                print('pass')
                pass

../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/imdb/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/trec/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/wmt16/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/amazon/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/mnli/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/20newsgroups/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/rte/mahalanobis
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2/imdb/mc
../../../uncertainty-estimation_cp/workdir/run_tasks_for_model_series/electra_hs_rau_no_sn/sst2

In [52]:
table_all = pd.concat([pd.concat([t[:-1] for t in tables]), pd.concat(baselines[-1:])]).reset_index()

In [53]:
def preproc_regs(x):
    reg = x.split('|')[1]
    if reg == 'reg':
        return 'CER'
    elif reg == 'raw':
        return '-'
    elif reg == 'hs_rau':
        return 'HS-RAU'
    else:
        return reg
    
def preproc_method(x):
    method = x.split('|')[0]
    sn = x.split('|')[-1]
    if method == 'mahalanobis' and not 'no_sn' in sn:
        return 'MD SN_ALL (ours)'
    elif method == 'mahalanobis':
        return 'MD'
    elif method == 'l_mahalanobis' and not 'no_sn' in sn:
        return 'L-MD SN_ALL (ours)'
    elif method == 'l_mahalanobis':
        return 'L-MD'
    elif method == 'mc_mahalanobis' and not 'no_sn' in sn:
        return 'SMD SN_ALL (ours)'
    elif method == 'mc_mahalanobis':
        return 'SMD'
    elif method == 'nuq' and not 'no_sn' in sn:
        return 'NUQ SN_ALL'
    elif method == 'nuq':
        return 'NUQ'
    elif method == 'l_nuq' and not 'no_sn' in sn:
        return 'L-NUQ SN_ALL'
    elif method == 'l_nuq':
        return 'L-NUQ'
    elif method == 'sngp':
        return 'SNGP'
    
    elif method == 'decomposing_md' and not 'no_sn' in sn:
        return 'Decomposing SN_ALL'
    elif method == 'decomposing_md':
        return 'Decomposing'
    
    elif method == 'nuq_best1' and not 'no_sn' in sn:
        return 'Best1 NUQ SN_ALL'
    elif method == 'nuq_best1':
        return 'Best1 NUQ'
    
    elif method == 'ddu' and not 'no_sn' in sn:
        return 'DDU SN_ALL'
    elif method == 'ddu':
        return 'DDU'
    
    elif method == 'ddu_maha' and not 'no_sn' in sn:
        return 'DDU Maha SN_ALL'
    elif method == 'ddu_maha':
        return 'DDU Maha'
    
    elif 'ddpp_dpp' in method:
        return 'DDPP (+DPP) (ours)'
    elif 'ddpp_ood' in method:
        return 'DDPP (+OOD) (ours)'
    elif 'mc' in method:
        return 'MC dropout'
    elif 'Deep' in method:
        return 'Deep Ensemble'
    elif 'baseline|raw_no_sn' in x:
        return 'SR (baseline)'
    elif 'baseline' in x and not 'no_sn' in x:
        return 'SR SN'
    return 'SR'

def preproc_ue(x):
    if x == 'bald':
        return 'BALD'
    elif 'sampled_mahalanobis_distance' in x:
        return 'SMD'
    elif 'relative_mahalanobis_distance' in x:
        return 'rel_MD'
    elif 'marginal_mahalanobis_distance' in x:
        return 'marg_MD'
    elif 'mahalanobis_distance' in x:
        return 'MD'
    elif 'sampled_max_prob' in x:
        return 'SMP'
    elif 'variance' in x:
        return 'PV'
    elif 'aleatoric' in x:
        return 'aleatoric'
    elif 'epistemic' in x:
        return 'epistemic'
    elif 'total' in x:
        return 'total'
    elif x == 'disc_md':
        return 'Disc MD'
    elif x == 'nondisc_md':
        return 'Nondisc MD'
    elif x == 'disc+nondisc_md':
        return 'Disc+Nondisc MD'
    elif x == 'sngp':
        return 'std'
    return 'MP'

table_all['Reg. Type'] = table_all.Method.apply(lambda x: preproc_regs(x))
table_all['Method'] = table_all.Method.apply(lambda x: preproc_method(x))
table_all['UE Score'] = table_all['UE Score'].apply(lambda x: preproc_ue(x))
table_all = table_all[list(table_all.columns[:1]) + list(table_all.columns[-1:]) + list(table_all.columns[1:-1])].reset_index(drop=True)

In [54]:
roc_auc_mean = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
roc_auc_std = table_all[table_all.columns[3::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

f1_mean = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
f1_std = table_all[table_all.columns[5::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_auc_mean = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_auc_std = table_all[table_all.columns[4::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

pr_mean = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
pr_std = table_all[table_all.columns[6::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

rec_auc_mean = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[0].astype(float)).mean(axis=1)
rec_auc_std = table_all[table_all.columns[7::5]].apply(lambda x: x.str.split('±').str[1].astype(float)).mean(axis=1)

table_all[('Mean',  'pr-auc')] = pr_auc_mean.round(2).astype(str)+'±'+pr_auc_std.round(2).astype(str)
table_all[('Mean',  'f1-score')] = f1_mean.round(2).astype(str)+'±'+f1_std.round(2).astype(str)
table_all[('Mean',  'precision')] = pr_mean.round(2).astype(str)+'±'+pr_std.round(2).astype(str)
table_all[('Mean',  'recall')] = rec_auc_mean.round(2).astype(str)+'±'+rec_auc_std.round(2).astype(str)
table_all[('Mean',  'roc-auc')] = roc_auc_mean.round(2).astype(str)+'±'+roc_auc_std.round(2).astype(str)

In [55]:
table_all[list(table_all.columns[:3]) + list(table_all.columns[-5:])]

Unnamed: 0_level_0,Method,Reg. Type,UE Score,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,pr-auc,f1-score,precision,recall,roc-auc
0,MD,HS-RAU,MD,91.03±2.31,90.74±1.55,86.73±2.42,95.98±1.11,87.12±2.38
1,MC dropout,HS-RAU,BALD,82.24±1.83,87.67±1.17,82.45±1.59,94.66±1.41,73.17±2.84
2,MC dropout,HS-RAU,SMP,83.78±2.06,88.15±1.6,83.03±2.09,95.06±2.04,74.63±3.37
3,MC dropout,HS-RAU,PV,82.55±1.87,87.88±1.28,82.79±1.88,94.75±1.84,73.71±2.97
4,NUQ,HS-RAU,aleatoric,89.61±2.63,90.05±1.94,85.89±2.59,95.52±1.25,84.4±3.78
5,NUQ,HS-RAU,epistemic,88.8±2.16,89.86±1.76,85.95±1.68,94.63±3.01,82.6±3.71
6,NUQ,HS-RAU,total,89.5±2.46,90.1±1.72,85.98±2.29,95.33±2.27,83.89±3.71
7,SR,HS-RAU,MP,87.74±2.57,88.61±2.23,85.98±2.29,95.33±2.27,78.53±4.15


In [56]:
table_all

Unnamed: 0_level_0,Method,Reg. Type,UE Score,IMDB,IMDB,IMDB,IMDB,IMDB,TREC,TREC,...,RTE,RTE,RTE,RTE,RTE,Mean,Mean,Mean,Mean,Mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,roc-auc,pr-auc,f1-score,precision,recall,roc-auc,pr-auc,...,roc-auc,pr-auc,f1-score,precision,recall,pr-auc,f1-score,precision,recall,roc-auc
0,MD,HS-RAU,MD,82.18±2.00,99.04±0.14,98.39±0.04,96.97±0.10,99.86±0.04,84.71±7.71,69.45±11.56,...,92.61±0.92,77.40±3.23,75.10±1.96,67.34±3.99,85.20±3.68,91.03±2.31,90.74±1.55,86.73±2.42,95.98±1.11,87.12±2.38
1,MC dropout,HS-RAU,BALD,57.20±3.13,96.60±0.18,98.29±0.00,96.63±0.00,100.00±0.00,82.20±5.20,62.24±9.80,...,79.07±1.94,42.07±1.60,60.65±3.01,50.16±3.20,76.90±4.75,82.24±1.83,87.67±1.17,82.45±1.59,94.66±1.41,73.17±2.84
2,MC dropout,HS-RAU,SMP,58.89±3.66,97.24±0.21,98.29±0.02,96.67±0.06,99.98±0.04,82.29±5.67,61.94±9.37,...,81.32±2.98,47.41±2.88,62.55±4.13,51.76±3.88,79.60±8.24,83.78±2.06,88.15±1.6,83.03±2.09,95.06±2.04,74.63±3.37
3,MC dropout,HS-RAU,PV,57.96±3.27,96.74±0.19,98.29±0.00,96.64±0.01,100.00±0.00,82.27±5.35,62.29±9.80,...,79.82±2.18,43.04±1.73,61.46±3.33,50.47±3.88,79.06±6.00,82.55±1.87,87.88±1.28,82.79±1.88,94.75±1.84,73.71±2.97
4,NUQ,HS-RAU,aleatoric,77.77±2.89,98.76±0.15,98.37±0.05,96.88±0.13,99.91±0.05,81.94±11.57,65.09±12.64,...,91.09±1.58,73.90±3.99,73.68±2.07,66.99±2.32,82.01±3.90,89.61±2.63,90.05±1.94,85.89±2.59,95.52±1.25,84.4±3.78
5,NUQ,HS-RAU,epistemic,75.60±3.15,98.61±0.18,98.37±0.04,96.90±0.13,99.88±0.07,84.45±5.46,66.25±7.98,...,88.67±3.25,68.90±4.81,71.35±3.81,65.23±2.73,79.36±8.74,88.8±2.16,89.86±1.76,85.95±1.68,94.63±3.01,82.6±3.71
6,NUQ,HS-RAU,total,77.07±2.94,98.71±0.15,98.37±0.04,96.89±0.13,99.89±0.07,83.45±8.62,66.02±10.91,...,90.29±2.26,72.76±4.32,72.94±2.27,66.42±4.03,81.59±6.99,89.5±2.46,90.1±1.72,85.98±2.29,95.33±2.27,83.89±3.71
7,SR,HS-RAU,MP,64.87±4.47,98.00±0.28,98.31±0.03,96.89±0.13,99.89±0.07,83.51±6.63,65.54±8.90,...,86.94±4.17,64.97±6.38,66.62±6.18,66.42±4.03,81.59±6.99,87.74±2.57,88.61±2.23,85.98±2.29,95.33±2.27,78.53±4.15
