In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
import pandas as pd
import seaborn as sns
import os
import re

from collections import defaultdict

## Configs

In [None]:
dirpath = os.getcwd()[:-8] + "evaluation_protocol"
dirname = "final-results-thesis"

_env = "walker2d"
_dataset_type = "arrl"
_dataset_version = "high"
_allow_baselines = True

## Data Prep

In [None]:
model_to_results = {}
output_folder_path = os.path.join(dirpath, dirname + "/" + _env)

json_files = set()
for root, dirs, _ in os.walk(output_folder_path):
    for dir in dirs:
        for new_root, new_dirs, files in os.walk(root + "/" + dir):
            for file in files:
                if file.endswith(".json"):
                    json_files.add(new_root + "/" + file)

for file in json_files:
    denom = file.split("/")[-3]
    param_name = file.split("/")[-2][8:].replace(" ", "")
    if param_name == "body-mass" or param_name == "friction":
        continue
    else:
        param_name = "adv"

    param = file.split("/")[-1].split(".json")[0]
    if file.split("/")[-2].replace(" ", "") == "no-adv":
        param = 0.0
    
    try:
        first_digit_idx = re.search(r'-\d', denom).start()
        model_name = denom[:first_digit_idx]
        model_name = model_name[:-3] + "_best" if model_name[:-1].endswith("v") else model_name
        model_type = ("ardt-vanilla" if "ardt-vanilla" in denom else ("ardt-multipart" if "ardt-multipart" in denom else "dt"))
        split_idx = 1 if model_type == "dt" else 2
        dataset_type = model_name.split("-")[split_idx][:model_name.split("-")[split_idx].find("train")][:-1]
        dataset_version = model_name.split("_")[-1]
    except AttributeError:
        model_name = denom
        model_type = denom
        dataset_type = 'online'
        dataset_version = None

    if dataset_type != _dataset_type or dataset_version != _dataset_version:
        if not (model_name in ['arrl', 'arrl-sgld'] and _allow_baselines):
            continue

    if (model_type == "arrl" and _dataset_type == "arrl_sgld") or (model_type == "arrl-sgld" and _dataset_type == "arrl"):
        continue
    
    # just for sorting matters, hacky
    mt = model_type
    if model_type == "dt": mt = "aredt"
    
    if mt not in model_to_results:
        model_to_results[mt] = {"adv": defaultdict(list)}
    
    model_to_results[mt][param_name][float(param)].append(np.array(pd.read_json(file)["ep_return"].to_list()))

# sort by model name 
model_to_results = {k: v for k, v in sorted(model_to_results.items(), key=lambda item: item[0])}

## Per Model Lineplots

In [None]:
def get_good_model_name(model_name):
    if model_name == "aredt":
        return "DT"
    elif model_name == "ardt-vanilla":
        return "Vanilla-ARDT"
    elif model_name == "ardt-multipart":
        return "Multipart-ARDT"
    elif model_name == "arrl":
        return "AR-DDPG"
    elif model_name == "arrl-sgld":
        return "AR-DDPG-SGLD"

In [None]:
param = "adv"

if _allow_baselines:
    if not _dataset_type == "arrl_sgld" and not _dataset_type == "arrl":
        fig = plt.figure(figsize=(16, 8))
        gs = GridSpec(2, 3)
        ax1 = fig.add_subplot(gs[0, 0])  # First row, first two columns
        ax2 = fig.add_subplot(gs[0, 1])  # First row, last column
        ax3 = fig.add_subplot(gs[1, 0])  # Second row, first column
        ax4 = fig.add_subplot(gs[1, 1])  # Second row, second column
        ax5 = fig.add_subplot(gs[1, 2])  # Second row, third column
        axes = [[ax1, ax2], [ax3, ax4, ax5]]
        cols = [0, 1, 0, 1, 2]
        rows = [0, 0, 1, 1, 1]
    else:
        fig = plt.figure(figsize=(12, 8))
        gs = GridSpec(2, 2)
        ax1 = fig.add_subplot(gs[0, 0])  # First row, first two columns
        ax2 = fig.add_subplot(gs[0, 1])  # First row, last column
        ax3 = fig.add_subplot(gs[1, 0])  # Second row, first column
        ax4 = fig.add_subplot(gs[1, 1])  # Second row, second column
        axes = [[ax1, ax2], [ax3, ax4]]
        cols = [0, 1, 0, 1]
        rows = [0, 0, 1, 1]
