In [3]:
"""
    hyperparamters: can include any model hyperparameters plus 'steps' for ahead prediction.
        These hyperparameters must appear as 
        columns in the evaluation.csv document created by evaluate_complete_directory.
    combiplots: if true, creates plots for every combination of 1-2 hyperparameters. Else, only 1 plot.
"""
import os, sys
sys.path.append('..')
import data_utils

main_dir = data_utils.join_ordinal_bptt_path('results/hiera_MRT1_5splits_small_batch')
eval_dir = '00_summary_7stepsahead'
hyperparameters = ['seq_per_subject', 'subjects_per_batch']
metrics = ['mae', 'diff_mae', 'change_mae', 'training_time', 'n_params']
include_hyper = {}#{'participant': [15, 34, 35, 52]}#[15, 17, 28, 34, 35, 52, 57, 61, 62]}
exclude_hyper = {'feature': ['EMA_emotion_control', 'EMA_emotion_change']}#, 'participant':24}
outlier_threshold = 10
pairwise_hypers = True
metric_subplots = True
file_format = 'png'
sort_index = True
plot_kind = 'line'       # "auto" for automatic choice (steps are line, rest is bars)

import os
assert os.path.exists(os.path.join(main_dir, eval_dir))

In [4]:
import sys
sys.path.append('..')
os.environ['PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT'] = '1.0'

import itertools as it
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import eval_reallabor_utils
import custom_rcparams as crc
import warnings
    
print('Plotting metric summaries...')
eval_dir = os.path.join(main_dir, eval_dir)
results = pd.read_csv(os.path.join(eval_dir, 'evaluation.csv'), index_col=0)
results = eval_reallabor_utils.include_exclude_hypers(results, include_hyper, exclude_hyper)
save_path = os.path.join(eval_dir, 'summary_plots')
os.makedirs(save_path, exist_ok=True)
# sns.set_theme()
if pairwise_hypers:
    hyperset = it.chain.from_iterable(it.combinations(
        hyperparameters, r) for r in range(1, 3))
else:
    hyperset = [hyperparameters]
hyperset = list(hyperset)
for hp in tqdm(hyperset):
    hp = list(hp)
    metric_results, errorbars = eval_reallabor_utils.calculate_metrics(results, hp, metrics, outlier_threshold=outlier_threshold)
    if sort_index:
        metric_results.sort_index(inplace=True)
        errorbars.sort_index(inplace=True)
    width = len(results[hp[0]].unique())*2
    if metric_subplots:
        fig, axes = plt.subplots(len(metrics), 1, sharex=True, figsize=(width, 1+4*len(metrics)), squeeze=False)
        axes = axes.flatten()
    for k, m in enumerate(metrics):
        if metric_subplots:
            ax = axes[k]
        else:
            fig, ax = plt.subplots(1, 1, figsize=(width, 5), squeeze=True)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            if plot_kind == 'line' or (plot_kind == 'auto' and 'steps' in hp[0]):                
                metric_results.unstack()[m].plot(ax=ax, kind='line', yerr=errorbars.unstack()[m],
                                                    capsize=5, linestyle='-', marker='o')
            else:
                metric_results.unstack()[m].plot(ax=ax, kind='bar', yerr=errorbars.unstack()[m],
                                                            capsize=5)
        if m in ['diff_mse', 'var_adjusted_diff_mse', 'mean_adjusted_diff_mse', 'inv_r_squared']:
            ax.plot(ax.get_xlim(), (0, 0), linestyle='-', color='black')
        elif m in ['rel_mse']:
            ax.set_ylim((max(0, ax.get_ylim()[0]), ax.get_ylim()[1]))
            ax.plot(ax.get_xlim(), (1, 1), linestyle='-', color='black')
        ax.set_ylabel(m)
        if (not metric_subplots):
            suptitle = f'{m} wrt. {", ".join(hp)}'
            plt.suptitle(suptitle)
            fig.tight_layout()                
            fig.savefig(fname=os.path.join(save_path, suptitle + '.' + file_format), dpi=200)
            plt.close()

    if metric_subplots:     
        suptitle = ', '.join(metrics) + ' wrt. ' + ', '.join(hp)
        plt.suptitle(suptitle)
        fig.tight_layout()
        fig.savefig(fname=os.path.join(save_path, suptitle + '.' + file_format),
                    format=file_format, dpi=200)
        plt.close()

    metric_results.to_csv(os.path.join(save_path, suptitle+'.csv'))

Plotting metric summaries...


100%|██████████| 3/3 [00:03<00:00,  1.15s/it]
