In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from copy import deepcopy

In [None]:
1e-3

In [None]:
from tueplots import bundles



# Inspired by bundles.neurips2023(), but adapting font sizes for pt12 standard

settings_dict = {'text.usetex': True,
                 'font.family': 'serif',
                 'text.latex.preamble': '\\renewcommand{\\rmdefault}{ptm}\\renewcommand{\\sfdefault}{phv}',
                 'figure.figsize': (5.5, 3.399186938124422),
                 'figure.constrained_layout.use': True,
                 'figure.autolayout': False,
                 'savefig.bbox': 'tight',
                 'savefig.pad_inches': 0.015,
                 'font.size': 10,
                 'axes.labelsize': 10,
                 'legend.fontsize': 8,
                 'xtick.labelsize': 8,
                 'ytick.labelsize': 8,
                 'axes.titlesize': 10,
                 'figure.dpi': 300}


plt.rcParams.update(settings_dict)


# Can use colors from bundles.rgb.
#     tue_blue
#     tue_brown
#     tue_dark
#     tue_darkblue
#     tue_darkgreen
#     tue_gold
#     tue_gray
#     tue_green
#     tue_lightblue
#     tue_lightgold
#     tue_lightgreen
#     tue_lightorange
#     tue_mauve
#     tue_ocre
#     tue_orange
#     tue_red
#     tue_violet

In [None]:
# # %matplotlib widget
# import matplotlib
# matplotlib.rcParams["figure.dpi"] = 300

In [None]:
# The 'laplace-redux/results' folder has to be in the following specified folder:

RESULTS_DIRECTORY = 'results - Toy Datasets - All conditions (map, TS, Laplace, scaling, ef)'


In [None]:
def load_results(RESULTS_DIRECTORY, dataset_name, sub_name, model_ids):
    dat_list = []
    for i in model_ids:
        dat_list.append(np.load(f'{RESULTS_DIRECTORY}/{dataset_name}/{sub_name}_{i}.npy', allow_pickle=True))
    
    dat = []
    for i in range(len(dat_list[0])):
        dat_i = {}
        for k, v in dat_list[0][i].items():
            values_list = [model_seed_run[i][k] for model_seed_run in dat_list]
            dat_i[k] = np.mean(values_list)
            dat_i[k + "_se"] = np.std(values_list) / np.sqrt(len(values_list))
        dat.append(dat_i)

    return dat


In [None]:
# automatically load all conditions

def get_all_conditions_model_id_names(RESULTS_DIRECTORY, DATASET_NAME):
    results_files = os.listdir(os.path.join(RESULTS_DIRECTORY, DATASET_NAME))
    conditions = list(set(["_".join(f.split("_")[:-1]) for f in results_files]))
    model_ids = list(set([f.split("_")[-1][:-4] for f in results_files]))
    model_ids.sort()
    return conditions, model_ids


DATASET = "CIFAR-10-C"
conditions_cifar10c, model_ids_cifar10c = get_all_conditions_model_id_names(RESULTS_DIRECTORY, DATASET)
results_cifar10c = []
for condition in conditions_cifar10c: 
    results_cifar10c.append(load_results(RESULTS_DIRECTORY, DATASET, condition, model_ids_cifar10c))


DATASET = "R-MNIST"
conditions_rmnist, model_ids_rmnist = get_all_conditions_model_id_names(RESULTS_DIRECTORY, DATASET)
results_rmnist = []
for condition in conditions_rmnist: 
    results_rmnist.append(load_results(RESULTS_DIRECTORY, DATASET, condition, model_ids_rmnist))


DATASET = "R-FMNIST"
conditions_rfmnist, model_ids_rfmnist = get_all_conditions_model_id_names(RESULTS_DIRECTORY, DATASET)
results_rfmnist = []
for condition in conditions_rfmnist: 
    results_rfmnist.append(load_results(RESULTS_DIRECTORY, DATASET, condition, model_ids_rfmnist))




