# Postprocessing

In [None]:
import os
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir(os.pardir)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from src.plotting import set_paper_context, set_poster_context

## Data

In [None]:
PATH_ERROR_DATA = os.path.join('data', 'models')

In [None]:
base_ae = np.load(os.path.join(PATH_ERROR_DATA, 'baseline.npy'))
mlp_ae = np.load(os.path.join(PATH_ERROR_DATA, 'multilayer_perceptron.npy'))
moe_ae = np.load(os.path.join(PATH_ERROR_DATA, 'mixture_of_experts.npy'))
tabnet_ae = np.load(os.path.join(PATH_ERROR_DATA, 'tabnet.npy'))

In [None]:
df = pd.DataFrame(data={'baseline': base_ae,
                        'MLP': mlp_ae,
                        'TabNet': tabnet_ae,
                        'MoE': moe_ae})

In [None]:
with set_paper_context():
    cs = sns.color_palette('rocket', 4)

    fig, ax = plt.subplots(1, 1, figsize=(4.5, 4))
    ax = sns.barplot(df,
                     edgecolor=(cs[2], cs[0], cs[1], cs[3]), facecolor='none',
                     errwidth=1.5, capsize=0.1, lw=2, ax=ax)
    ax.set(xlabel='', ylabel='',
           yticks=[0, 0.06, 0.12],
           yticklabels=[0, 0.06, 0.12],
           ylim=[0, 0.125]
          )

    fig.supxlabel('model')
    fig.supylabel('absolute error [°C]')
    fig.tight_layout()
    sns.despine()

    # fig_name = os.path.join('figures', 'models.pdf')
    # fig.savefig(fig_name, dpi=200, bbox_inches='tight')

In [None]:
with set_poster_context(font_scale=2):
    cs = sns.color_palette('rocket', 4)

    fig, ax = plt.subplots(1, 1, figsize=(6.5, 5.5))
    ax = sns.barplot(df,
                     edgecolor=(cs[2], cs[0], cs[1], cs[3]),
                     facecolor='none',
                     errwidth=2, capsize=0.1, lw=3, ax=ax)
    for container in ax.containers:
        ax.bar_label(container,
                     label_type='center',
                     fmt='%.3f', color='k')
    ax.set(xlabel='', ylabel='',
           title='absolute estimation error (°C)',
           yticks=[0, 0.06, 0.12],
           yticklabels=[0, 0.06, 0.12],
           ylim=[0, 0.125])
    fig.tight_layout()
    sns.despine()

    # fig_name = os.path.join('figures', 'poster', 'models.png')
    # fig.savefig(fig_name, dpi=500, bbox_inches='tight')