### Setup

In [None]:
%load_ext autoreload
%autoreload 2

import pickle
from functools import partial
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from uq.analysis.constants import (
    all_misspec_queries,
    model_name,
    no_misspec_query,
)
from uq.analysis.dataframes import (
    get_datasets_df,
    load_config,
    load_df,
    make_test_df_for_tuning,
)
from uq.analysis.plot_boxplots import plot_sorted_boxplot
from uq.analysis.plot_calib_all_datasets import (
    compute_barplot_order,
    plot_calib_all_datasets,
    plot_hist_test_statistics,
    hyp_test_hists,
)
from uq.analysis.plot_cd_diagram import draw_my_cd_diagram
from uq.analysis.plot_cohen_d import (
    build_cohen_d,
    plot_cohen_d_boxplot,
)
from uq.analysis.plot_reliability_diagrams import make_reliability_df, plot_reliability_diagrams
from uq.utils.general import filter_dict, savefig, set_notebook_options

set_notebook_options()
np.seterr(invalid='ignore')

path = Path('results')
ext = 'pdf'

### Standard settings in all figures

In [None]:
def standard_setting(
    df,
    interleaved=False,
    lambda_=None,
    misspec=False,
    toy=False,
    posthoc=True,
    add_posthoc_dataset=True,
):
    df = df.query('s.isna() or s == 100')
    df = df.query('n_quantiles.isna() or n_quantiles == 64')

    if toy:
        df = df.query('dataset_group == "toy"')
    else:
        df = df.query('dataset_group != "toy"')
    if not interleaved:
        df = df.query('not interleaved or interleaved.isna()')
    if lambda_ is not None:
        df = df.query('lambda_ in @lambda_ or lambda_.isna()')
    if 'misspecification' in df.index.names and not misspec:
        df = df.query(no_misspec_query)
    if not posthoc:
        df = df.query('posthoc_method.isna()')
    index = df.index.names
    df = df.reset_index()
    model_name_partial = partial(model_name, add_posthoc_dataset=add_posthoc_dataset)
    df['name'] = df.apply(model_name_partial, axis='columns').astype('string')
    df['base_loss'] = pd.Categorical(df['base_loss'], ['nll', 'crps', 'expected_qs']).astype('string')
    df.sort_values('base_loss', kind='stable')
    df = df.set_index(index + ['name'])
    return df

### Loading dataframes

In [None]:
config = load_config('logs/full')

In [None]:
df = load_df(config, tuning=True)
test_df = make_test_df_for_tuning(df, config)
test_df['drop_prob'] = np.nan
test_df = test_df.set_index('drop_prob', append=True)


def op_without_index(df, op):
    names = df.index.names
    df = op(df.reset_index())
    return df.set_index(names)


def op(df):
    if 'lambda_' in df.columns:
        df['lambda_'] = df['lambda_'].fillna(0)
    return df


test_df = op_without_index(test_df, op)

### Selection of best $\lambda$

In [None]:
baseline_query = 'model == "no_regul"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
    'posthoc_method',
    'posthoc_dataset',
]


def duplicate_baseline_per_regul(df):
    df_baseline = df.query(baseline_query)
    df_regul = df.query(f'not ({baseline_query})')
    # Get all regularization groups
    groups = df_regul.groupby(join_by + ['model'], dropna=False).size()
    # Convert the result to a dataframe
    groups = groups.index.to_frame().reset_index(drop=True)
    # Get the baseline for each regularization group
    df_baseline_per_regul = groups.merge(df_baseline, how='left', on=join_by)
    df_baseline_per_regul['lambda_'] = 0
    index = df_regul.index.names
    concat = pd.concat([df_regul.reset_index(), df_baseline_per_regul])
    return concat.set_index(index)


tuned_lambda_values = [0, 0.01, 0.05, 0.2, 1, 5]
test_df_dup = duplicate_baseline_per_regul(test_df.query('lambda_.isna() or lambda_ in @tuned_lambda_values'))
test_df_dup.reset_index().lambda_.value_counts()

In [None]:
columns_to_keep = ['lambda_']
accepted_relative_wis_loss = 0.1