# load all results for each condition into a list.
# make a list of all the condition names

In [None]:
condition_translation_dict = {
    'laplace_ef_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "LLLA(EF)+CVS",
    'laplace_last_layer_full': "LLLA", 
    'laplace_last_layer_full_diagaddfitted': "LLLA+CVS - only diag fitted",
    'laplace_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "LLLA+CVS",
    'temp': "TS (pycalib)",
    'laplace_weight_inc_temp_last_layer_full': "LLLA+WITS",
    'map_weight_inc_temp': "TS (WITS)",
    'laplace_ef_weight_inc_temp_last_layer_full': "LLLA(EF)+WITS",
    'laplace_ef_weight_inc_temp_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "LLLA(EF)+WITS+CVS",
    'map': "MAP",
    'laplace_last_layer_full_scalingfittted': "LLLA+CVS - only scaling",
    'laplace_ef_last_layer_full': "LLLA(EF)",
    'laplace_last_layer_full_diagscalingfitted': "LLLA+CVS - only diag scaling",
    'laplace_weight_inc_temp_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "LLLA+WITS+CVS",
    'laplace_OODValSet_weight_inc_temp_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "fitted on OOD: LLLA+WITS+CVS",
    'laplace_OODValSet_weight_inc_temp_last_layer_full': "fitted on OOD: LLLA+WITS",
    'laplace_OODValSet_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "fitted on OOD: LLLA+CVS",
    'laplace_ef_OODValSet_weight_inc_temp_last_layer_full': "fitted on OOD: LLLA(EF)+WITS",
    'laplace_ef_OODValSet_last_layer_full': "fitted on OOD: LLLA(EF)",
    'laplace_ef_OODValSet_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "fitted on OOD: LLLA(EF)+CVS",
    'laplace_ef_OODValSet_weight_inc_temp_last_layer_full_scalingfittted_diagaddfitted_diagscalingfitted': "fitted on OOD: LLLA(EF)+WITS+CVS",
    'map_OODValSet_weight_inc_temp': "fitted on OOD: MAP+WITS",
}


In [None]:
assert np.all([c in condition_translation_dict.keys() for c in conditions_rmnist])
assert np.all([c in condition_translation_dict.keys() for c in conditions_rfmnist])
assert np.all([c in condition_translation_dict.keys() for c in conditions_cifar10c])


In [None]:
conditions_rmnist = [condition_translation_dict[c] for c in conditions_rmnist]
conditions_rfmnist = [condition_translation_dict[c] for c in conditions_rfmnist]
conditions_cifar10c = [condition_translation_dict[c] for c in conditions_cifar10c]


In [None]:
total_amount_of_conditions = len(set(conditions_rmnist + conditions_rfmnist + conditions_cifar10c))

In [None]:
total_amount_of_conditions

In [None]:
all_unique_conditions = set(conditions_rmnist + conditions_rfmnist + conditions_cifar10c)

fitted_on_OOD_conditions = [c for c in all_unique_conditions if 'fitted on OOD' in c]
standard_conditions = [c for c in all_unique_conditions if c not in fitted_on_OOD_conditions]

fitted_on_OOD_ef_conditions = [c for c in fitted_on_OOD_conditions if 'EF' in c]
fitted_on_OOD_conditions = [c for c in fitted_on_OOD_conditions if c not in fitted_on_OOD_ef_conditions]
ef_conditions = [c for c in standard_conditions if "EF" in c]
standard_conditions = [c for c in standard_conditions if c not in ef_conditions]
scaling_ablation_conditions = [c for c in standard_conditions if "only" in c]
standard_conditions = [c for c in standard_conditions if c not in scaling_ablation_conditions]


In [None]:
all_unique_conditions = list(all_unique_conditions)
all_unique_conditions.sort()
standard_conditions = list(standard_conditions)
standard_conditions.sort()

