In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from matplotlib import patches
import numpy as np
import pandas as pd
import seaborn as sns
import os
import re

from collections import defaultdict

## Configs

In [None]:
dirpath = os.getcwd()[:-8] + "evaluation_protocol"
dirname = "final-results-thesis"

_env = "halfcheetah"

## Data Prep

In [None]:
model_to_results = {}
output_folder_path = os.path.join(dirpath, dirname + "/" + _env)

json_files = set()
for root, dirs, _ in os.walk(output_folder_path):
    for dir in dirs:
        for new_root, new_dirs, files in os.walk(root + "/" + dir):
            for file in files:
                if file.endswith(".json"):
                    json_files.add(new_root + "/" + file)

for file in json_files:
    denom = file.split("/")[-3]
    param_name = file.split("/")[-2][8:].replace(" ", "")
    if param_name != "body-mass" and param_name != "friction":
        continue

    param = file.split("/")[-1].split(".json")[0]
    try:
        first_digit_idx = re.search(r'-\d', denom).start()
        model_name = denom[:first_digit_idx]
        model_name = model_name[:-3] + "_best" if model_name[:-1].endswith("v") else model_name
        model_type = ("ardt-vanilla" if "ardt-vanilla" in denom else ("ardt-multipart" if "ardt-multipart" in denom else "dt"))
        split_idx = 1 if model_type == "dt" else 2
        dataset_type = model_name.split("-")[split_idx][:model_name.split("-")[split_idx].find("train")][:-1]
        dataset_version = model_name.split("_")[-1]
    except AttributeError:
        model_name = denom
        model_type = denom
        dataset_type = 'online'
        dataset_version = None

    if model_name == 'arrl-sgld':
        continue

    if not ((dataset_type == "robust" and dataset_version == "level") or (dataset_type == "arrl" and dataset_version == "high")):
        continue
    
    # just for sorting matters, hacky
    mt = model_type
    if model_type == "dt": mt = "aredt"
    mt = mt + "-" + dataset_type
    
    if mt not in model_to_results:
        model_to_results[mt] = {"body-mass": defaultdict(list), "friction": defaultdict(list)}

    model_to_results[mt][param_name][float(param)].extend(pd.read_json(file)["ep_return"].to_list())

# sort by model name 
model_to_results = {k: v for k, v in sorted(model_to_results.items(), key=lambda item: item[0])}

In [None]:
model_pairs = {"aredt": [], "ardt-vanilla": [], "ardt-multipart": []}
for model, results in model_to_results.items():
    if model.split("-")[0] == "aredt":
        model_pairs["aredt"].append({model: results})
    elif model.split("-")[0] == "ardt":
        if "vanilla" in model:
            model_pairs["ardt-vanilla"].append({model: results})
        else:
            model_pairs["ardt-multipart"].append({model: results})

## Per Model Lineplots

In [None]:
def get_good_model_name(model_name):
    if model_name.startswith("aredt"):
        if model_name.endswith("robust"):
            return "DT-MP"
        else:
            return "DT-SP"
    elif model_name.startswith("ardt-vanilla"):
        if model_name.endswith("robust"):
            return "Vanilla-ARDT-MP"
        else:
            return "Vanilla-ARDT-SP"
    elif model_name.startswith("ardt-multipart"):
        if model_name.endswith("robust"):
            return "Multipart-ARDT-MP"
        else:
            return "Multipart-ARDT-SP"
    
def get_good_param_name(param_name):
    if param_name == "body-mass":
        return "Body Mass"
    elif param_name == "friction":
        return "Friction"

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(16, 10))
params = ["body-mass", "friction"]

