In [None]:
import json
import os
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

test_dir = "/home/cynthiachen/il-representations/runs/il_test_runs/8"
eval_files = [os.path.join(test_dir, f) for f in os.listdir(test_dir) if 'eval' in f]
eval_files.sort()

with open(os.path.join(test_dir, 'config.json')) as f:
    test_config = json.load(f)
    
benchmark_name = test_config['env_cfg']['task_name']
policy_dir = test_config['policy_dir']
exp_ident = "no-augs" if "no" in policy_dir else "with-augs"

if policy_dir[-1] == '/':
    policy_dir = os.path.dirname(policy_dir)

train_config_file = os.path.join(os.path.dirname(policy_dir), 'config.json')

with open(train_config_file) as f:
    train_config = json.load(f)
    
train_exp_ident = train_config['exp_ident']
if train_exp_ident == 'dmc-full-trajs-consistent-augs':
    if train_config['bc']['n_trajs']:
        train_exp_ident = f"dmc-{train_config['bc']['n_trajs']}-trajs-consistent-augs"

print(train_exp_ident)
print(test_config)
print(benchmark_name)
print(policy_dir)

## Plot return curves

In [None]:
def get_result_dict(eval_files):
    result_dict = {'n_update': []}

    for eval_file in eval_files:
        with open(eval_file) as f:
            test_result = json.load(f)
            
        policy_name = test_result['policy_path'].split('/')[-1]
        nupdate = int(policy_name.split('_')[-2])
        if nupdate > 2000000:
            continue
        for key, value in test_result.items():
            
            # For procgen, we have different dicts for train_level and test_level results
            if isinstance(value, dict) and 'return_mean' in value.keys():
                return_mean = value['return_mean']
                # Initialize list in result_dict if it hasn't been initialized
                if key not in result_dict.keys():
                    result_dict[key] = [return_mean]
                else:
                    result_dict[key].append(return_mean)

        result_dict['n_update'].append(nupdate)
    return result_dict

result_dict = get_result_dict(eval_files)

# The results might not be sorted according to nupdates, so we make sure
# they are sorted correctly here.
sorted_idx = sorted(range(len(result_dict['n_update'])), key=lambda k: result_dict['n_update'][k])
for key, value in result_dict.items():
    result_dict[key] = [result_dict[key][idx] for idx in sorted_idx]

sns.set(style='darkgrid')
ax = sns.lineplot(x='n_update', y='return_mean', data=mean_df)
ax.set_title(f"{benchmark_name}-{train_exp_ident}")

for key, value in result_dict.items():
    if key != 'n_update':
        ax = sns.lineplot(x='n_update', y=key, data=df, label=key)
    ax.set_title(f"{benchmark_name}-{exp_ident}")
    ax.set_ylabel("Return")
    sns.set(style='darkgrid')

    fig = ax.get_figure()
    fig.savefig(f"{test_dir}/return_curve-{key}.png")
    
    print(f"Average number: {sum(value[1:])/len(value[1:])}")

## Plot loss curves

In [None]:
from pathlib import Path
train_folder = Path(policy_dir).parent.absolute()
progress_path = os.path.join(train_folder, 'progress.csv')
progress_df = pd.read_csv(progress_path)

ax = sns.lineplot(x='n_updates', y='loss', data=progress_df)
ax.set_title(f"{benchmark_name}-{train_exp_ident}")

fig = ax.get_figure()
fig.savefig(f"{test_dir}/loss_curve.png")