fitted_on_OOD_conditions = list(fitted_on_OOD_conditions)
fitted_on_OOD_conditions.sort()
fitted_on_OOD_ef_conditions = list(fitted_on_OOD_ef_conditions)
fitted_on_OOD_ef_conditions.sort()
ef_conditions = list(ef_conditions)
ef_conditions.sort()
scaling_ablation_conditions = list(scaling_ablation_conditions)
scaling_ablation_conditions.sort()


In [None]:
all_unique_conditions

In [None]:
standard_conditions

In [None]:
fitted_on_OOD_conditions

In [None]:
fitted_on_OOD_ef_conditions

In [None]:
ef_conditions

In [None]:
scaling_ablation_conditions

In [None]:
### Color palette

# palette = plt.get_cmap('hsv')
# palette = plt.get_cmap('nipy_spectral')
# palette = plt.get_cmap('gist_rainbow')
palette = plt.get_cmap('Set1')
condition_to_color = {c: palette(i / len(standard_conditions)) for i, c in enumerate(standard_conditions)}
palette = plt.get_cmap('Set2')
condition_to_color.update({c: palette(i / len(fitted_on_OOD_conditions)) for i, c in enumerate(fitted_on_OOD_conditions)})

palette = plt.get_cmap('Set2')
condition_to_color.update({c: palette(i / len(ef_conditions)) for i, c in enumerate(ef_conditions)})

palette = plt.get_cmap('Set2')
condition_to_color.update({c: palette(i / len(fitted_on_OOD_ef_conditions)) for i, c in enumerate(fitted_on_OOD_ef_conditions)})

palette = plt.get_cmap('Set2')
condition_to_color.update({c: palette(i / len(scaling_ablation_conditions)) for i, c in enumerate(scaling_ablation_conditions)})

condition_to_color["MAP"] = "darkblue"

In [None]:
def combined_plot(conditions_rmnist, results_rmnist, conditions_rfmnist, results_rfmnist, conditions_cifar10c, results_cifar10c, legend_order_permutation=None):

    fig, ax = plt.subplots(2, 3)
    # fig.set_size_inches([12, 8])

    for condition, results in zip(conditions_rmnist, results_rmnist):
        # datasets = range(len(results))
        datasets = range(0, 181, 15)
        ax[0][0].plot(datasets, [i['ece'] for i in results], label=condition, color=condition_to_color[condition])
    ax[0][0].set_xticks([0, 45, 90, 135, 180])
    ax[0][0].set_ylabel("ECE")



    for condition, results in zip(conditions_rmnist, results_rmnist):
        # datasets = range(len(results))
        datasets = range(0, 181, 15)
        ax[1][0].plot(datasets, [i['nll'] for i in results], label=condition, color=condition_to_color[condition])
    ax[1][0].set_ylabel("NLL")
    ax[1][0].set_xticks([0, 45, 90, 135, 180])
    ax[1][0].set_xlabel("R-MNIST\n(rotation angle)")


    for condition, results in zip(conditions_rfmnist, results_rfmnist):
        datasets = range(0, 181, 15)
        ax[0][1].plot(datasets, [i['ece'] for i in results], label=condition, color=condition_to_color[condition])
    ax[0][1].set_xticks([0, 45, 90, 135, 180])


    for condition, results in zip(conditions_rfmnist, results_rfmnist):
        datasets = range(0, 181, 15)
        ax[1][1].plot(datasets, [i['nll'] for i in results], label=condition, color=condition_to_color[condition])
    ax[1][1].set_xticks([0, 45, 90, 135, 180])
    ax[1][1].set_xlabel("R-FMNIST\n(rotation angle)")



    for condition, results in zip(conditions_cifar10c, results_cifar10c):
        datasets = range(len(results))
        ax[0][2].plot(datasets, [i['ece'] for i in results], label=condition, color=condition_to_color[condition])
    ax[0][2].set_xticks(datasets)


    for condition, results in zip(conditions_cifar10c, results_cifar10c):
        datasets = range(len(results))
        ax[1][2].plot(datasets, [i['nll'] for i in results], label=condition, color=condition_to_color[condition])
    ax[1][2].set_xticks(datasets)
    ax[1][2].set_xlabel("CIFAR10-C\n(degree of corruption)")



    handles, labels = ax[1][1].get_legend_handles_labels()
    if legend_order_permutation:
        handles, labels = [handles[idx] for idx in legend_order_permutation],[labels[idx] for idx in legend_order_permutation]

    # plt.figlegend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.55))
    ax[0][2].legend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.5))
    # plt.tight_layout()
    # plt.show()