for row, (mt, mr) in enumerate(sorted(model_pairs.items(), key=lambda item: item[0], reverse=True)):
    global_min = 5000
    global_max = 0
    ylabel_axes = []
    all_axes = []

    curr_col = -1
    for col, param in enumerate(params):
        for col2, entry in enumerate(mr):
            curr_col += 1
            model_type = list(entry.keys())[0]
            d = list(entry.values())[0]

            pv = []
            means = []
            medians = []
            mins = []
            ninefive = []
            ninenine = []
            stdevs = []

            for key in sorted(d[param].keys()):
                pv.append(key)
                means.append(np.mean(d[param][key]))
                medians.append(np.median(d[param][key]))
                mins.append(np.min(d[param][key]))
                ninefive.append(np.percentile(d[param][key], 5))
                ninenine.append(np.percentile(d[param][key], 1))
                stdevs.append(np.std(d[param][key]))
                global_min = min(global_min, min(means), min(medians), min(mins), min(ninefive), min(ninenine))
                global_max = max(global_max, max(means), max(medians), max(mins), max(ninefive), max(ninenine))

            axes[row, curr_col].set_title(f'{get_good_param_name(param)} Multipliers vs Returns for {get_good_model_name(model_type)}', fontsize=9.5)
            axes[row, curr_col].set_xlabel(f'{get_good_param_name(param)} Multipliers')
            axes[row, curr_col].set_xticks(np.arange(0.5, 2.1, 0.1))
            axes[row, curr_col].set_xticklabels([f'{x:.2f}' for x in np.arange(0.5, 2.1, 0.1)], rotation=90)
            if curr_col == 0:
                ylabel_axes.append(axes[row, curr_col])
            all_axes.append(axes[row, curr_col])
            axes[row, curr_col].plot(pv, means, color='green', label="Mean")
            axes[row, curr_col].plot(pv, medians, color='blue', label="Median")
            axes[row, curr_col].plot(pv, ninefive, color='orange', label="5th Percentile")
            axes[row, curr_col].plot(pv, ninenine, color='brown', label="1st Percentile")
            axes[row, curr_col].plot(pv, mins, color='red', label="Minimum")
            axes[row, curr_col].scatter(pv, means, color='green', s=10)
            axes[row, curr_col].scatter(pv, medians, color='blue', s=10)
            axes[row, curr_col].scatter(pv, ninefive, color='orange', s=10)
            axes[row, curr_col].scatter(pv, ninenine, color='brown', s=10)
            axes[row, curr_col].scatter(pv, mins, color='red', s=10)
            axes[row, curr_col].axvline(x=1.0, color='grey', linestyle='--')

    global_min = int(global_min / 500) * 500 - 500
    global_max = int(global_max / 500) * 500 + 500
    for ax in all_axes:
        ax.set_ylim(global_min, global_max)
        ax.set_ylabel('')
        ax.set_yticklabels([])

    for ax in ylabel_axes:
        ax.set_ylabel('Returns')
        ax.set_yticks(np.arange(global_min, global_max, 500))
        ax.set_yticklabels([f'{x:.0f}' for x in np.arange(global_min, global_max, 500)])


handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=6, bbox_to_anchor=(0.5, 1.05))

plt.tight_layout()
plt.show()

## Per-Stat Gridplots

In [None]:
def get_stat_func(stat):
    if stat == "mean":
        return np.mean
    elif stat == "median":
        return np.median
    elif stat == "min":
        return np.min
    elif stat == "95th":
        return lambda x: np.percentile(x, 5)
    elif stat == "99th":
        return lambda x: np.percentile(x, 1)
    
def get_stat_name(stat):
    if stat == "mean":
        return "Mean"
    elif stat == "median":
        return "Median"
    elif stat == "min":
        return "Worst-Case"
    elif stat == "95th":
        return "5th Percentile"
    elif stat == "99th":
        return "1st Percentile"
    
def get_env_name(name):
    if name == "halfcheetah":
        return "HalfCheetah"
    elif name == "hopper":
        return "Hopper"
    elif name == "walker2d":
        return "Walker2D"
    

In [None]:
statistics = ['mean', '95th', 'min']
metrics = ["body-mass", "friction"]

### Lineplots

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 5))

