In [None]:
import json
import os

import matplotlib.pyplot as plt
import numpy as np

from collections import defaultdict

from transformers import DecisionTransformerConfig

#
import warnings
warnings.filterwarnings('ignore')

In [None]:
DIR = "./eval-outputs-pipeline"
MODEL_TYPES = ['dt', 'ardt-simplest', 'ardt-vanilla', 'ardt-full']

In [None]:
results_paths = sorted([DIR + "/" + r for r in os.listdir(DIR) if os.path.isdir(os.path.join(DIR, r))])
results_paths

In [None]:
models_to_results = defaultdict(list)

for i, path in enumerate(results_paths):
    with open(path + "/env-adv.json", "r") as f:
        model_returns = json.load(f)['ep_return']
    model_name = path.split("/")[-1][:-10]
    model_config = DecisionTransformerConfig.from_pretrained("afonsosamarques/" + path.split("/")[-1], use_auth_token=True)
    done = False
    for type in MODEL_TYPES:
        if model_name.startswith(type + "-"):
            done = True
            model_type = type
            models_to_results['type'].append(model_type)
    if not done: continue
    models_to_results['name'].append(model_name)
    models_to_results['number'].append(i)
    models_to_results['return_mean'].append(int(np.mean(model_returns)))
    models_to_results['return_std'].append(int(np.std(model_returns)))
    models_to_results['lambda1'].append(model_config.lambda1)
    models_to_results['lambda2'].append(model_config.lambda2)
    dataset = model_name.split("-")[-1]
    models_to_results['dataset'].append(dataset)
    model_id = f"{model_type} | {dataset} | l1 = {model_config.lambda1} | l2 = {model_config.lambda2}"
    models_to_results['id'].append(model_id)

def get_length(length, max_length, part, nparts=4):
    multiple = 0.65 if (part % nparts) == 0 else (0.90 if (part % nparts) == 1 else 0)
    return max_length if length == max_length else max_length + int((max_length - length) * multiple)

parts = [s.split("|") for s in models_to_results['id']]
counts = [i for i in range(len(parts[0]))]
max_lengths = [max(len(part[i]) for part in parts) for i in range(len(parts[0]))]
aligned_strings = ["|".join(part.ljust(get_length(len(part), max_length, ct)) for part, max_length, ct in zip(parts[i], max_lengths, counts)) for i in range(len(parts))]
models_to_results['id'] = aligned_strings
models_to_results['id']

In [None]:
filter_by_dataset = False  # FIXME
datasets = list(set([m for m in models_to_results['dataset']]))
datasets

In [None]:
if filter_by_dataset:
    dataset_idx = -1  # FIXME
    dataset = datasets[dataset_idx]
    models_to_results = {k: [v for i, v in enumerate(models_to_results[k]) if models_to_results['dataset'][i] == dataset] for k in models_to_results.keys()}
models_to_results

In [None]:
def get_color(model_type, idx):
    if model_type == 'dt':
        return plt.cm.get_cmap('Blues', len(models_to_results['name']))(idx)
    elif model_type == 'ardt-simplest':
        return plt.cm.get_cmap('Oranges', len(models_to_results['name']))(idx)
    elif model_type == 'ardt-vanilla':
        return plt.cm.get_cmap('Reds', len(models_to_results['name']))(idx)
    elif model_type == 'ardt-full':
        return plt.cm.get_cmap('Purples', len(models_to_results['name']))(idx)
    else:
        raise RuntimeError(f"Model type {model_type} not recognized.")


def get_ecolor(dataset, idx):
    dataset_idx = datasets.index(dataset)
    colors = [
        'magenta',
        'sienna',
        'olivedrab',
        'grey',
    ]
    return colors[dataset_idx]
    

plt.figure(figsize=(10, 10))
for i in range(9, len(models_to_results['return_mean'])):
    plt.scatter(models_to_results['return_mean'][i], models_to_results['return_std'][i], s=100, color=get_color(models_to_results['type'][i], i), edgecolors=get_ecolor(models_to_results['dataset'][i], i), linewidths=3, label=models_to_results['id'][i])
plt.xlabel("Return Mean")
plt.ylabel("Return Std")
plt.title("Model Comparison")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., ncol=1)
plt.show();