def make_test_df_mean(test_df):
    test_df = test_df.reset_index(level='run_id', drop=True)
    return test_df.groupby(test_df.index.names, dropna=False).mean()


def select_best_lambda(test_df_dup):
    # For more stable results, we compare the *mean* WIS and calib_l1 per model
    df = test_df_dup[['val_wis', 'val_calib_l1']]
    df = make_test_df_mean(test_df_dup)
    # We compare each model (include models without regularization) to their baseline (i.e., same model without regularization)
    baseline = df.query('lambda_ == 0')[['val_wis']]
    df = df.reset_index(level=columns_to_keep)
    join_by_for_mean = join_by + ['model']
    join_by_for_mean.remove('run_id')
    df = df.merge(baseline, how='left', on=join_by_for_mean, suffixes=(None, '_baseline'), validate='many_to_one')
    # Constraint on WIS (models without regularization are guaranteed to be selected)
    mask = df['val_wis'] <= df['val_wis_baseline'] * (1 + accepted_relative_wis_loss)
    df = df[mask]
    df = df.set_index('lambda_', append=True)
    # Take lambda with minimum calibration per model
    idxmin = df.groupby(join_by_for_mean + ['model'], dropna=False)['val_calib_l1'].idxmin()
    df = df.loc[idxmin]
    selected_index = df.index.to_frame().reset_index(drop=True)
    index_names = test_df_dup.index.names
    df = selected_index.merge(test_df_dup.reset_index(), on=join_by_for_mean + ['lambda_'])
    df = df.set_index(index_names)
    return df


test_df_best = select_best_lambda(test_df_dup)
test_df = pd.concat([test_df_best, test_df.query(baseline_query)])
test_df_best.reset_index().lambda_.value_counts(), len(test_df_best)

### Cohen's d figures

In [None]:
metrics = ['test_calib_l1', 'test_wis', 'test_nll', 'test_stddev']
more_metrics = ['test_calib_l2', 'test_rmse', 'test_mae']


def save_cohen_d(*args, path=None, **kwargs):
    path = Path(path)
    fig = plot_cohen_d_boxplot(*args, **kwargs)
    savefig(path / f'cohen_d_boxplot.{ext}', fig)


def metric_queries(metrics):
    return {metric: f'metric == "{metric}"' for metric in metrics}

In [None]:
def plot_all_cohen_d(test_df, default_cohen_d, path, cd_diagrams=True, fig_kwargs={}, **kwargs):
    # All base losses
    add_posthoc_dataset = test_df.reset_index()['posthoc_dataset'].nunique() > 1
    plot_df = standard_setting(test_df, add_posthoc_dataset=add_posthoc_dataset, **kwargs)
    df_cohen = default_cohen_d(plot_df)
    save_cohen_d(df_cohen, metric_queries(metrics), legend=False, path=path / 'main_metrics', **fig_kwargs)
    if cd_diagrams:
        for metric in metrics:
            try:
                draw_my_cd_diagram(plot_df, metric)
            except ValueError:
                print(f'The friedman test failed for {metric}')
            savefig(path / 'cd_diagrams' / f'{metric}.{ext}', dpi=300)

    # All base losses with other metrics
    df_cohen = default_cohen_d(plot_df, metrics=more_metrics)
    save_cohen_d(df_cohen, metric_queries(more_metrics), legend=False, path=path / 'more_metrics', **fig_kwargs)

### Comparison of some post-hoc and regularization vs baseline

In [None]:
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul" and posthoc_method.isna() and pred_type == "mixture" and base_loss == "nll"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
]
columns_to_keep = [
    'base_loss',
    'pred_type',
    'mixture_size',
    'model',
    'posthoc_method',
    'posthoc_dataset',
    'name',
]

cases = [
    baseline_query,
    'model == "no_regul" and (posthoc_method in ["CQR", "rec-kde", "rec-lin", "rec-emp"] or posthoc_method.isna())',
    'model in ["cdf_based", "entropy_based", "truncated"] and posthoc_method.isna()',
]
query = ' or '.join(f'({case})' for case in cases)
new_df = test_df.query(query)
new_df = new_df.query('posthoc_dataset.isna() or posthoc_dataset == "calib"')

