# Analysis of Latent Recalibration

In [None]:
%load_ext autoreload
%autoreload 2

import logging
from functools import partial
import math
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

from moc.analysis.dataframes import (
    agg_mean_sem,
    format_cell_jupyter,
    format_cell_latex,
    get_datasets_df,
    get_metric_df,
    load_config,
    load_df,
    to_latex,
)
from moc.analysis.highlighter import Highlighter
from moc.analysis.plot_barplots import barplots
from moc.analysis.plot_calibration import plot_reliability_diagrams
from moc.models.tarflow.tarflow import image_shape
from moc.utils import savefig, set_notebook_options

set_notebook_options(logging.WARNING)

plt.rcParams.update(
    {
        'axes.titlesize': 12,
        'axes.labelsize': 12,
        'legend.fontsize': 14,
    }
)
name = 'lr'
path = Path('results') / name

In [None]:
config = load_config(Path('logs') / name)
df_raw = load_df(config)
df = get_metric_df(config, df_raw).reset_index()
df_ds = get_datasets_df(config, reload=False)

## Tables creation

In [None]:
def create_table(df, name, metrics, format_cell_kwargs={}):
    df = df.copy().reset_index()
    df = df.query('metric in @metrics')
    df = df.query('posthoc_method != "HDR" or metric not in ["nll", "latent_calibration"]')
    df['metric'] = pd.Categorical(df['metric'], categories=metrics)
    plot_df = df.reset_index()[['abb', 'metric', 'name', 'value', 'run_id']]
    pivot_df = plot_df.pivot_table(
        index='abb',
        columns=('metric', 'name'),
        values='value',
        aggfunc=agg_mean_sem,
        observed=True,
    )
    styled_table = pivot_df.style.apply(
        Highlighter().highlight_statistically_similar_to_best_per_metric, axis=None
    )
    to_latex(
        styled_table.format(partial(format_cell_latex, **format_cell_kwargs)),
        path / 'tables' / f'{name}_lr.tex',
    )
    return styled_table.format(partial(format_cell_jupyter, add_sem=True, **format_cell_kwargs))


def create_tables(df, name):
    create_table(df, f'{name}_scoring_rules', metrics=['nll', 'energy_score'])
    create_table(df, f'{name}_calibration', metrics=['latent_calibration', 'hdr_calibration'])

## Bar plots

In [None]:
def plot_all_barplots(plot_df, dir_name):
    barplots(plot_df, ['latent_calibration', 'hdr_calibration'])
    savefig(path / 'barplot' / dir_name / 'calibration.pdf')
    barplots(plot_df, ['nll', 'energy_score'], width=4.8)
    savefig(path / 'barplot' / dir_name / 'scoring_rules.pdf')

## Reliability diagrams

In [None]:
def plot_all_reliability_diagrams(df, dir_name):
    plot_reliability_diagrams(
        df.query('posthoc_method != "HDR"'), 'latent_distance', config, ncols=5, ncols_legend=5
    )
    savefig(path / 'reliability_diagrams' / dir_name / 'latent_distance.pdf')
    plot_reliability_diagrams(df, 'hpd', config, ncols=5, ncols_legend=3)
    savefig(path / 'reliability_diagrams' / dir_name / 'hpd.pdf')

## Convex potential flows

In [None]:
plot_df = df.query('model == "MQF2"')
plot_df = plot_df.query('posthoc_density_estimator.isna() or posthoc_density_estimator == "kde"')

plot_all_barplots(plot_df, 'MQF2')
plot_all_reliability_diagrams(plot_df, 'MQF2')
create_tables(plot_df, 'MQF2')

## ARFlow results

In [None]:
plot_df = df.query('model == "ARFlow"')
plot_df = plot_df.query('transform_type == "spline-quadratic" and hidden_size == 64 and num_layers == 2')

plot_all_barplots(plot_df, 'ARFlow')
create_tables(plot_df, 'ARFlow')

## Misspecified convex potential flow

In [None]:
config = load_config(Path('logs') / 'lr_misspecified')
df_raw = load_df(config)
df = get_metric_df(config, df_raw).reset_index()
df_ds = get_datasets_df(config, reload=False)

In [None]:
plot_df = df.query('model == "MQF2"')
plot_df = plot_df.query('posthoc_density_estimator.isna() or posthoc_density_estimator == "kde"')

plot_all_barplots(plot_df, 'MQF2')
create_tables(plot_df, 'MQF2')

## TarFlow

In [None]:
# BPD metric
def nll_to_bpd(nll, k=128):
    n_dims = image_shape.numel()
    # Scale
    bpd = nll + math.log(k) * n_dims
    # Bits per dimension
    return bpd / (n_dims * math.log(2))


def add_bpd(df):
    nll_rows = df.query('metric == "nll"')
    df_bpd = nll_rows.assign(value=lambda x: x['value'].apply(nll_to_bpd), metric='bpd')
    return pd.concat([df, df_bpd], axis=0)

In [None]:
names = ['lr_tarflow_noisy', 'lr_tarflow_no_noise']
for name in names:
    config = load_config(Path('logs') / name)
    df = get_metric_df(config, load_df(config)).reset_index()
    df = add_bpd(df)
    display(
        create_table(
            df,
            name,
            metrics=['latent_calibration_100', 'bpd'],
            format_cell_kwargs={'mean_digits': 4, 'sem_digits': 4},
        )
    )