In [11]:
SAVE = True

Combine relevant evaluation files into a single file

In [2]:
import sys
sys.path.append('..')
import os
import pandas as pd
import eval_reallabor_utils
import data_utils

MRT = [2, 3]
combined_eval_file = []

for mrt in MRT:

    results_dirs = [
        f'v3_MRT{mrt}_SimpleModels_every_day/00_summary_7stepsahead_interv',
        f'v3_MRT{mrt}_Kalman_every_day/00_summary_7stepsahead_interv',
        f'v3_MRT{mrt}_every_day/00_summary_7stepsahead_interv',
        f'v3_MRT{mrt}_Transformer_every_day/00_summary_7stepsahead_interv'
        ]

    eval_files = []
    for rd in results_dirs:
        eval_files.append(data_utils.join_ordinal_bptt_path('results', rd, 'evaluation.csv'))
        assert os.path.exists(eval_files[-1]), eval_files[-1]
    eval_reallabor_utils.clear_line_and_print(f'Combining evaluation files')
    mrt_combined_eval_file = eval_reallabor_utils.combine_evaluation_files(eval_files, save_path=None, specifier_name='latent_model')
    mrt_combined_eval_file['MRT'] = mrt

    use_days_from_file = data_utils.train_test_split_path(mrt, 'valid_first_alarms_no_con_smoothed.csv')
    valid_days = pd.read_csv(use_days_from_file, index_col=0)
    valid_days.columns = [int(c) for c in valid_days.columns]

    eval_reallabor_utils.clear_line_and_print(f'Filtering out invalid days')
    eval_only_valid_days = []
    for p, group in mrt_combined_eval_file.groupby('participant'):
        if p in valid_days.columns:
            eval_only_valid_days.append(group.loc[group['test_day'].isin(valid_days[p])])
    mrt_combined_eval_file = pd.concat(eval_only_valid_days, axis=0, ignore_index=True)
    
    mrt_combined_eval_file['prediction'] = mrt_combined_eval_file['prediction'].clip(0, 8)

    combined_eval_file.append(mrt_combined_eval_file)

eval_reallabor_utils.clear_line_and_print(f'Creating grand combined evaluation file')
combined_eval_file = pd.concat(combined_eval_file, axis=0, ignore_index=False)
if not os.path.exists(data_utils.join_ordinal_bptt_path('results/_paper/combined_every_day_ensemble.csv')):
    os.makedirs(data_utils.join_ordinal_bptt_path('results/_paper'), exist_ok=True)
    combined_eval_file.to_csv(data_utils.join_ordinal_bptt_path('results/_paper/combined_every_day_ensemble.csv'))

if 'train_until' in combined_eval_file.columns:
    combined_eval_file.loc[combined_eval_file['latent_model']=='hierarchized clipped-shallow-PLRNN', 'latent_model'] = 'hierarchized-clipped-shallow-PLRNN'
    combined_eval_file.loc[combined_eval_file['train_until'].isna(), 'train_until'] = combined_eval_file.loc[combined_eval_file['train_until'].isna(), 'train_on_data_until_timestep']

combined_eval_file = eval_reallabor_utils.include_exclude_hypers(combined_eval_file, {}, {'steps':7})

Creating grand combined evaluation file                                                                                                                                                                 

In [3]:
from eval_reallabor import reallabor_metrics
m_manager = reallabor_metrics.MetricsManager(combined_eval_file, ['MRT', 'latent_model', 'participant'], use_gt_for_predicted_difference=False, include_r2=False)

In [4]:
mae = m_manager.mae().unstack('latent_model')
difference_to_plrnn = mae - mae[['clipped-shallow-PLRNN']].to_numpy()  # positive = PLRNN wins, negative = other wins

In [20]:
plrnn_var_best = difference_to_plrnn['VAR1'].sort_values(ascending=False)#.head(5)
plrnn_kalman_best = difference_to_plrnn['KalmanFilter'].sort_values(ascending=False)#.head(5)
plrnn_var_kalman_best = difference_to_plrnn[['VAR1', 'KalmanFilter']].sum(axis=1).sort_values(ascending=False)#.head(5)

In [21]:
SHOW = False
REALTIME = False
SORTBY = plrnn_var_kalman_best

os.makedirs(data_utils.join_ordinal_bptt_path('results/_paper/predictions'), exist_ok=True)

