In [16]:
import wandb
import numpy as np
import pandas as pd
from tabulate import tabulate
from scipy.stats import bootstrap
from collections import defaultdict

def nested_dict():
    return defaultdict(nested_dict)

def ci_95(data):
    if len(data) < 2:
        return (-1, -1)
    res = bootstrap((np.array(data),), np.mean, confidence_level=0.95, n_resamples=1000, method="basic")
    return int(res.confidence_interval.low), int(res.confidence_interval.high)

In [12]:
group_map = {}
env_translation = {
    "Pendulum": "BalancePendulum",
    "Quadrotor": "BalanceQuadrotor",
    "Energy System": "LoadBalanceHousehold"
}
safeguard_translation = {
    "No Safeguard": "NP",
    "Boundary Projection Base": "P",
    "Boundary Projection Regularised": "P",
    "Ray Mask Base": "ZRP-Lin",
    "Ray Mask Regularised": "ZRP-Lin",
    "Ray Mask Passthrough": "ZRP-Lin-PT",
    "Hyperbolic Ray Mask Base": "ZRP-Tanh",
    "Hyperbolic Ray Mask Regularised": "ZRP-Tanh",
}
algorithms = ["SHAC", "SAC", "PPO"]
for env in env_translation.keys():
    for algo in algorithms:
        for safeguard in safeguard_translation.keys():
            translation = safeguard_translation[safeguard]
            if env == "Energy System" and translation == "P":
                translation = translation.replace("P", "BP")
            elif env == "Energy System" and "ZRP" in translation:
                translation = translation.replace("ZRP", "ZRM")

            key = env_translation[env] + "-" + translation + "-" + algo
            key += "" if "Regularised" not in safeguard else "-Reg"

            group_map[key] = (env, algo, safeguard)

stuck_runs = {
    "Pendulum": {
        "SHAC": {
            "No Safeguard": ["likely-paper-5140"],
            "Boundary Projection Base": ["clean-universe-5155", "trim-shadow-5153"],
            "Boundary Projection Regularised": ["turtle-brulee-5032"],
            "Ray Mask Base": [],
            "Ray Mask Regularised": [],
            "Ray Mask Passthrough": ["happy-shadow-5213", "boysenberry-strudel-5040"],
            "Hyperbolic Ray Mask Base": ["different-terrain-5205"],
            "Hyperbolic Ray Mask Regularised": [""],
        },
        "PPO": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Ray Mask Base": [],
        },
        "SAC": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Ray Mask Base": [],
        }},
    "Quadrotor": {
        "SHAC": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Boundary Projection Regularised": [],
            "Ray Mask Base": [],
            "Ray Mask Regularised": [],
            "Ray Mask Passthrough": ["comfy-gorge-5215"],
            "Hyperbolic Ray Mask Base": [],
            "Hyperbolic Ray Mask Regularised": [],
        },
        "PPO": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Ray Mask Base": [],
        },
        "SAC": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Ray Mask Base": [],
        }},
    "Energy System": {
        "SHAC": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Boundary Projection Regularised": [],
            "Ray Mask Base": [],
            "Ray Mask Regularised": [],
            "Ray Mask Passthrough": [],
            "Hyperbolic Ray Mask Base": [],
            "Hyperbolic Ray Mask Regularised": [],
        },
        "PPO": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Ray Mask Base": [],
        },
        "SAC": {
            "No Safeguard": [],
            "Boundary Projection Base": [],
            "Ray Mask Base": [],
        }},
}

In [13]:
api = wandb.Api()

entity = "tim-walter-tum"
project = "Safe Differentiable Reinforcement Learning"

runs = api.runs(f"{entity}/{project}")

In [14]:
results = nested_dict()
timing_results = nested_dict()

for run in api.runs(f"{entity}/{project}"):
    try:
        group = run.group
        if group not in group_map:
            continue
        env, algo, safeguard = group_map[group]
        if run.id in stuck_runs[env][algo][safeguard]:
            continue
        if "Timing" in run.name:
            timing_results[env][algo][safeguard] = run.summary['_wandb']['runtime']
        else:
            rewards = run.history(keys=["eval/Episodic Reward"])["eval/Episodic Reward"].dropna().values
            steps = run.history(keys=["eval/Episodic Reward"])["_step"].dropna().values

            threshold = rewards[-1] * 0.95 if rewards[-1] > 0 else rewards[-1] * 1.05
            idx = np.where(rewards >= threshold)[0]
            nr_steps = steps[idx[0]] if len(idx) > 0 else steps[-1]

            if len(results[env][algo][safeguard]) == 0:
                results[env][algo][safeguard] = {"nr_steps": [], "reward":[]}

            results[env][algo][safeguard]["nr_steps"] += [nr_steps]
            results[env][algo][safeguard]["reward"] += [rewards[-1]]
    except Exception as e:
        print(e)
        print(run.group)
        print(run.name)

In [21]:
#Table 2
print(r"""
\begin{table*}
  \centering
  \caption{Comparison of learning algorithms in unsafe training.}
  \label{tab:unsafe}
  \begin{tabular}{l r r r r r}
    \toprule
    \multirow{2}{*}{Algorithm} & \multicolumn{2}{c}{\# Step} & \multicolumn{2}{c}{Reward} & \multirow{2}{*}{\# Stuck} \\
    & Mean & 95\% CI & Mean & 95\% CI & \\""")

for env in results.keys():
    print(rf"""
    \midrule
    \multicolumn{{6}}{{c}}{{{env}}}             \\
    """)
    for algo in algorithms:
        res = results[env][algo]["No Safeguard"]

        mean_steps = int(np.mean(res["nr_steps"]))
        low_steps, high_steps = ci_95(res["nr_steps"])

        mean_reward = int(np.mean(res["reward"]))
        low_reward, high_reward = ci_95(res["reward"])

        nr_stuck = len(stuck_runs[env][algo]["No Safeguard"])
        print(rf"""{algo} & {mean_steps} & [{low_steps}, {high_steps}] & {mean_reward} & [{low_reward}, {high_reward}] & {nr_stuck} \\""")
print(r"""    \bottomrule
  \end{tabular}
\end{table*}""")


\begin{table*}
  \centering
  \caption{Comparison of learning algorithms in unsafe training.}
  \label{tab:unsafe}
  \begin{tabular}{l r r r r r}
    \toprule
    \multirow{2}{*}{Algorithm} & \multicolumn{2}{c}{\# Step} & \multicolumn{2}{c}{Reward} & \multirow{2}{*}{\# Stuck} \\
    & Mean & 95\% CI & Mean & 95\% CI & \\

    \midrule
    \multicolumn{6}{c}{Pendulum}             \\
    
SHAC & 12288 & [10240, 14592] & -71 & [-133, 53] & 1 \\
SAC & 8513 & [6260, 10767] & -14 & [-15, -11] & 0 \\
PPO & 81600 & [81600, 81600] & -596 & [-1174, 300] & 0 \\

    \midrule
    \multicolumn{6}{c}{Quadrotor}             \\
    
SHAC & 20364 & [11544, 28070] & -155 & [-167, -137] & 0 \\
SAC & 80628 & [59094, 108686] & -1046 & [-1857, 145] & 0 \\
PPO & 80640 & [42504, 118720] & -1710 & [-2117, -1268] & 0 \\

    \midrule
    \multicolumn{6}{c}{Energy System}             \\
    
SHAC & 328680 & [125400, 500313] & -104339 & [-138590, -67952] & 0 \\
SAC & 450000 & [-1, -1] & -349419 & [-1, -1] & 0 \\