In [None]:
import json
import os
import subprocess

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

from toyenv_one import OneStepEnvVOne
from toyenv_two import OneStepEnvVTwo
from toyenv_three import OneStepEnvVThree


def find_root_dir():
    try:
        root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode('utf-8')
    except Exception as e:
        root_dir = os.getcwd()[:os.getcwd().find('action-robust-decision-transformer')+len('action-robust-decision-transformer')]
    return root_dir + ('' if root_dir.endswith('action-robust-decision-transformer') else '/action-robust-decision-transformer') + "/codebase/toy_problem"

In [None]:
ARCH_TYPE = "deterministic"  # "stochastic" or "deterministic"
EVALS_PER_MODEL = 1024

In [None]:
for env_version in ["v1", "v2", "v3"]:
    results_dir = find_root_dir() + f"/results/{ARCH_TYPE}/toy-env-{env_version}"

    model_types = []
    targets = []
    means = []
    stdevs = []

    for model_type in ['dt', 'ardt-simplest', 'ardt-full']:
        results_path = f"{results_dir}/{model_type}/results.json"
        with open(results_path, "r") as f:
            results = json.load(f)["worstcase"]

        # plot mean and std of mean returns
        for entry in results:
            model_types.append(model_type)
            targets.append(entry["target_return"])
            means.append(np.mean(entry["mean_returns"]))
            stdevs.append(np.mean(entry["std_returns"] / np.sqrt(EVALS_PER_MODEL)))

    df = pd.DataFrame(
        {
            "model_type": model_types,
            "target": targets,
            "mean": means,
            "stdev": stdevs,
        }
    )

    plt.figure(figsize=(10, 6))
    ax = sns.barplot(x='target', y='mean', data=df, hue='model_type', palette='viridis')
    x_coords = [p.get_x() + 0.5 * p.get_width() for p in ax.patches]
    y_coords = [p.get_height() for p in ax.patches]
    ax.errorbar(x=x_coords, y=y_coords, yerr=df["stdev"], fmt="none", c="k")

    plt.xlabel("Target Returns")
    plt.ylabel("Observed Returns")
    plt.title(f"Target vs Observed for Env-{env_version}")
    plt.legend(fontsize="small", bbox_to_anchor=(1.22, 1.0))

    if env_version == "v2":
        plt.yticks(np.arange(int(np.min(np.array(means) - np.array(stdevs))) + 1, int(np.max(np.array(means) + np.array(stdevs))) + 1, 1.0))

    plt.show();


In [None]:
def model_type_to_id(model_type):
    if model_type == "dt":
        return 2
    elif model_type == "ardt-simplest":
        return 3
    elif model_type == "ardt-full":
        return 4
    elif model_type == "correct":
        return 1
    elif model_type == "best":
        return 0
    else:
        raise ValueError(f"Invalid model type: {model_type}")

viridis_palette = sns.color_palette("viridis", n_colors=3)
specific_colors = ['seagreen', 'gold']
custom_palette = specific_colors + viridis_palette

for env_version in ["v1", "v2", "v3"]:
    results_dir = find_root_dir() + f"/results/{ARCH_TYPE}/toy-env-{env_version}"
    env_class = OneStepEnvVOne if env_version == "v1" else (OneStepEnvVTwo if env_version == "v2" else OneStepEnvVThree)

    model_ids = []
    model_types = []
    targets = []
    pr_actions_selected = []
    pr_actions_freqs = []

    for model_type in ['dt', 'ardt-simplest', 'ardt-full']:
        results_path = f"{results_dir}/{model_type}/results.json"
        with open(results_path, "r") as f:
            results = json.load(f)["worstcase"]

        for entry in results:
            for pr_action, pr_action_freq in entry["pr_action_freqs"].items():
                model_ids.append(model_type_to_id(model_type))
                model_types.append(model_type)
                targets.append(entry["target_return"])
                pr_actions_selected.append(str(pr_action).replace(".", "").replace(" ", ", "))
                pr_actions_freqs.append(np.mean(pr_action_freq))

    for target_return in np.unique(targets):
        all_actions = env_class.get_all_possible_pr_actions()
        for th_pr_action, th_pr_action_freq in env_class.get_correct_pr_action(target_return):
            model_ids.append(model_type_to_id('correct'))
            model_types.append('correct')
            targets.append(target_return)
            pr_actions_selected.append(str(th_pr_action))
            pr_actions_freqs.append(th_pr_action_freq)
            all_actions.remove(th_pr_action)
        if len(all_actions) > 0:
            for action in all_actions:
                model_ids.append(model_type_to_id('correct'))
                model_types.append('correct')
                targets.append(target_return)
                pr_actions_selected.append(str(action))
                pr_actions_freqs.append(0)

        all_actions = env_class.get_all_possible_pr_actions()
        for th_pr_action, th_pr_action_freq in env_class.get_best_pr_action(target_return):
            model_ids.append(model_type_to_id('best'))
            model_types.append('best')
            targets.append(target_return)
            pr_actions_selected.append(str(th_pr_action))
            pr_actions_freqs.append(th_pr_action_freq)
            all_actions.remove(th_pr_action)
        if len(all_actions) > 0:
            for action in all_actions:
                model_ids.append(model_type_to_id('best'))
                model_types.append('best')
                targets.append(target_return)
                pr_actions_selected.append(str(action))
                pr_actions_freqs.append(0)

    df = pd.DataFrame(
        {
            "model_id": model_ids,
            "model_type": model_types,
            "target": targets,
            "action": pr_actions_selected,
            "freq": pr_actions_freqs,
        }
    )

    for target_return in env_class.get_eval_targets():
        filtered_df = df[df["target"] == target_return]
        filtered_df = filtered_df.sort_values(by=["model_id", "action"])

        plt.figure(figsize=(10, 6))
        ax = sns.barplot(x='action', y='freq', data=filtered_df, hue='model_type', palette=custom_palette)

        plt.xlabel("Action")
        plt.ylabel("Frequency")
        plt.title(f"Action Frequency for Env-{env_version} and Target Return {target_return}")
        plt.legend(fontsize="small", bbox_to_anchor=(1.22, 1.0))

        plt.show();

    print("\n ============================================================================================ \n")