In [None]:
import json
import os
import subprocess

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

from toyenv_one import OneStepEnvVOne
from toyenv_two import OneStepEnvVTwo
from toyenv_three import OneStepEnvVThree


def find_root_dir():
    try:
        root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode('utf-8')
    except Exception as e:
        root_dir = os.getcwd()[:os.getcwd().find('action-robust-decision-transformer')+len('action-robust-decision-transformer')]
    return root_dir + ('' if root_dir.endswith('action-robust-decision-transformer') else '/action-robust-decision-transformer') + "/codebase/toy_problem"

In [None]:
EVALS_PER_MODEL = 1024
envs_to_consider = ["v1", "v3"]

In [None]:
def return_correct_return(env, target):
    if env == "v1":
        if target == 0.5: return 0.5
        if target == 2.0: return 1.5
        if target == 1.5: return 1.5
    elif env == "v2":
        if target == 5: return -3.0
        if target == -3: return -3.0
        if target == -2: return -2.0
        if target == 1: return 1.0
        if target == 2: return 2.0
    elif env == "v3":
        if target == 0: return 0.0
        if target == 1.5: return 0.75
        if target == 2.5: return 0.75
        if target == 3: return 0.75
        if target == 0.75: return 0.75
        if target == 1: return 0.75



In [None]:
for env_version in envs_to_consider:
    results_dir = find_root_dir() + f"/results/toy-env-{env_version}"

    model_types = []
    targets = []
    means = []
    stdevs = []

    for model_type in ['dt', 'ardt-vanilla', 'ardt-multipart', 'reality']:
        if model_type != "reality":
            results_path = f"{results_dir}/{model_type}/results.json"
            with open(results_path, "r") as f:
                results = json.load(f)["worstcase"]

            # plot mean and std of mean returns
            for entry in results:
                model_types.append(model_type)
                targets.append(entry["target_return"])
                means.append(np.mean(entry["mean_returns"]))
                stdevs.append(np.mean(entry["std_returns"] / np.sqrt(EVALS_PER_MODEL)))
        else:
            for entry in results:
                model_types.append("zbest-conditional-worst-case")
                targets.append(entry["target_return"])
                means.append(return_correct_return(env_version, entry["target_return"]))
                stdevs.append(0.0)

    df = pd.DataFrame(
        {
            "model_type": model_types,
            "target": targets,
            "mean": means,
            "stdev": stdevs,
        }
    )

    df = df.sort_values(by=['model_type', 'target'], ascending=False)
    df['model_type'] = df['model_type'].apply(lambda x: 'best-conditional-worst-case' if x == 'zbest-conditional-worst-case' else x)
    df['mean'] = df['mean'].apply(lambda x: 0.003 if x == 0.0 else x)

    plt.figure(figsize=(10, 6))
    custom_palette = ['#FFA500', '#66CCFF', '#3399FF', '#336699']
    ax = sns.barplot(x='target', y='mean', data=df, hue='model_type', palette=custom_palette)

    # apply error bar in the correct places!!
    x_coords = [p.get_x() + 0.5 * p.get_width() for p in ax.patches]
    y_coords = [p.get_height() for p in ax.patches]

    # sort x_coords and y_coords according to barplot order!!
    x_coords = [x for _, x in sorted(zip(y_coords, x_coords))]
    y_coords = sorted(y_coords)
    for i in range(len(x_coords)):
         ax.errorbar(x_coords[i], y_coords[i], yerr=df['stdev'][i], ecolor='black', capsize=3)

    plt.xlabel("Target Returns")
    plt.ylabel("Observed Returns")
    plt.title(f"Returns for Toy Environment Version {min(2, int(env_version[-1]))}")
    plt.legend(fontsize="small", bbox_to_anchor=(-0.05, -0.15), loc="lower left", ncol=2, borderaxespad=0.0)

    if env_version == "v2":
        plt.yticks(np.arange(int(np.min(np.array(means) - np.array(stdevs))) + 1, int(np.max(np.array(means) + np.array(stdevs))) + 1, 1.0))

    plt.show();