else:
    fig = plt.figure(figsize=(16, 4))
    gs = GridSpec(1, 3)
    ax1 = fig.add_subplot(gs[0, 0])  # First row, first two columns
    ax2 = fig.add_subplot(gs[0, 1])  # First row, last column
    ax3 = fig.add_subplot(gs[0, 2])  # First row, third column
    axes = [[ax1, ax2, ax3]]
    cols = [0, 1, 2]
    rows = [0, 0, 0]

global_min = 5000
global_max = 0
ylabel_axes = []
all_axes = []

for i, (model_type, d) in enumerate(sorted(model_to_results.items(), key=lambda item: item[0], reverse=True)):
    pv = []
    means = []
    medians = []
    mins = []
    ninefive = []
    ninenine = []
    stdevs = []

    for key in sorted(d[param].keys()):
        pv.append(key)
        means.append(np.mean(d[param][key]))
        medians.append(np.median(d[param][key]))
        mins.append(np.min(np.percentile(np.array(d[param][key]), 5, axis=1, keepdims=True)))
        ninefive.append(np.percentile(d[param][key], 5))
        ninenine.append(np.percentile(d[param][key], 1))
        stdevs.append(np.std(d[param][key]))
        global_min = min(global_min, min(means), min(medians), min(mins), min(ninefive), min(ninenine))
        global_max = max(global_max, max(means), max(medians), max(mins), max(ninefive), max(ninenine))
    
    ax = axes[rows[i]][cols[i]]
    ax.set_title(f'Adversarial Returns for {get_good_model_name(model_type)}', fontsize=10.5)
    ax.set_xlabel(f'Adversary Strength (\u03B1)')
    ax.set_xticks(np.arange(0.0, 0.6, 0.1))
    ax.set_xticklabels([f'{x:.2f}' for x in np.arange(0.0, 0.6, 0.1)], rotation=90)
    if cols[i] == 0:
        ylabel_axes.append(ax)
    all_axes.append(ax)
    ax.plot(pv, means, color='green', label="Mean")
    ax.plot(pv, medians, color='blue', label="Median")
    ax.plot(pv, ninefive, color='orange', label="5th Percentile")
    ax.plot(pv, ninenine, color='brown', label="1st Percentile")
    ax.plot(pv, mins, color='red', label="Minimum")
    ax.scatter(pv, means, color='green', s=10)
    ax.scatter(pv, medians, color='blue', s=10)
    ax.scatter(pv, ninefive, color='orange', s=10)
    ax.scatter(pv, ninenine, color='brown', s=10)
    ax.scatter(pv, mins, color='red', s=10)

global_min = int(global_min / 500) * 500 - 500
global_max = int(global_max / 500) * 500 + 500
for ax in all_axes:
    ax.set_ylim(global_min, global_max)
    ax.set_ylabel('')
    ax.set_yticklabels([])

for ax in ylabel_axes:
    ax.set_ylabel('Returns')
    ax.set_yticks(np.arange(global_min, global_max, 500))
    ax.set_yticklabels([f'{x:.0f}' for x in np.arange(global_min, global_max, 500)])

if _allow_baselines:
    if not _dataset_type == "arrl_sgld" and not _dataset_type == "arrl":
        handles, labels = axes[0][0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', ncol=3, bbox_to_anchor=(0.845, 0.965))
    else:
        handles, labels = axes[0][0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', ncol=len(handles), bbox_to_anchor=(0.52, 1.05))
else:
    handles, labels = axes[0][0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=len(handles), bbox_to_anchor=(0.52, 1.10))

plt.tight_layout()
plt.show()

## Per-Stat Gridplots

In [None]:
def get_stat_func(stat):
    if stat == "mean":
        return np.mean
    elif stat == "median":
        return np.median
    elif stat == "min":
        return lambda x: np.min(np.percentile(np.array(x), 5, axis=1, keepdims=True))
    elif stat == "95th":
        return lambda x: np.percentile(x, 5)
    elif stat == "99th":
        return lambda x: np.percentile(x, 1)
    
def get_stat_name(stat):
    if stat == "mean":
        return "Mean"
    elif stat == "median":
        return "Median"
    elif stat == "min":
        return "Worst-Case"
    elif stat == "95th":
        return "5th Percentile"
    elif stat == "99th":
        return "1st Percentile"
    
def get_env_name(name):
    if name == "halfcheetah":
        return "HalfCheetah"
    elif name == "hopper":
        return "Hopper"
    elif name == "walker2d":
        return "Walker2D"
    

In [None]:
statistics = ['mean', '95th', 'min']

### Lineplots

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 3))