In [None]:
combined_plot(conditions_rmnist, results_rmnist, conditions_rfmnist, results_rfmnist, conditions_cifar10c, results_cifar10c)
plt.show()

In [None]:
def filter_conditions(conditions, results, wanted_conditions_list):
    ret_conditions, ret_results = [], []

    for c, r in zip(conditions, results): 
        if c in wanted_conditions_list:
            ret_conditions.append(c)
            ret_results.append(r)
        
    return ret_conditions, ret_results

In [None]:
# TODO
# Concise plots:
# Baseline (MAP, LLLA, TS(pycalib)) (+ TS(mine?))
wanted_conditions = ["MAP", "LLLA", "TS (pycalib)", "TS (WITS)"]
conditions_rmnist_filtered, results_rmnist_filtered = filter_conditions(conditions_rmnist, results_rmnist, wanted_conditions)
conditions_rfmnist_filtered, results_rfmnist_filtered = filter_conditions(conditions_rfmnist, results_rfmnist, wanted_conditions)
conditions_cifar10c_filtered, results_cifar10c_filtered = filter_conditions(conditions_cifar10c, results_cifar10c, wanted_conditions)

conditions_rmnist_filtered, results_rmnist_filtered = zip(*sorted(zip(conditions_rmnist_filtered, results_rmnist_filtered)))
conditions_rfmnist_filtered, results_rfmnist_filtered = zip(*sorted(zip(conditions_rfmnist_filtered, results_rfmnist_filtered)))
conditions_cifar10c_filtered, results_cifar10c_filtered = zip(*sorted(zip(conditions_cifar10c_filtered, results_cifar10c_filtered)))


condition_to_color.update({'MAP': bundles.rgb.tue_red,
                          'TS (pycalib)': bundles.rgb.tue_lightorange,
                          'TS (WITS)': bundles.rgb.tue_orange,
                          'LLLA': bundles.rgb.tue_violet,
                          'LLLA+CVS': bundles.rgb.tue_lightblue,
                          'LLLA+WITS': bundles.rgb.tue_blue,
                          'LLLA+WITS+CVS': bundles.rgb.tue_darkblue,
                          'fitted on OOD: MAP+WITS': bundles.rgb.tue_gray,
                          'fitted on OOD: LLLA+CVS': bundles.rgb.tue_lightgreen,
                          'fitted on OOD: LLLA+WITS': bundles.rgb.tue_green,
                          'fitted on OOD: LLLA+WITS+CVS': bundles.rgb.tue_darkgreen
})


combined_plot(conditions_rmnist_filtered, results_rmnist_filtered, conditions_rfmnist_filtered, results_rfmnist_filtered, conditions_cifar10c_filtered, results_cifar10c_filtered, legend_order_permutation=[1, 0, 3, 2])
if not os.path.exists('img/Results/ToyData_ECE_NLL/'):
    os.makedirs('img/Results/ToyData_ECE_NLL/')
plt.savefig('img/Results/ToyData_ECE_NLL/MAP_LLLA_TS_ECE_NLL.pdf')
plt.show()


In [None]:
# WITS / Laplace with Cov-scaling

wanted_conditions = deepcopy(standard_conditions)
wanted_conditions.remove("TS (pycalib)")

