### Setup

In [None]:
%load_ext autoreload
%autoreload 2

from functools import partial
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from uq.analysis.dataframes import (
	load_config, load_df, make_test_df_for_tuning, fillna
)
from uq.analysis.plot_cohen_d import build_cohen_d, plot_cohen_d_boxplot
from uq.utils.general import filter_dict, set_notebook_options, savefig
from uq.analysis.plot_cd_diagram import draw_my_cd_diagram

set_notebook_options()

path = Path('results')
ext = 'pdf'

### Loading the dataframe

In [None]:
config = load_config('logs/hparam_tuning/')

In [None]:
df = load_df(config, tuning=True)
test_df = make_test_df_for_tuning(df, config)

def op_without_index(df, op):
    names = df.index.names
    df = op(df.reset_index())
    return df.set_index(names)

def op(df):
    if 'lambda_' in df.columns:
        df['lambda_'] = df['lambda_'].fillna(0)
    return df

op_without_index(test_df, op)

### Helper functions

In [None]:
def hparam_str(hparam, value):
    if hparam == 'mixture_size':
        return f'{value} components'
    elif hparam == 'nb_hidden':
        return f'{value} hidden layers'
    elif hparam == 'n_quantiles':
        return f'{value} quantiles'
    elif hparam == 'lambda_':
        return rf'$\lambda={value}$'


def model_name(d, hparams):
    return ', '.join([hparam_str(hparam, d[hparam]) for hparam in hparams])


def standard_setting(df, hparams=None):
    index = df.index.names
    df = df.reset_index()
    model_name_partial = partial(model_name, hparams=hparams)
    df['name'] = df.apply(model_name_partial, axis='columns').astype('string')
    df['base_loss'] = pd.Categorical(df['base_loss'], ['nll', 'crps', 'expected_qs']).astype('string')
    df.sort_values('base_loss', kind='stable')
    df = df.set_index(index + ['name'])
    return df

In [None]:
metrics = ['test_calib_l1', 'test_wis', 'test_nll', 'test_stddev']

def metric_queries(metrics):
    return {
        metric: f'metric == "{metric}"'
        for metric in metrics
    }

def plot_all_cohen_d(plot_df, default_cohen_d, path, fig_kwargs={}):
    df_cohen = default_cohen_d(plot_df)
    fig = plot_cohen_d_boxplot(df_cohen, metric_queries(metrics), legend=False, **fig_kwargs)
    savefig(path / f'cohen_d_boxplot.{ext}', fig)

### Comparison with different numbers of hidden layers

In [None]:
# `baseline_query` selects the baselines. The models that are not selected are the compared models.
baseline_query = 'nb_hidden == 3'
# `join_by` represents all columns that should be the same when comparing a model and its baseline.
join_by = [
    'dataset_group', 'dataset', 'run_id', 'metric',
    'base_loss', 'pred_type', 'mixture_size'
]
# `columns_to_keep` represents the columns of the compared model that should be kept in the final result.
# Note that these columns do not have to be the same between a compared model and its baseline.
columns_to_keep = ['name', 'nb_hidden']

default_cohen_d = partial(build_cohen_d, 
    metrics=metrics, baseline_query=baseline_query, join_by=join_by, columns_to_keep=columns_to_keep
)

plot_df = standard_setting(test_df, hparams=['nb_hidden']).query('model == "no_regul" and mixture_size == 3 and base_loss == "nll"')
plot_all_cohen_d(plot_df, default_cohen_d, path=path / 'hparams' / 'nb_hidden')

### Comparison of mixtures with different numbers of components

In [None]:
baseline_query = 'mixture_size == 3'
join_by = [
    'dataset_group', 'dataset', 'run_id', 'metric',
    'base_loss', 'pred_type', 'nb_hidden'
]
columns_to_keep = ['name', 'mixture_size']

default_cohen_d = partial(build_cohen_d, 
    metrics=metrics, baseline_query=baseline_query, join_by=join_by, columns_to_keep=columns_to_keep
)

plot_df = standard_setting(test_df, hparams=['mixture_size']).query('model == "no_regul" and nb_hidden == 3 and base_loss == "nll"')
plot_all_cohen_d(plot_df, default_cohen_d, path=path / 'hparams' / 'mixture_size')

### Comparison of quantile predictions with different numbers of quantiles

In [None]:
baseline_query = 'n_quantiles == 64'
join_by = [
    'dataset_group', 'dataset', 'run_id', 'metric',
    'base_loss', 'pred_type', 'nb_hidden', 'mixture_size'
]
columns_to_keep = ['name', 'n_quantiles']

default_cohen_d = partial(build_cohen_d, 
    metrics=metrics, baseline_query=baseline_query, join_by=join_by, columns_to_keep=columns_to_keep
)

plot_df = standard_setting(test_df, hparams=['n_quantiles']).query('model == "no_regul" and nb_hidden == 3 and base_loss == "expected_qs"')
plot_all_cohen_d(plot_df, default_cohen_d, path=path / 'hparams' / 'n_quantiles')

### Comparison with different numbers of hidden layers and components in the mixtures

In [None]:
baseline_query = 'nb_hidden == 3 and mixture_size == 3'
join_by = [
    'dataset_group', 'dataset', 'run_id', 'metric',
    'base_loss', 'pred_type',
]
columns_to_keep = ['name', 'nb_hidden', 'mixture_size',]

default_cohen_d = partial(build_cohen_d, 
    metrics=metrics, baseline_query=baseline_query, join_by=join_by, columns_to_keep=columns_to_keep
)

plot_df = standard_setting(test_df, hparams=['nb_hidden', 'mixture_size']).query('model == "no_regul" and base_loss == "nll"')
plot_all_cohen_d(plot_df, default_cohen_d, path=path / 'hparams' / 'nb_hidden_and_mixture_size')

### Comparison with different regularization strengths

In [None]:
join_by = [
    'dataset_group', 'dataset', 'run_id', 'metric',
    'base_loss', 'pred_type',
]
columns_to_keep = ['name', 'lambda_', 's']

for model, s in [('cdf_based', 50), ('quantile_based', 0.01), ('entropy_based', 0.01)]:
    baseline_query = f'lambda_ == 0.2 and s == {s}'
    default_cohen_d = partial(build_cohen_d, 
        metrics=metrics, baseline_query=baseline_query, join_by=join_by, columns_to_keep=columns_to_keep
    )

    plot_df = standard_setting(test_df.query(f's == {s}'), hparams=['lambda_']).query(f'model == "{model}" and base_loss == "nll"')
    plot_all_cohen_d(plot_df, default_cohen_d, path=path / 'hparams' / 'lambda_' / model)

In [None]:
test_df.query('pred_type == "mixture" and model == "cdf_based"').groupby('lambda_', dropna=False)[['test_nll', 'test_calib_l1']].agg(lambda x: x.median(skipna=False))