In [3]:
'''
    Makes trajectory plots of the "best" models against data, using their plot_obs_simulated function.
    Saves those plots into 'freely_generated_against_data(best)' and 'test_set_ahead_prediction(best)' subfolders.
    for_each:       List of hyperparameters, selects the best n models for each value of the 
                    parameters specified here (results in more plots)
    include_hyper:  Dict with entries hyperparameter:value, only models that have this
                    hyperparameter-value combination will be included in the selection
    exclude_hyper:  Dict with entries hyperparameter:value, models that have this
                    hyperparameter-value combination will be excluded from the selection
'''

main_dir = r'/home/janik.fechtelpeter/Documents/bptt/results/MRT1_Gridsearch02_EqualSpacing'
eval_dir = '00_summary_gridsearch_best_runs'
epoch_criterion = 'latest'
mode = 'min'
for_each = ['dim_z']
n_best = 3
include_hyper = {}
exclude_hyper = {'feature': ['EMA_emotion_control', 'EMA_emotion_change']}
# exclude_hyper = {}
prewarm_steps = 4
plot_ahead_prediction=True
plot_generated=True
plot_predicted=True
format='png'


In [4]:
import sys
sys.path.append('..')

import os
from typing import Iterable
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import eval_reallabor_utils
import custom_rcparams as crc

eval_dir = os.path.join(main_dir, eval_dir)
results = pd.read_csv(os.path.join(eval_dir, 'evaluation.csv'))
if plot_ahead_prediction:
    os.makedirs(os.path.join(
        eval_dir, 'test_set_ahead_prediction(best)'), exist_ok=True)
if plot_generated:
    os.makedirs(os.path.join(
        eval_dir, 'generated_against_data(best)'), exist_ok=True)
if plot_predicted:
    os.makedirs(os.path.join(
        eval_dir, 'forced_prediction(best)'), exist_ok=True)

results = eval_reallabor_utils.include_exclude_hypers(results, include_hyper, exclude_hyper)
results['abs_diff'] = (results['ground_truth'] - results['prediction']).abs()
features = results['feature'].unique()
if len(for_each) > 0:
    grouped = results.groupby(for_each)
else:
    grouped = [('all', results)]
    print(f'Plotting {n_best} best trajectories over all.')
for name, group in grouped:
    if len(for_each)>0:
        if not isinstance(name, Iterable):
            name = (name, )
        print(f'Plotting {n_best} best trajectories for ' + ", ".join([f"{for_each[k]}=={name[k]}" for k in range(len(for_each))]))
    if mode == 'min':
        group = group.sort_values(by='abs_diff', ascending=True)[:n_best]
    elif mode == 'max':
        group = group.sort_values(by='abs_diff', ascending=False)[:n_best]
    for i, best in tqdm(group.iterrows(), total=len(group)):
        model_id = best['model_id']
        run = best['run']
        run = str(run).zfill(3)
        model_dir = os.path.join(main_dir, model_id, run)
        model, test_dataset = eval_reallabor_utils.load_model_and_data(model_dir, epoch_criterion=epoch_criterion)
        train_dataset = model.dataset
        filename = model_id + f' run {run} crit mse'
        test_data, test_inputs = test_dataset.data()
        train_data, train_inputs = train_dataset.data()
        if plot_ahead_prediction:
            prewarm_data = train_data[-prewarm_steps:]
            if train_inputs is not None:
                prewarm_inputs = train_inputs[-prewarm_steps:]
            else:
                prewarm_inputs = None
            model.plot_generated_against_obs(test_data, test_inputs, prewarm_data=prewarm_data,
                                                prewarm_inputs=prewarm_inputs,
                                                plot_mean=False, ylim=(0.5, 7.5), features=features)
            plt.suptitle(f'{best["participant"]}, run {run}, mae={best["abs_diff"]:.3f}')
            plt.yticks(np.arange(7)+1, [1,'',3,'',5,'',7])
            plt.savefig(os.path.join(eval_dir, 'test_set_ahead_prediction(best)', filename + f'.{format}'), dpi=200)
            plt.close()
        if plot_generated:
            model.plot_generated_against_obs(train_data, train_inputs, ylim=(0.5, 7.5), features=features)
            plt.suptitle(f'{best["participant"]}, run {run}, mae={best["abs_diff"]:.3f}')
            plt.yticks(np.arange(7)+1, [1,'',3,'',5,'',7])
            plt.savefig(os.path.join(eval_dir, 'generated_against_data(best)', filename + f'.{format}'), dpi=200)
            plt.close()
        if plot_predicted:
            model.plot_prediction(train_data, train_inputs, ylim=(0.5,7.5))
            plt.suptitle(f'{best["participant"]}, run {run}, mae={best["abs_diff"]:.3f}')
            plt.savefig(os.path.join(eval_dir, 'forced_prediction(best)', filename + f'.{format}'), dpi=200)
            plt.close()

  for name, group in grouped:


Plotting 3 best trajectories for dim_z==1


100%|██████████| 3/3 [00:26<00:00,  8.97s/it]


Plotting 3 best trajectories for dim_z==3


100%|██████████| 3/3 [00:25<00:00,  8.47s/it]


Plotting 3 best trajectories for dim_z==5


100%|██████████| 3/3 [00:23<00:00,  7.84s/it]


Plotting 3 best trajectories for dim_z==7


100%|██████████| 3/3 [00:24<00:00,  8.32s/it]


Plotting 3 best trajectories for dim_z==9


100%|██████████| 3/3 [00:26<00:00,  8.81s/it]


Plotting 3 best trajectories for dim_z==11


100%|██████████| 3/3 [00:26<00:00,  8.96s/it]


Plotting 3 best trajectories for dim_z==13


100%|██████████| 3/3 [00:30<00:00, 10.14s/it]


Plotting 3 best trajectories for dim_z==15


100%|██████████| 3/3 [00:31<00:00, 10.62s/it]


Plotting 3 best trajectories for dim_z==20


100%|██████████| 3/3 [00:32<00:00, 10.86s/it]
