In [None]:
import json
import os
import subprocess

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

from toyenv_one import OneStepEnvVOne
from toyenv_two import OneStepEnvVTwo
from toyenv_three import OneStepEnvVThree


def find_root_dir():
    try:
        root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode('utf-8')
    except Exception as e:
        root_dir = os.getcwd()[:os.getcwd().find('action-robust-decision-transformer')+len('action-robust-decision-transformer')]
    return root_dir + ('' if root_dir.endswith('action-robust-decision-transformer') else '/action-robust-decision-transformer') + "/codebase/toy_problem"

In [None]:
EVALS_PER_MODEL = 1024
envs_to_consider = ["v1", "v3"]

In [None]:
def model_type_to_id(model_type):
    if model_type == "dt":
        return 1
    elif model_type == "ardt-vanilla":
        return 2
    elif model_type == "ardt-multipart":
        return 3
    elif model_type == "best-conditional-worst-case":
        return 0
    else:
        raise ValueError(f"Invalid model type: {model_type}")


for env_idx, env_version in enumerate(envs_to_consider):
    results_dir = find_root_dir() + f"/results/toy-env-{env_version}"
    env_class = OneStepEnvVOne if env_version == "v1" else (OneStepEnvVTwo if env_version == "v2" else OneStepEnvVThree)

    model_ids = []
    model_types = []
    targets = []
    pr_actions_selected = []
    pr_actions_freqs = []

    for model_type in ['dt', 'ardt-vanilla', 'ardt-multipart']:
        results_path = f"{results_dir}/{model_type}/results.json"
        with open(results_path, "r") as f:
            results = json.load(f)["worstcase"]

        for entry in results:
            for pr_action, pr_action_freq in entry["pr_action_freqs"].items():
                model_ids.append(model_type_to_id(model_type))
                model_types.append(model_type)
                targets.append(entry["target_return"])
                pr_actions_selected.append(str(pr_action).replace(".", "").replace(" ", ", "))
                pr_actions_freqs.append(np.mean(pr_action_freq))

    for target_return in np.unique(targets):
        all_actions = env_class.get_all_possible_pr_actions()
        for th_pr_action, th_pr_action_freq in env_class.get_best_pr_action(target_return):
            model_ids.append(model_type_to_id('best-conditional-worst-case'))
            model_types.append('best-conditional-worst-case')
            targets.append(target_return)
            pr_actions_selected.append(str(th_pr_action))
            pr_actions_freqs.append(th_pr_action_freq)
            all_actions.remove(th_pr_action)
        if len(all_actions) > 0:
            for action in all_actions:
                model_ids.append(model_type_to_id('best-conditional-worst-case'))
                model_types.append('best-conditional-worst-case')
                targets.append(target_return)
                pr_actions_selected.append(str(action))
                pr_actions_freqs.append(0)

    df = pd.DataFrame(
        {
            "model_id": model_ids,
            "model_type": model_types,
            "target": targets,
            "action": pr_actions_selected,
            "freq": pr_actions_freqs,
        }
    )

    # just for some part of the bar to show
    df['freq'] = df['freq'].apply(lambda x: 0.003 if x < 0.003 else x)

    eval_targets = env_class.get_eval_targets()
    num_targets = len(eval_targets)
    
    num_cols = 3
    num_rows = int(np.ceil(num_targets / num_cols))


    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
    if num_rows == 1: axes = axes.reshape(1, -1)
    
    for target_idx, target_return in enumerate(eval_targets):
        ax = axes[target_idx // num_cols, target_idx % num_cols]

        filtered_df = df[df["target"] == target_return]
        filtered_df = filtered_df.sort_values(by=["model_id", "action"])
        
        custom_palette = ['#FFA500', '#66CCFF', '#3399FF', '#336699']
        sns.barplot(x='action', y='freq', data=filtered_df, hue='model_type', palette=custom_palette, ax=ax)
        
        ax.set_xlabel("Action")
        ax.set_ylabel("Frequency")
        ax.set_title(f"Toy Environment Version {min(2, int(env_version[-1]))}, Target Return {target_return}")
        ax.legend().set_visible(False)

    for i in range(num_targets, num_rows * num_cols):
        fig.delaxes(axes.flatten()[i])

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right', fontsize='small', bbox_to_anchor=(1.15, 0.95))
    
    plt.tight_layout()
    plt.show()

    print("\n ============================================================================================ \n")