conditions_rmnist_filtered, results_rmnist_filtered = filter_conditions(conditions_rmnist, results_rmnist, wanted_conditions)
conditions_rfmnist_filtered, results_rfmnist_filtered = filter_conditions(conditions_rfmnist, results_rfmnist, wanted_conditions)
conditions_cifar10c_filtered, results_cifar10c_filtered = filter_conditions(conditions_cifar10c, results_cifar10c, wanted_conditions)

conditions_rmnist_filtered, results_rmnist_filtered = zip(*sorted(zip(conditions_rmnist_filtered, results_rmnist_filtered)))
conditions_rfmnist_filtered, results_rfmnist_filtered = zip(*sorted(zip(conditions_rfmnist_filtered, results_rfmnist_filtered)))
conditions_cifar10c_filtered, results_cifar10c_filtered = zip(*sorted(zip(conditions_cifar10c_filtered, results_cifar10c_filtered)))


combined_plot(conditions_rmnist_filtered, results_rmnist_filtered, conditions_rfmnist_filtered, results_rfmnist_filtered, conditions_cifar10c_filtered, results_cifar10c_filtered, legend_order_permutation=[4, 0, 1, 5, 2, 3])
if not os.path.exists('img/Results/ToyData_ECE_NLL/'):
    os.makedirs('img/Results/ToyData_ECE_NLL/')
plt.savefig('img/Results/ToyData_ECE_NLL/WITS_CovScaling_ECE_NLL.pdf')
plt.show()


In [None]:
# With EF vs. GGN

wanted_conditions = ["MAP", "LLLA", "LLLA(EF)"]
# wanted_conditions = ["LLLA+CVS", "LLLA(EF)+CVS"]
# wanted_conditions = ["LLLA+WITS+CVS", "LLLA(EF)+WITS+CVS"]

# wanted_conditions = ["LLLA", "LLLA+WITS", "LLLA+CVSITS + Cov-sca+CVSF)", "LLLA(EF)+WITS", "LLLA(EF)+CVSF)+WITS+CVS"]

conditions_rmnist_filtered, results_rmnist_filtered = filter_conditions(conditions_rmnist, results_rmnist, wanted_conditions)
conditions_rfmnist_filtered, results_rfmnist_filtered = filter_conditions(conditions_rfmnist, results_rfmnist, wanted_conditions)
conditions_cifar10c_filtered, results_cifar10c_filtered = filter_conditions(conditions_cifar10c, results_cifar10c, wanted_conditions)

conditions_rmnist_filtered, results_rmnist_filtered = zip(*sorted(zip(conditions_rmnist_filtered, results_rmnist_filtered)))
conditions_rfmnist_filtered, results_rfmnist_filtered = zip(*sorted(zip(conditions_rfmnist_filtered, results_rfmnist_filtered)))
conditions_cifar10c_filtered, results_cifar10c_filtered = zip(*sorted(zip(conditions_cifar10c_filtered, results_cifar10c_filtered)))


combined_plot(conditions_rmnist_filtered, results_rmnist_filtered, conditions_rfmnist_filtered, results_rfmnist_filtered, conditions_cifar10c_filtered, results_cifar10c_filtered, legend_order_permutation=[1, 2, 0])
plt.show()


In [None]:
# With EF: WITS / Laplace with Cov-scaling

wanted_conditions = ["MAP", "LLLA", "LLLA+WITS+CVS", "LLLA(EF)", "LLLA(EF)+WITS", "LLLA(EF)+CVS", "LLLA(EF)+WITS+CVS"]

conditions_rmnist_filtered, results_rmnist_filtered = filter_conditions(conditions_rmnist, results_rmnist, wanted_conditions)
conditions_rfmnist_filtered, results_rfmnist_filtered = filter_conditions(conditions_rfmnist, results_rfmnist, wanted_conditions)
conditions_cifar10c_filtered, results_cifar10c_filtered = filter_conditions(conditions_cifar10c, results_cifar10c, wanted_conditions)

