# Visualising joint training runs

In [None]:
%matplotlib inline
import collections
import html
import os
import json

from IPython.core.display import display, HTML
import numpy as np 
import pandas as pd

search_dirs = ['/scratch/sam/il-rep-jt-2021-05-18/joint_train_runs/']

# Preparing a table of results

In [None]:
def find_eval_json_files(search_dirs):
    """Look for eval.json and config.json files in the given search
    directories. Yield pairs of (config path, eval.json path)."""
    config_name = 'config.json'
    eval_name = 'eval.json'
    for search_dir in search_dirs:
        for dirpath, _, filenames in os.walk(search_dir):
            fn_set = set(filenames)
            if config_name in fn_set and eval_name in fn_set:
                dirpath = os.path.abspath(dirpath)
                conf_path = os.path.join(dirpath, config_name)
                eval_path = os.path.join(dirpath, eval_name)
                yield conf_path, eval_path

def read_and_combine_configs(conf_eval_path_iter):
    """Given an iterator that yields pairs of (config path, eval path), this
    function reads the corresponding (JSON) files and merges them into the
    same dict (specifically, it yields the config dict augmented with an
    extra 'eval' key for eval.json results)."""
    for conf_path, eval_path in conf_eval_path_iter:
        with open(conf_path, 'r') as conf_fp, open(eval_path, 'r') as eval_fp:
            conf_dict = json.load(conf_fp)
            eval_dict = json.load(eval_fp)
            yield {
                'conf_path': conf_path,
                'eval': eval_dict,
                **conf_dict,
            }
    
def make_pandas_table(search_dirs):
    """Look for joint_training.py runs underneath each of the given
    search_dirs, then return results for all runs as as a big Pandas table
    with the following columns:
    
    - exp_ident (human-readable name)
    - conf_path (path to config.json, uniquely identifies run)
    - train_env (train env for method)
    - test_env (evaluation env for this particular row; there may be multiple
      eval envs for each run)
    - return (mean return on test_env)
    """
    path_pair_iter = find_eval_json_files(search_dirs=search_dirs)
    dict_iter = read_and_combine_configs(path_pair_iter)
    frame_dict = collections.defaultdict(lambda: [])
    for data_dict in dict_iter:
        env_cfg = data_dict['env_cfg']
        train_env = env_cfg['task_name']
        eval_dict = data_dict['eval']
        envs_returns = []
        if env_cfg['benchmark_name'] == 'magical':
            for env_eval_dict in eval_dict['full_data']:
                # look for test_env, mean_score
                test_env = env_eval_dict['test_env']
                # shorten long names like "MoveToCorner-Demo-v0" to just "-Demo"
                short_test_env = '-' + test_env.split('-')[1]
                envs_returns.append((short_test_env, env_eval_dict['mean_score']))
            envs_returns.append(('Average', eval_dict['return_mean']))
        else:
            envs_returns.append((train_env, eval_dict['return_mean']))
        for test_env_name, test_env_return in envs_returns:
            frame_dict['exp_ident'].append(data_dict['exp_ident'])
            frame_dict['conf_path'].append(data_dict['conf_path'])
            frame_dict['train_env'].append(train_env)
            frame_dict['test_env'].append(test_env_name)
            frame_dict['return'].append(test_env_return)
    return pd.DataFrame.from_dict(frame_dict)

def mean_std(arr):
    """Aggregation func for pd.pivot_table that displays mean and standard
    deviation of array as a single string."""
    mean = np.mean(arr)
    std = np.std(arr)
    return f'{mean:.2f}±{std:.2f} ({len(arr)})'

pandas_table = make_pandas_table(search_dirs=search_dirs)
for train_env_name, sub_table in pandas_table.groupby('train_env'):
    display(HTML(
        f'<p><strong>Results for {html.escape(train_env_name)}</strong>'
        '<br/>(numbers are mean return ± stddev and seed count)</p>'))
    pivoted_table = pd.pivot_table(sub_table, columns=['test_env'], values='return', index='exp_ident', aggfunc=mean_std)
    display(pivoted_table)