for i, statistic in enumerate(statistics):
    func = get_stat_func(statistic)

    for metric_index, metric in enumerate(metrics):
        ax = axes[metric_index][i]

        for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
            pv = []
            statistics_values = []
            bands = []

            for key in sorted(d[metric].keys()):
                pv.append(key)
                statistics_values.append(func(d[metric][key]))
                bands.append(np.std(d[metric][key]) / np.sqrt(len(d[metric][key])))

            line, = ax.plot(pv, statistics_values, label=get_good_model_name(model_name))
            color = line.get_color()
            ax.scatter(pv, statistics_values, s=10, color=color)
            if statistic == "mean":
                ax.fill_between(pv, np.array(statistics_values) - np.array(bands), np.array(statistics_values) + np.array(bands), alpha=0.2)

        ax.set_title(f"{get_good_param_name(metric)} Multipliers vs {get_stat_name(statistic)} Returns ({get_env_name(_env)})", fontsize=9.5)
        ax.set_xlabel(f'{get_good_param_name(metric)} Multipliers')
        ax.set_ylabel("Returns")
        ax.set_ylim(min(0, axes[metric_index][i].get_ylim()[0]), axes[metric_index][i].get_ylim()[1])
        ax.tick_params(axis='x', rotation=90)
        ax.axvline(x=1.0, color='grey', linestyle='--')

# handles, labels = ax.get_legend_handles_labels()
# fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=6, fontsize=12)

plt.tight_layout()
plt.show()


In [None]:
fig, axes = plt.subplots(3, 6, figsize=(30, 10))

model_names = sorted([name for name in model_to_results.keys()], reverse=True)
model_pairs = [(model_names[i], model_names[i + 1]) for i in range(0, len(model_names), 2)]

for i, (model_name1, model_name2) in enumerate(model_pairs):
    last_col = -1

    for j, metric in enumerate(metrics):
        for k, statistic in enumerate(statistics):
            last_col += 1
            ax = axes[i][last_col]

            for model_name in [model_name1, model_name2]:
                func = get_stat_func(statistic)
                d = model_to_results[[name for name in model_to_results.keys() if name == model_name][0]]

                pv = []
                statistics_values = []
                bands = []

                for key in sorted(d[metric].keys()):
                    pv.append(key)
                    statistics_values.append(func(d[metric][key]))
                    bands.append(np.std(d[metric][key]) / np.sqrt(len(d[metric][key])))

                line, = ax.plot(pv, statistics_values, label=("Multi-Policy" if get_good_model_name(model_name)[-2:] == "MP" else "Single-Policy"))
                color = line.get_color()
                ax.scatter(pv, statistics_values, s=10, color=color)
                if statistic == "mean":
                    ax.fill_between(pv, np.array(statistics_values) - np.array(bands), np.array(statistics_values) + np.array(bands), alpha=0.2)

            ax.set_title(f"{get_good_param_name(metric)} Multipliers vs {get_stat_name(statistic)} Returns ({get_env_name(_env)})", fontsize=10.5)
            ax.set_xlabel(f'{get_good_param_name(metric)} Multipliers')
            ax.set_ylabel("Returns")
            ax.tick_params(axis='x', rotation=90)
            ax.axvline(x=1.0, color='grey', linestyle='--')

    row_label = f"{get_good_model_name(model_name)[:-3] if get_good_model_name(model_name)[:-3] != 'DT' else 'Decision Transformer'}"
    fig.text(0, (i * 0.33) + 0.2, row_label, va='center', ha='center', rotation='vertical', fontsize=12, fontweight='bold')

legend_handles_labels = [ax.get_legend_handles_labels() for ax in axes.flatten()]
handles, labels = legend_handles_labels[0]

unique_labels = []
unique_handles = []
for h, l in zip(handles, labels):
    if l not in unique_labels:
        unique_labels.append(l)
        unique_handles.append(h)

fig.legend(unique_handles, unique_labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=6)

plt.tight_layout()
plt.show()


In [None]:
fig, axes = plt.subplots(6, 3, figsize=(15, 20))

model_names = sorted([name for name in model_to_results.keys()], reverse=True)
model_pairs = [(model_names[i], model_names[i + 1]) for i in range(0, len(model_names), 2)]