conditions_rmnist_filtered, results_rmnist_filtered = zip(*sorted(zip(conditions_rmnist_filtered, results_rmnist_filtered)))
conditions_rfmnist_filtered, results_rfmnist_filtered = zip(*sorted(zip(conditions_rfmnist_filtered, results_rfmnist_filtered)))
conditions_cifar10c_filtered, results_cifar10c_filtered = zip(*sorted(zip(conditions_cifar10c_filtered, results_cifar10c_filtered)))


combined_plot(conditions_rmnist_filtered, results_rmnist_filtered, conditions_rfmnist_filtered, results_rfmnist_filtered, conditions_cifar10c_filtered, results_cifar10c_filtered, legend_order_permutation=[6, 0, 5, 3, 4, 1, 2])
if not os.path.exists('img/Results/ToyData_ECE_NLL/'):
    os.makedirs('img/Results/ToyData_ECE_NLL/')
plt.savefig('img/Results/ToyData_ECE_NLL/EF_Ablation.pdf')
plt.show()


In [None]:
# # WITS / Laplace with Cov-scaling

# # wanted_conditions = ["MAP", "LLLA", "TS (pycalib)", "TS (WITS)", "LLLA+WITS", "LLLA+CVSITS+CVS
# wanted_conditions = ["MAP", "LLLA+WITS", "LLLA+CVSITS+CVS

# conditions_rmnist_filtered, results_rmnist_filtered = filter_conditions(conditions_rmnist, results_rmnist, wanted_conditions)
# conditions_rfmnist_filtered, results_rfmnist_filtered = filter_conditions(conditions_rfmnist, results_rfmnist, wanted_conditions)
# conditions_cifar10c_filtered, results_cifar10c_filtered = filter_conditions(conditions_cifar10c, results_cifar10c, wanted_conditions)

# combined_plot(conditions_rmnist_filtered, results_rmnist_filtered, conditions_rfmnist_filtered, results_rfmnist_filtered, conditions_cifar10c_filtered, results_cifar10c_filtered)

# CVS Scaling parameters ablation

In [None]:
wanted_conditions = deepcopy(scaling_ablation_conditions)
wanted_conditions += ["LLLA", 'LLLA+CVS']
# wanted_conditions.remove("TS (pycalib)")

conditions_rmnist_filtered, results_rmnist_filtered = filter_conditions(conditions_rmnist, results_rmnist, wanted_conditions)
conditions_rfmnist_filtered, results_rfmnist_filtered = filter_conditions(conditions_rfmnist, results_rfmnist, wanted_conditions)
conditions_cifar10c_filtered, results_cifar10c_filtered = filter_conditions(conditions_cifar10c, results_cifar10c, wanted_conditions)

conditions_rmnist_filtered, results_rmnist_filtered = zip(*sorted(zip(conditions_rmnist_filtered, results_rmnist_filtered)))
conditions_rfmnist_filtered, results_rfmnist_filtered = zip(*sorted(zip(conditions_rfmnist_filtered, results_rfmnist_filtered)))
conditions_cifar10c_filtered, results_cifar10c_filtered = zip(*sorted(zip(conditions_cifar10c_filtered, results_cifar10c_filtered)))


combined_plot(conditions_rmnist_filtered, results_rmnist_filtered, conditions_rfmnist_filtered, results_rfmnist_filtered, conditions_cifar10c_filtered, results_cifar10c_filtered)
if not os.path.exists('img/Results/ToyData_ECE_NLL/'):
    os.makedirs('img/Results/ToyData_ECE_NLL/')
plt.savefig('img/Results/ToyData_ECE_NLL/CVS_Parameter_Ablation.pdf')
plt.show()


In [None]:

fig, ax = plt.subplots()

for condition, results in zip(conditions_cifar10c, results_cifar10c):
    datasets = range(len(results))
    ax.plot(datasets, [i['ece'] for i in results], label=condition)

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("ECE on cifar10c")
plt.legend()
plt.show()