for metric in metrics:
    default_cohen_d = partial(
        build_cohen_d,
        metrics=[metric],
        baseline_query=baseline_query,
        join_by=join_by,
        columns_to_keep=columns_to_keep,
    )

    plot_all_cohen_d(
        new_df,
        default_cohen_d,
        path=path / 'posthoc_and_regul_vs_baseline' / metric,
        fig_kwargs={'color_map_name': 'posthoc_or_regul'},
        cd_diagrams=True,
    )

In [None]:
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul" and posthoc_method.isna() and pred_type == "mixture" and base_loss == "nll"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
]
columns_to_keep = [
    'base_loss',
    'pred_type',
    'mixture_size',
    'model',
    'posthoc_method',
    'posthoc_dataset',
    'name',
]

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

cases = [
    baseline_query,
    'model == "no_regul" and (posthoc_method in ["CQR", "rec-kde", "rec-lin", "rec-emp"] or posthoc_method.isna())',
    'model in ["cdf_based", "entropy_based", "truncated"] and posthoc_method.isna()',
]
query = ' or '.join(f'({case})' for case in cases)
new_df = test_df.query(query)
new_df = new_df.query('posthoc_dataset.isna() or posthoc_dataset == "calib"')

plot_all_cohen_d(
    new_df,
    default_cohen_d,
    path=path / 'posthoc_and_regul_vs_baseline',
    fig_kwargs={'color_map_name': 'posthoc_or_regul'},
)

### Comparison between post-hoc on training and calibration dataset

In [None]:
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul" and posthoc_method.isna() and pred_type == "mixture" and base_loss == "nll"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
]
columns_to_keep = [
    'base_loss',
    'pred_type',
    'mixture_size',
    'model',
    'posthoc_method',
    'posthoc_dataset',
    'name',
]

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

cases = [
    baseline_query,
    'model == "no_regul" and posthoc_method in ["CQR", "rec-kde", "rec-lin"]',
    'model == "cdf_based" and posthoc_method.isna()',
]
query = ' or '.join(f'({case})' for case in cases)
new_df = test_df.query(query)
new_df = new_df.query('base_loss in ["nll", "expected_qs"]')

# new_df = test_df.query('posthoc_method.isna() or posthoc_method in ["CQR", "rec-kde", "rec-lin"]')
# new_df = new_df.query('model == "no_regul" or (model == "cdf_based" and posthoc_method.isna())')
# new_df = new_df.query('base_loss in ["nll", "expected_qs"]')
plot_all_cohen_d(
    new_df,
    default_cohen_d,
    path=path / 'posthoc_dataset_train_vs_calib',
    fig_kwargs={'color_map_name': 'posthoc_dataset'},
)

### Comparison of regularization or posthoc vs vanilla per base loss

In [None]:
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul" and posthoc_method.isna()'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
]
columns_to_keep = ['model', 'posthoc_method', 'posthoc_dataset', 'name']

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

for base_loss, base_loss_df in test_df.groupby('base_loss'):
    print(base_loss, flush=True)
    new_df = base_loss_df.query('model == "no_regul" or posthoc_method.isna()')
    new_df = new_df.query('posthoc_dataset == "calib" or posthoc_dataset.isna()')
    plot_all_cohen_d(
        new_df,
        default_cohen_d,
        path=path / 'posthoc_or_regul_vs_vanilla' / base_loss,
        fig_kwargs={'color_map_name': 'posthoc_or_regul'},
    )

### Comparison of ALL post-hoc and regularization vs baseline

In [None]:
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul" and posthoc_method.isna() and pred_type == "mixture" and base_loss == "nll"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
]
columns_to_keep = [
    'base_loss',
    'pred_type',
    'mixture_size',
    'model',
    'posthoc_method',
    'posthoc_dataset',
    'name',
]

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

plot_all_cohen_d(
    test_df,
    default_cohen_d,
    path=path / 'all_posthoc_and_regul_vs_baseline',
    fig_kwargs={'figsize': (20, 10)},
)

### Comparison of regularization methods vs no regularization