for i, (model_name1, model_name2) in enumerate(model_pairs):
    last_col = -1

    for j, metric in enumerate(metrics):
        for k, statistic in enumerate(statistics):
            last_col += 1
            ax = axes[last_col][i]

            for model_name in [model_name1, model_name2]:
                func = get_stat_func(statistic)
                d = model_to_results[[name for name in model_to_results.keys() if name == model_name][0]]

                pv = []
                statistics_values = []
                bands = []

                for key in sorted(d[metric].keys()):
                    pv.append(key)
                    statistics_values.append(func(d[metric][key]))
                    bands.append(np.std(d[metric][key]) / np.sqrt(len(d[metric][key])))

                line, = ax.plot(pv, statistics_values, label=("Multi-Policy" if get_good_model_name(model_name)[-2:] == "MP" else "Single-Policy"))
                color = line.get_color()
                ax.scatter(pv, statistics_values, s=10, color=color)
                if statistic == "mean":
                    ax.fill_between(pv, np.array(statistics_values) - np.array(bands), np.array(statistics_values) + np.array(bands), alpha=0.2)

            ax.set_title(f"{get_good_param_name(metric)} Multipliers vs {get_stat_name(statistic)} Returns", fontsize=9.5)
            ax.set_xlabel(f'{get_good_param_name(metric)} Multipliers')
            ax.set_ylabel("Returns")
            ax.tick_params(axis='x', rotation=90)
            ax.axvline(x=1.0, color='grey', linestyle='--')

    row_label = f"{get_good_model_name(model_name)[:-3] if get_good_model_name(model_name)[:-3] != 'DT' else 'Decision Transformer'}"
    row_label += f" ({get_env_name(_env)})"
    fig.text((i * 0.33) + 0.191, 1.005, row_label, va='center', ha='center', fontsize=10, fontweight='bold')

legend_handles_labels = [ax.get_legend_handles_labels() for ax in axes.flatten()]
handles, labels = legend_handles_labels[0]

unique_labels = []
unique_handles = []
for h, l in zip(handles, labels):
    if l not in unique_labels:
        unique_labels.append(l)
        unique_handles.append(h)

fig.legend(unique_handles, unique_labels, loc='upper center', bbox_to_anchor=(0.5, 1.035), ncol=6)

plt.tight_layout()
plt.show()

### Heatmaps

In [None]:
fig, axes = plt.subplots(len(metrics), len(statistics), figsize=(16, 8))
all_dfs = []

for j, statistic in enumerate(statistics):
    func = get_stat_func(statistic)
    dfs = []

    for metric_index, metric in enumerate(metrics):
        data = []
        unique_model_names = []
        unique_param_values = []
            
        for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
            if get_good_model_name(model_name) not in unique_model_names:
                unique_model_names.append(get_good_model_name(model_name))
                
            for key in sorted(d[metric].keys()):
                if key not in unique_param_values:
                    unique_param_values.append(key)
                
                statistics_value = func(d[metric][key])
                data.append({
                    "model_name": get_good_model_name(model_name),
                    "param_value": key,
                    "returns": statistics_value,
                })

        df = pd.DataFrame(data)
        dfs.append(df)
        all_dfs.append(df)

    for i, (df, metric) in enumerate(zip(dfs, metrics)):
        unique_model_names_reversed = unique_model_names[::-1]
        df['model_name'] = pd.Categorical(df['model_name'], categories=unique_model_names_reversed, ordered=True)
        df.sort_values('model_name', inplace=True)
        df['param_value'] = pd.Categorical(df['param_value'], categories=unique_param_values, ordered=True)
        heatmap_data = df.pivot_table(index='model_name', columns='param_value', values='returns')

        sns.heatmap(heatmap_data, cmap='coolwarm_r', ax=axes[i][j])
        axes[i][j].set_title(f'{get_stat_name(statistic)} Returns by Model and {get_good_param_name(metric)} Multipliers ({get_env_name(_env)})', fontsize=9)
        axes[i][j].set_xlabel(f'{get_good_param_name(metric)} Multipliers', fontsize=9)

        if j == 0:
            axes[i][j].set_ylabel('Models', fontsize=9)
            axes[i][j].tick_params(axis='y', labelsize=9)
        else:
            axes[i][j].set_ylabel('')
            axes[i][j].set_yticklabels([])

plt.tight_layout()
plt.show()