for ax, statistic in zip(axes, statistics):
    func = get_stat_func(statistic)

    for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
        pv = []
        statistics_values = []
        bands = []

        for key in sorted(d["adv"].keys()):
            pv.append(key)
            statistics_values.append(func(d["adv"][key]))
            bands.append(np.std(d["adv"][key]) / np.sqrt(len(d["adv"][key])))

        line, = ax.plot(pv, statistics_values, label=get_good_model_name(model_name))
        color = line.get_color()
        ax.scatter(pv, statistics_values, s=10, color=color)
        if statistic == "mean":
            ax.fill_between(pv, np.array(statistics_values) - np.array(bands), np.array(statistics_values) + np.array(bands), alpha=0.2)

        ax.set_title(f'{get_stat_name(statistic)} Adversarial Returns ({get_env_name(_env)})', fontsize=10)
        ax.set_xlabel(f'Adversary Strength (\u03B1)', fontsize=9.5)
        ax.set_ylabel('Returns', fontsize=9.5)
        ax.set_xticks(np.arange(0.0, 0.6, 0.1))
        ax.set_xticklabels([f'{x:.2f}' for x in np.arange(0.0, 0.6, 0.1)], rotation=90)
        ax.tick_params(axis='x', rotation=90)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=5)

plt.tight_layout()
plt.show()

### Barplots

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
handles, labels = [], []

for ax, statistic in zip(axes, statistics):
    func = get_stat_func(statistic)
    
    data = []
    for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
        for key in sorted(d["adv"].keys()):
            statistics_value = func(d["adv"][key])
            data.append({
                "model_name": get_good_model_name(model_name),
                "param_value": key,
                "returns": statistics_value,
            })
    
    df = pd.DataFrame(data)
    
    sns.barplot(x='model_name', y='returns', hue='param_value', data=df, ax=ax)
    ax.set_title(f'{get_stat_name(statistic)} Adversarial Returns ({get_env_name(_env)})', fontsize=10)
    ax.set_xlabel('Model Type', fontsize=9)
    ax.set_ylabel('Returns', fontsize=9)
    ax.get_legend().remove()

    h, l = ax.get_legend_handles_labels()
    handles += h
    labels += l

unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
handles, labels = zip(*unique)

fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=5)
plt.tight_layout()
plt.show()


In [None]:
default_palette = sns.color_palette()
shifted_palette = [default_palette[4]] + default_palette

data = []
for statistic in statistics:
    func = get_stat_func(statistic)
    
    for model_name, d in sorted(model_to_results.items(), key=lambda item: item[0], reverse=True):
        
        for key in sorted(d["adv"].keys()):
            statistics_value = func(d["adv"][key])
            
            data.append({
                "model_name": get_good_model_name(model_name),
                "statistic": get_stat_name(statistic),
                "returns": statistics_value,
                "param_value": key, 
            })

df = pd.DataFrame(data)

unique_param_values = df['param_value'].unique()
fig, axes = plt.subplots(1, len(unique_param_values), figsize=(15, 5))
handles, labels = [], []

for i, param_value in enumerate(unique_param_values):
    sns.barplot(
        x='statistic', 
        y='returns', 
        hue='model_name', 
        data=df[df['param_value'] == param_value], 
        ax=axes[i], 
        errorbar=None,
        palette=shifted_palette if _allow_baselines else default_palette
    )
    
    axes[i].set_title(f'Adversarial Returns for \u03B1={param_value} ({get_env_name(_env)})', fontsize=10)
    axes[i].set_xlabel('Statistic Type', fontsize=9)
    axes[i].set_ylabel('Returns', fontsize=9)

    handles, labels = axes[i].get_legend_handles_labels()
    axes[i].get_legend().remove()

fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=5)
plt.tight_layout()
plt.show()