In [None]:
# `baseline_query` selects the baselines. The models that are not selected are the compared models.
baseline_query = 'model == "no_regul"'
# `join_by` represents all columns that should be the same when comparing a model and its baseline.
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
]
# `columns_to_keep` represents the columns of the compared model that should be kept in the final result.
# Note that these columns do not have to be the same between a compared model and its baseline.
columns_to_keep = ['model', 'name']

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

plot_all_cohen_d(test_df.query('posthoc_method.isna()'), default_cohen_d, path=path / 'regul_vs_no_regul')

### Comparison of post-hoc methods vs no post-hoc

In [None]:
baseline_query = 'posthoc_method.isna()'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'misspecification',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
    'model',
]
columns_to_keep = ['posthoc_method', 'posthoc_dataset', 'name']

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

plot_all_cohen_d(test_df.query('model == "no_regul"'), default_cohen_d, path=path / 'posthoc_vs_no_posthoc')

### Plot all methods (vanilla, post-hoc and regularization)

In [None]:
def plot_all_methods(test_df, path):
    df = standard_setting(test_df, posthoc=True, add_posthoc_dataset=True)
    df = df.query('posthoc_method.isna() or posthoc_method in ["CQR", "rec-kde"]')
    df = df.query('model == "no_regul" or (model == "cdf_based" and posthoc_method.isna())')

    metrics = [
        'PCE',
        'CRPS',
        'NLL',
    ]
    df = df.rename(columns={'test_calib_l1': 'PCE', 'test_wis': 'CRPS', 'test_nll': 'NLL'})
    plot_sorted_boxplot(df, metrics=metrics)
    savefig(path / f'boxplot.{ext}')
    for metric in metrics:
        draw_my_cd_diagram(df, metric)
        savefig(path / 'cd_diagrams' / f'{metric}.{ext}', dpi=300)


plot_all_methods(test_df, path=path / 'all')

### Reliability diagrams

In [None]:
plot_df = test_df.query('posthoc_method.isna() and model == "no_regul"')
plot_df = standard_setting(plot_df)
rel_df = make_reliability_df(plot_df)
fig = plot_reliability_diagrams(rel_df, ncols=7, agg_run=True)
savefig(path / 'rel_diags' / f'vanilla.{ext}')

plot_df = standard_setting(test_df)
rel_df = make_reliability_df(plot_df)
fig = plot_reliability_diagrams(rel_df, ncols=7, agg_run=True)
savefig(path / 'rel_diags' / f'regul.{ext}')

plot_df = test_df.query('posthoc_method in ["rec-emp", "CQR"] and posthoc_dataset == "calib" and model == "no_regul"')
plot_df = standard_setting(plot_df, add_posthoc_dataset=False)
rel_df = make_reliability_df(plot_df)
fig = plot_reliability_diagrams(rel_df, ncols=7, agg_run=True)
savefig(path / 'rel_diags' / f'posthoc.{ext}')

### Calibration bar plots

In [None]:
plot_df = standard_setting(test_df, add_posthoc_dataset=False).query(
    'posthoc_dataset.isna() or posthoc_dataset == "calib"'
)
order = compute_barplot_order(plot_df, 'name == "MIX-NLL"', 'test_calib_l1')
order.to_pickle(Path(config.log_dir) / 'order.pickle')

In [None]:
test_statistics_path = path / 'test_statistics.pkl'
if test_statistics_path.exists():
    with open(test_statistics_path, 'rb') as f:
        test_statistics = pickle.load(f)
else:
    test_statistics = hyp_test_hists(plot_df, config, nb_test_samples=10000)
    with open(test_statistics_path, 'wb') as f:
        pickle.dump(test_statistics, f)

In [None]:
sort_by_nb_instances = False
if sort_by_nb_instances:
    order = get_datasets_df(config, reload=True).reset_index().sort_values('Total instances').Dataset
    datasets = plot_df.reset_index().dataset.unique()
    order = order[order.isin(datasets)]