import matplotlib.pyplot as plt 
from math import ceil, floor
from plotting_styles import PaperStyle, colors
from plotting_utils import adjust_lim

model_labels = {
                    # 'MovingAverage(1)': 'Last Step', 'MeanPredictor': 'Global Mean', 'InputsRegression': 'Linear Regression', 
                    'VAR1': 'VAR(1)', 
                    'KalmanFilter': 'Kalman Filter', 
                    'clipped-shallow-PLRNN': 'PLRNN', 
                    # 'hierarchized-clipped-shallow-PLRNN': 'H-PLRNN',
                    'Transformer': 'Transformer'
                    }

mae_errorbars = m_manager.mae(func='sem').unstack('latent_model')

m_manager_stepwise = reallabor_metrics.MetricsManager(combined_eval_file, ['MRT', 'latent_model', 'participant', 'steps'], use_gt_for_predicted_difference=False, include_r2=False)

with PaperStyle():

    for i in range(len(SORTBY)):
        participant = SORTBY.index.get_level_values('participant')[i]
        MRT = SORTBY.index.get_level_values('MRT')[i]
        eval_xs = m_manager_stepwise.raw_metrics.xs(SORTBY.index[i], level=SORTBY.index.names, drop_level=False).sort_values(['train_on_data_until_timestep', 'steps']).reset_index()
        eval_xs.loc[eval_xs['steps']==0, 'prediction'] = eval_xs.loc[eval_xs['steps']==0, 'ground_truth']

        eval_xs['timesteps'] = eval_xs[['steps', 'train_on_data_until_timestep']].sum(axis=1).to_numpy()
        first_alarms = eval_xs.loc[eval_xs['steps'] == 0, 'timesteps'].unique()

        participant_data = pd.read_csv(data_utils.single_subject_path(MRT, 'processed_csv_no_con_smoothed_causal', participant))
        day_nr = participant_data.loc[participant_data['Timesteps'].isin(first_alarms), 'DayNr']

        if not REALTIME:
            eval_xs['timesteps'] = eval_xs['timesteps'].astype(str)  
            first_alarms = first_alarms.astype(str)

        if len(first_alarms) != len(day_nr):
            1

        # gt = eval_xs.loc[eval_xs['latent_model']=='clipped-shallow-PLRNN', 'ground_truth'] 
        fig, axes = plt.subplots(1, 2, figsize=(12, 2), width_ratios=(5,1))
        axes[0].plot(eval_xs['timesteps'], eval_xs['ground_truth'], label='ground truth', linestyle='', marker='.')
        for m, model in enumerate(model_labels.keys()):
            for n, ts in enumerate(eval_xs['train_on_data_until_timestep'].unique()):
                mask = (eval_xs['latent_model'] == model) & (eval_xs['train_on_data_until_timestep'] == ts)
                axes[0].plot(eval_xs.loc[mask, 'timesteps'], eval_xs.loc[mask, 'prediction'], label=model_labels[model], color=colors.model_colors[model])
            axes[1].errorbar(m, mae.xs(SORTBY.index[i], level=SORTBY.index.names)[model], 
                            yerr=mae_errorbars.xs(SORTBY.index[i], level=SORTBY.index.names)[model],
                            linestyle='', marker='o', color=colors.model_colors[model])
        # axes[0].legend()
        axes[0].set(xlabel='day', ylabel='mean EMA score', title=f'Participant {participant}', xticks=first_alarms, xticklabels=day_nr)
        axes[1].set_xticks(range(len(model_labels)), labels=model_labels.values(), rotation=90)


        ylim = list(axes[0].get_ylim())
        ylim[0] = floor(ylim[0])
        ylim[1] = ceil(ylim[1])
        for fa in first_alarms:
            axes[0].plot((fa, fa), ylim, color='white', zorder=0)

        adjust_lim(axes[0], axis='y', upper=0.05, lower=0.05)
        adjust_lim(axes[1], axis='x', upper=0.05, lower=0.05)
        adjust_lim(axes[1], axis='y', upper=0.05, lower=0.05)
        if SAVE:
            plt.savefig(data_utils.join_ordinal_bptt_path('results/_paper/predictions', f'participant_{participant}.svg'))
        if not SHOW:
            plt.close(fig)
if SHOW:
    plt.show()