In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
import pandas as pd
import seaborn as sns
import os
import re

from collections import defaultdict

## Configs

In [None]:
dirpath = os.getcwd()[:-8] + "evaluation_protocol"
dirname = "pre-results-thesis"

_env = "halfcheetah"
_dataset_type = "robust"
_dataset_version = "level"
_allow_baselines = True

## Data Prep

In [None]:
model_to_results = {}
output_folder_path = os.path.join(dirpath, dirname + "/" + _env)

json_files = set()
for root, dirs, _ in os.walk(output_folder_path):
    for dir in dirs:
        for new_root, new_dirs, files in os.walk(root + "/" + dir):
            for file in files:
                if file.endswith(".json"):
                    json_files.add(new_root + "/" + file)

for file in json_files:
    denom = file.split("/")[-3]
    param_name = file.split("/")[-2][8:].replace(" ", "")
    if param_name != "body-mass" and param_name != "friction":
        continue

    param = file.split("/")[-1].split(".json")[0]
    try:
        first_digit_idx = re.search(r'\d', denom).start()
        model_name = denom[:first_digit_idx]
        model_name = model_name[:-2] + "_best" if model_name.endswith("v") else model_name[:-1]
        model_type = ("ardt-vanilla" if "ardt-vanilla" in denom else ("ardt-multipart" if "ardt-multipart" in denom else "dt"))
        split_idx = 1 if model_type == "dt" else 2
        dataset_type = model_name.split("-")[split_idx][:model_name.split("-")[split_idx].find("train")][:-1]
        dataset_version = model_name.split("_")[-1]
    except AttributeError:
        model_name = denom
        model_type = denom
        dataset_type = 'online'
        dataset_version = None

    if dataset_type != _dataset_type or dataset_version != _dataset_version:
        if not (model_name in ['arrl', 'arrl-sgld'] and _allow_baselines):
            continue
    
    # just for sorting matters, hacky
    mt = model_type
    if model_type == "dt": mt = "aredt"
    
    if mt not in model_to_results:
        model_to_results[mt] = {"body-mass": defaultdict(list), "friction": defaultdict(list)}

    model_to_results[mt][param_name][float(param)].extend(pd.read_json(file)["ep_return"].to_list())

# sort by model name 
model_to_results = {k: v for k, v in sorted(model_to_results.items(), key=lambda item: item[0])}

## Per Model Lineplots

In [None]:
def get_good_model_name(model_name):
    if model_name == "aredt":
        return "DT"
    elif model_name == "ardt-vanilla":
        return "Vanilla-ARDT"
    elif model_name == "ardt-multipart":
        return "Multipart-ARDT"
    elif model_name == "arrl":
        return "AR-DDPG"
    elif model_name == "arrl-sgld":
        return "AR-DDPG-SGLD"
    
def get_good_param_name(param_name):
    if param_name == "body-mass":
        return "Body Mass"
    elif param_name == "friction":
        return "Friction"

In [None]:
fig, axes = plt.subplots(nrows=len(model_to_results), ncols=2, figsize=(15, 5 * len(model_to_results)))
params = ["body-mass", "friction"]

for row, (model_type, d) in enumerate(sorted(model_to_results.items(), key=lambda item: item[0], reverse=True)):
    for col, param in enumerate(params):
        pv = []
        means = []
        medians = []
        mins = []
        ninefive = []
        ninenine = []
        stdevs = []

        for key in sorted(d[param].keys()):
            pv.append(key)
            means.append(np.mean(d[param][key]))
            medians.append(np.median(d[param][key]))
            mins.append(np.min(d[param][key]))
            ninefive.append(np.percentile(d[param][key], 5))
            ninenine.append(np.percentile(d[param][key], 1))
            stdevs.append(np.std(d[param][key]))

        axes[row, col].set_title(f'{get_good_param_name(param)} Multipliers vs Returns for {get_good_model_name(model_type)}', fontsize=10.5)
        axes[row, col].set_xlabel(f'{get_good_param_name(param)} Multipliers')
        axes[row, col].set_ylabel('Returns')
        axes[row, col].set_xticks(np.arange(0.5, 2.1, 0.1))
        axes[row, col].set_xticklabels([f'{x:.2f}' for x in np.arange(0.5, 2.1, 0.1)], rotation=90)
        axes[row, col].plot(pv, means, color='green', label="Mean")
        axes[row, col].plot(pv, medians, color='blue', label="Median")
        axes[row, col].plot(pv, ninefive, color='orange', label="5th Percentile")
        axes[row, col].plot(pv, ninenine, color='brown', label="1st Percentile")
        axes[row, col].plot(pv, mins, color='red', label="Minimum")
        axes[row, col].scatter(pv, means, color='green', s=10)
        axes[row, col].scatter(pv, medians, color='blue', s=10)
        axes[row, col].scatter(pv, ninefive, color='orange', s=10)
        axes[row, col].scatter(pv, ninenine, color='brown', s=10)
        axes[row, col].scatter(pv, mins, color='red', s=10)
        axes[row, col].axvline(x=1.0, color='grey', linestyle='--')


handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=6, bbox_to_anchor=(0.5, 1.02))

plt.tight_layout()
plt.show()

## Per-Stat Gridplots

In [None]:
def get_stat_func(stat):
    if stat == "mean":
        return np.mean
    elif stat == "median":
        return np.median
    elif stat == "min":
        return np.min
    elif stat == "95th":
        return lambda x: np.percentile(x, 5)
    elif stat == "99th":
        return lambda x: np.percentile(x, 1)
    
def get_stat_name(stat):
    if stat == "mean":
        return "Mean"
    elif stat == "median":
        return "Median"
    elif stat == "min":
        return "Worst-Case"
    elif stat == "95th":
        return "5th Percentile"
    elif stat == "99th":
        return "1st Percentile"
    
def get_env_name(name):
    if name == "halfcheetah":
        return "HalfCheetah"
    elif name == "hopper":
        return "Hopper"
    elif name == "walker2d":
        return "Walker2D"
    

In [None]:
statistics = ['mean', 'min']
metrics = ["body-mass", "friction"]

### Lineplots

In [None]:
for statistic in statistics:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    func = get_stat_func(statistic)

    for metric_index, metric in enumerate(metrics):
        ax = axes[metric_index]

        for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
            pv = []
            statistics_values = []
            bands = []

            for key in sorted(d[metric].keys()):
                pv.append(key)
                statistics_values.append(func(d[metric][key]))
                bands.append(np.std(d[metric][key]) / np.sqrt(len(d[metric][key])))

            line, = ax.plot(pv, statistics_values, label=get_good_model_name(model_name))
            color = line.get_color()
            ax.scatter(pv, statistics_values, s=10, color=color)
            if statistic == "mean":
                ax.fill_between(pv, np.array(statistics_values) - np.array(bands), np.array(statistics_values) + np.array(bands), alpha=0.2)

        ax.set_title(f"{get_good_param_name(metric)} Multipliers vs {get_stat_name(statistic)} Returns ({get_env_name(_env)})", fontsize=10.5)
        ax.set_xlabel(f'{get_good_param_name(metric)} Multipliers')
        ax.set_ylabel("Returns")
        ax.tick_params(axis='x', rotation=90)
        ax.axvline(x=1.0, color='grey', linestyle='--')

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=5)

    plt.tight_layout()
    plt.show()


### Heatmaps

In [None]:
for statistic in statistics:
    func = get_stat_func(statistic)
    fig, axes = plt.subplots(1, len(metrics), figsize=(8, 4))
    dfs = []

    for metric_index, metric in enumerate(metrics):
        data = []
        unique_model_names = []
        unique_param_values = []
            
        for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
            if get_good_model_name(model_name) not in unique_model_names:
                unique_model_names.append(get_good_model_name(model_name))
                
            for key in sorted(d[metric].keys()):
                if key not in unique_param_values:
                    unique_param_values.append(key)
                
                statistics_value = func(d[metric][key])
                data.append({
                    "model_name": get_good_model_name(model_name),
                    "param_value": key,
                    "returns": statistics_value,
                })

        df = pd.DataFrame(data)
        dfs.append(df)

    for i, (df, metric) in enumerate(zip(dfs, metrics)):
        unique_model_names_reversed = unique_model_names[::-1]
        df['model_name'] = pd.Categorical(df['model_name'], categories=unique_model_names_reversed, ordered=True)
        df.sort_values('model_name', inplace=True)
        df['param_value'] = pd.Categorical(df['param_value'], categories=unique_param_values, ordered=True)
        heatmap_data = df.pivot_table(index='model_name', columns='param_value', values='returns')

        sns.heatmap(heatmap_data, cmap='coolwarm_r', ax=axes[i], cbar=False)
        axes[i].set_title(f'{get_stat_name(statistic)} Returns by Model and {get_good_param_name(metric)} Multipliers ({get_env_name(_env)})', fontsize=9.5)
        axes[i].set_xlabel(f'{get_good_param_name(metric)} Multipliers', fontsize=9)
        axes[i].set_ylabel('Models', fontsize=9)
        axes[i].tick_params(axis='y', labelsize=9)

        if i != 0:
            axes[i].set_ylabel('')        # Remove y-label on the plot to the right
            axes[i].set_yticklabels([])   # Remove y-tick labels on the plot to the right

    global_min = min(df['returns'].min() for df in dfs)
    global_min = int(global_min / 500) * 500 - 500
    global_max = max(df['returns'].max() for df in dfs)
    global_max = int(global_max / 500) * 500 + 500

    norm = plt.Normalize(vmin=global_min, vmax=global_max)
    sm = plt.cm.ScalarMappable(cmap='coolwarm_r', norm=norm)
    sm.set_array([])
    cbar_ax = fig.add_axes([1.02, 0.15, 0.02, 0.78])
    fig.colorbar(sm, cax=cbar_ax)

    plt.tight_layout()
    plt.show();