In [None]:

fig, ax = plt.subplots()

for condition, results in zip(conditions_cifar10c, results_cifar10c):
    datasets = range(len(results))
    ax.plot(datasets, [i['nll'] for i in results], label=condition)

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("NLL on cifar10c")
plt.legend()
plt.show()



In [None]:

fig, ax = plt.subplots()

for condition, results in zip(conditions_rmnist, results_rmnist):
    datasets = range(len(results))
    ax.plot(datasets, [i['ece'] for i in results], label=condition)

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("ECE on rmnist")
plt.legend()
plt.show()



In [None]:

fig, ax = plt.subplots()

for condition, results in zip(conditions_rmnist, results_rmnist):
    datasets = range(len(results))
    ax.plot(datasets, [i['nll'] for i in results], label=condition)

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("NLL on rmnist")
plt.legend()
plt.show()



In [None]:

fig, ax = plt.subplots()

for condition, results in zip(conditions_rfmnist, results_rfmnist):
    datasets = range(len(results))
    ax.plot(datasets, [i['ece'] for i in results], label=condition)

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("ECE on rfmnist")
plt.legend()
plt.show()



In [None]:

fig, ax = plt.subplots()

for condition, results in zip(conditions_rfmnist, results_rfmnist):
    datasets = range(len(results))
    ax.plot(datasets, [i['nll'] for i in results], label=condition)

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("NLL on rfmnist")
plt.legend()
plt.show()



In [None]:
fig, ax = plt.subplots()

for condition, results in zip(conditions_cifar10c, results_cifar10c):
    datasets = range(len(results))
    try:
        acc = [m["acc"] for m in results]
        conf = [m["conf"] for m in results]
        mean_variances = [m["mean_variance"] for m in results]

        ax.plot(datasets, acc, label=f"acc - {condition}")
        ax.errorbar(datasets, conf, yerr=mean_variances, label=f"conf - {condition}", fmt='-o')
    except: # MAP does not have 'mean_variance'
        acc = [m["acc"] for m in results]
        conf = [m["conf"] for m in results]
        ax.plot(datasets, acc, label=f"acc - {condition}")
        ax.plot(datasets, conf, label=f"conf - {condition}")

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("Accuracy/Confidence on cifar10c")
plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()

for condition, results in zip(conditions_rmnist, results_rmnist):
    datasets = range(len(results))
    try:
        acc = [m["acc"] for m in results]
        conf = [m["conf"] for m in results]
        mean_variances = [m["mean_variance"] for m in results]

        ax.plot(datasets, acc, label=f"acc - {condition}")
        ax.errorbar(datasets, conf, yerr=mean_variances, label=f"conf - {condition}", fmt='-o')
    except: # MAP does not have 'mean_variance'
        acc = [m["acc"] for m in results]
        conf = [m["conf"] for m in results]
        ax.plot(datasets, acc, label=f"acc - {condition}")
        ax.plot(datasets, conf, label=f"conf - {condition}")

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("Accuracy/Confidence on rmnist")
plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()

for condition, results in zip(conditions_rfmnist, results_rfmnist):
    datasets = range(len(results))
    try:
        acc = [m["acc"] for m in results]
        conf = [m["conf"] for m in results]
        mean_variances = [m["mean_variance"] for m in results]

        ax.plot(datasets, acc, label=f"acc - {condition}")
        ax.errorbar(datasets, conf, yerr=mean_variances, label=f"conf - {condition}", fmt='-o')
    except: # MAP does not have 'mean_variance'
        acc = [m["acc"] for m in results]
        conf = [m["conf"] for m in results]
        ax.plot(datasets, acc, label=f"acc - {condition}")
        ax.plot(datasets, conf, label=f"conf - {condition}")

ax.set_xticks(datasets)
# ax.set_xticklabels(["ID", "OOD"])

plt.title("Accuracy/Confidence on rfmnist")
plt.legend()
plt.show()