plot_calib_all_datasets(
    plot_df,
    config,
    order,
    names=['MIX-NLL'],
    test_statistics=test_statistics,
    path=path / 'pce_and_rel_diags' / f'vanilla_nll.{ext}',
)
plot_calib_all_datasets(
    plot_df,
    config,
    order,
    names=['MIX-CRPS'],
    test_statistics=test_statistics,
    path=path / 'pce_and_rel_diags' / f'vanilla_crps.{ext}',
)
plot_calib_all_datasets(
    plot_df,
    config,
    order,
    names=['SQR-CRPS'],
    test_statistics=test_statistics,
    path=path / 'pce_and_rel_diags' / f'vanilla_wis.{ext}',
)
plot_calib_all_datasets(
    plot_df,
    config,
    order,
    names=['MIX-NLL + Rec-EMP'],
    test_statistics=test_statistics,
    path=path / 'pce_and_rel_diags' / f'posthoc_nll.{ext}',
)

### Histogram of test statistics distribution

In [None]:
with mpl.rc_context({'axes.formatter.limits': (-3, 4)}):
    plot_hist_test_statistics(
        plot_df, config, order, test_statistics, path=path / 'pce_and_rel_diags' / f'hist_test_statistic.{ext}'
    )

In [None]:
### Comparison of posthoc + regularization vs posthoc
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
    'posthoc_method',
    'posthoc_dataset',
]
columns_to_keep = [
    'model',
    'name',
]

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

query_pairs = [
    'posthoc_method == "rec-kde" and model in ["no_regul", "cdf_based"]',
    'posthoc_method == "rec-kde" and model in ["no_regul", "entropy_based"]',
    'posthoc_method == "rec-lin" and model in ["no_regul", "cdf_based"]',
    'posthoc_method == "CQR" and model in ["no_regul", "truncated"]',
]
query = ' or '.join(f'({pair})' for pair in query_pairs)

plot_all_cohen_d(
    test_df.query(query).query('posthoc_dataset == "calib"'),
    default_cohen_d,
    path=path / 'posthoc_and_regul_vs_posthoc',
)

### Comparison of posthoc + regularization vs posthoc

In [None]:
config = load_config('logs/full')

In [None]:
df = load_df(config, tuning=True)
test_df = make_test_df_for_tuning(df, config)
test_df['drop_prob'] = np.nan
test_df = test_df.set_index('drop_prob', append=True)
test_df = op_without_index(test_df, op)

baseline_query = 'model == "no_regul"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
    'posthoc_method',
    'posthoc_dataset',
]
tuned_lambda_values = [0, 0.01, 0.05, 0.2, 1, 5]
test_df_dup = duplicate_baseline_per_regul(test_df.query('lambda_.isna() or lambda_ in @tuned_lambda_values'))

columns_to_keep = ['lambda_']
accepted_relative_wis_loss = 0.1
test_df_best = select_best_lambda(test_df_dup)
test_df = pd.concat([test_df_best, test_df.query(baseline_query)])

In [None]:
# The baseline is precisely no regularization, no post-hoc, mixture prediction and nll loss
baseline_query = 'model == "no_regul"'
join_by = [
    'dataset_group',
    'dataset',
    'run_id',
    'metric',
    'nb_hidden',
    'drop_prob',
    'base_loss',
    'pred_type',
    'mixture_size',
    'posthoc_method',
    'posthoc_dataset',
]
columns_to_keep = [
    'model',
    'name',
]

default_cohen_d = partial(
    build_cohen_d,
    metrics=metrics,
    baseline_query=baseline_query,
    join_by=join_by,
    columns_to_keep=columns_to_keep,
)

query_pairs = [
    'posthoc_method == "rec-kde" and model in ["no_regul", "cdf_based"]',
    'posthoc_method == "rec-kde" and model in ["no_regul", "entropy_based"]',
    'posthoc_method == "rec-lin" and model in ["no_regul", "cdf_based"]',
    'posthoc_method == "CQR" and model in ["no_regul", "truncated"]',
]
query = ' or '.join(f'({pair})' for pair in query_pairs)

plot_all_cohen_d(
    test_df.query(query).query('posthoc_dataset == "calib"'),
    default_cohen_d,
    path=path / 'posthoc_and_regul_vs_posthoc',
)