In [None]:
import pickle
import os
import uncertainty_estimation.experiments_utils.ood_experiments_utils as ood_utils
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

In [None]:
def barplot_from_nested_dict(nested_dict, metric_name='OOD detection AUC',
                             group_name='OOD group', vline=None, xlim=(0, 1.0), save_dir=None,
                             height=6,
                             aspect=1.5, legend_out=False):
    sns.set_palette("Set1", 10)
    sns.set_style('whitegrid')
    df = pd.DataFrame.from_dict(nested_dict)

    df = df.stack().reset_index()
    df.columns = [group_name, '', metric_name]

    sns.catplot(x=metric_name, y=group_name, hue='', data=df, kind='bar',
                height=height, aspect=aspect, facet_kws=dict(despine=False), alpha=0.9,
                legend_out=legend_out)
    plt.xlim(xlim)
    if vline:
        plt.axvline(vline, linestyle='--')
    if save_dir:
        plt.savefig(save_dir, dpi=300,
                    bbox_inches='tight', pad=0)
        plt.close()
    else:
        plt.show()


def boxplot_from_nested_listdict(nested_dict, name, hline=None, ylim=(0.0, 1.0), x_name='scale',
                                 save_dir=None, **kwargs):
    sns.set_palette("Set1", 10)
    sns.set_style('whitegrid')
    df = pd.DataFrame.from_dict(nested_dict,
                                orient='columns')

    df = df.stack().reset_index()
    df.columns = [x_name, '', name]
    df = df.explode(name)

    sns.catplot(x=x_name, y=name, hue='', data=df, kind='box',
                facet_kws=dict(despine=False), legend_out=False, **kwargs)
    if hline:
        plt.axhline(hline, linestyle='--')
    plt.ylim(ylim)
    if save_dir:
        plt.savefig(save_dir, dpi=300,
                    bbox_inches='tight', pad=0)
        plt.close()
    else:
        plt.show()
    


## OOD barplots

In [None]:
ood_dir_name = '../uncertainty_estimation/experiments/pickled_results/OOD/'
auc_dict, recall_dict = dict(), dict()
for method in os.listdir(ood_dir_name):
    method_dir = os.path.join(ood_dir_name, method)
    with open(os.path.join(method_dir,'detect_auc.pkl'), 'rb') as f:
        auc_dict[method] = pickle.load(f)
    with open(os.path.join(method_dir,'recall.pkl'), 'rb') as f:
        recall_dict[method] = pickle.load(f)

In [None]:
barplot_from_nested_dict(auc_dict, vline=0.5, xlim=(0.45, 1.0))

In [None]:
barplot_from_nested_dict(recall_dict, metric_name='OOD recall', vline=0.05)

## Perturbation barplots

In [None]:
ood_dir_name = '../uncertainty_estimation/experiments/pickled_results/perturbation/'
auc_dict, recall_dict = dict(), dict()
auc_dict_std, recall_dict_std = dict(), dict()
for method in os.listdir(ood_dir_name):
    method_dir = os.path.join(ood_dir_name, method)
    with open(os.path.join(method_dir,'perturb_detect_auc.pkl'), 'rb') as f:
        auc_dict[method] = pickle.load(f)
    with open(os.path.join(method_dir,'perturb_recall.pkl'), 'rb') as f:
        recall_dict[method] = pickle.load(f)

In [None]:
recall_dict.keys()

In [None]:
boxplot_from_nested_listdict(recall_dict, "OOD recall", hline=0.05)

In [None]:
boxplot_from_listdict(auc_dict, "OOD detection AUC", hline=0.5, ylim=(0.45,1.0))