In [None]:
import json
import os
import seaborn as sns
import numpy as np
import pandas as pd

test_dir = "/scratch/cynthiachen/ilr-results/dmc-10-trajs-temporal-cpc-2021-05-17/test-results/1"
eval_files = [os.path.join(test_dir, f) for f in os.listdir(test_dir) if 'eval' in f]
eval_files.sort()

with open(os.path.join(test_dir, 'config.json')) as f:
    test_config = json.load(f)
    
benchmark_name = test_config['env_cfg']['task_name']
policy_dir = test_config['policy_dir']

if policy_dir[-1] == '/':
    policy_dir = os.path.dirname(policy_dir)

train_config_file = os.path.join(os.path.dirname(policy_dir), 'config.json')

with open(train_config_file) as f:
    train_config = json.load(f)
    
train_exp_ident = train_config['exp_ident']
if train_exp_ident == 'dmc-full-trajs-consistent-augs':
    if train_config['bc']['n_trajs']:
        train_exp_ident = f"dmc-{train_config['bc']['n_trajs']}-trajs-consistent-augs"

print(train_exp_ident)
print(test_config)
print(benchmark_name)
print(policy_dir)

## Plot return curves

In [None]:
def get_mean_nupdate_list(eval_files):
    mean_list = []
    nupdate_list = []
    for eval_file in eval_files:
        with open(eval_file) as f:
            test_result = json.load(f)
        mean_list.append(test_result['return_mean'])
        policy_name = test_result['policy_path'].split('/')[-1]
        nupdate = int(policy_name.split('_')[-2])
        nupdate_list.append(nupdate)
    return mean_list, nupdate_list

mean_list, nupdate_list = get_mean_nupdate_list(eval_files)

# The lists might not be sorted according to nupdates, so we make sure
# they are sorted correctly here.
sorted_idx = sorted(range(len(nupdate_list)), key=lambda k: nupdate_list[k])
mean_list = [mean_list[idx] for idx in sorted_idx]
nupdate_list = [nupdate_list[idx] for idx in sorted_idx]

mean_dict = {'return_mean': mean_list, 'n_update': nupdate_list}
mean_df = pd.DataFrame(mean_dict)

sns.set(style='darkgrid')
ax = sns.lineplot(x='n_update', y='return_mean', data=mean_df)
ax.set_title(f"{benchmark_name}-{train_exp_ident}")

fig = ax.get_figure()
fig.savefig(f"{test_dir}/return_curve.png")

print(mean_list, nupdate_list)
print(sorted_idx)

## Plot loss curves

In [None]:
from pathlib import Path
train_folder = Path(policy_dir).parent.absolute()
progress_path = os.path.join(train_folder, 'progress.csv')
progress_df = pd.read_csv(progress_path)

ax = sns.lineplot(x='n_updates', y='loss', data=progress_df)
ax.set_title(f"{benchmark_name}-{train_exp_ident}")

fig = ax.get_figure()
fig.savefig(f"{test_dir}/loss_curve.png")