# Analysis of all methods results

Here we explore the performance of RL, Planning, and RL+Planning across the different environments.

We will look at their mean performance (i.e. episode returns) across the different environments. Looking at both in-distribution (planning population matches the test population) and out-of-distribution (planning population does not match the test population) settings.

**Note** for RL and RL+Planning each experiment run was repeated 5 times (once for each RL policy seed), so we average the results across these 5 runs.

In [None]:
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from posggym_baselines.config import REPO_DIR

sys.path.insert(0, str(REPO_DIR / "baseline_exps"))
import exp_utils

sns.set_theme(style="white")
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

SAVE_RESULTS = True
PRL_VERSION = False

In [None]:
# Load environment information and other formatting stuff.

print("Experiment Environment:")
ALL_ENV_DATA = exp_utils.load_all_env_data()
for k in ALL_ENV_DATA:
    print(k)

NUM_ENVS = len(ALL_ENV_DATA)

# figure parameters
N_COLS = 2 if PRL_VERSION else 3
N_ROWS = (NUM_ENVS // N_COLS) + int(NUM_ENVS % N_COLS > 0)


RETURN_RANGES = {
    "CooperativeReaching-v0": (0, 1),
    "Driving-v1": (-1, 1),
    "LevelBasedForaging-v3": (0, 1),
    "PredatorPrey-v0": (0, 1),
    "PursuitEvasion-v1_i0": (-1, 1),
    "PursuitEvasion-v1_i1": (-1, 1)
}

# Algorithms which have both in- and out-of-distribution results
ID_VS_OOD_ALGS = ["IPOMCP", "POTMMCP", "RL-BR", "COMBINED"]

# Plot Formatting
alg_order = [
    "INTMCP",
    "IPOMCP",
    "POMCP",
    "POTMMCP",
    "RL-BR",
    "COMBINED",
]
# ref: https://seaborn.pydata.org/tutorial/color_palettes.html
alg_color_palette = sns.color_palette("tab10")
alg_palette={
    # planning methods
    "INTMCP": alg_color_palette[0],
    "IPOMCP": alg_color_palette[1],
    "POMCP": alg_color_palette[2],
    "POTMMCP": alg_color_palette[3],
    # rl
    "RL-BR": alg_color_palette[4],
    # combined
    "COMBINED": alg_color_palette[6],
}
alg_dashes = {
    # planning methods
    "INTMCP": (2, 2),
    "IPOMCP": (2, 2),
    "POMCP": (2, 2),
    "POTMMCP": (2, 2),
    # rl
    "RL-BR": "",
    # combined
    "COMBINED": (1, 1),
}
alg_dashes = {
    # planning methods
    "INTMCP": "",
    "IPOMCP": "",
    "POMCP": "",
    "POTMMCP": "",
    # rl
    "RL-BR": (2, 2),
    # combined
    "COMBINED": (4, 2),
}

# Load Data

In [None]:
# Utility function to add In/Out of Distribution labels
def get_in_out_dist_label(row):
    if row["Algorithm"] in ["POMCP", "INTMCP"]:
        return False
    return bool(row["Planning Population"] == row["Test Population"])

## Planning Data

In [None]:
planning_results_dfs = []
for env_id, env_data in ALL_ENV_DATA.items():
    env_planning_results = pd.read_csv(env_data.planning_results_file)
    env_planning_results["full_env_id"] = env_id
    env_planning_results["Type"] = "Planning"
    planning_results_dfs.append(env_planning_results)

planning_results_df = pd.concat(planning_results_dfs, ignore_index=True)
planning_results_df.rename(
    columns={
        "alg": "Algorithm",
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "return": "Return",
        "discounted_return": "Discounted Return",
        "len": "Episode Length",
    },
    inplace=True,
)

planning_results_df["In Distribution"] = planning_results_df.apply(
    get_in_out_dist_label, axis=1
)

planning_results_df.sort_values(
    by=[
        "Algorithm", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "search_time_limit"
    ], 
    inplace=True
)

max_search_time = planning_results_df["search_time_limit"].max()
print(planning_results_df["search_time_limit"].unique())

del planning_results_dfs

## RL Data

In [None]:
br_results_dfs = []
for full_env_id, env_data in ALL_ENV_DATA.items():
    env_br_results_df = pd.read_csv(env_data.rl_br_results_file)
    env_br_results_df["full_env_id"] = full_env_id
    env_br_results_df["Type"] = "Learning"
    br_results_dfs.append(env_br_results_df)

br_results_df = pd.concat(br_results_dfs, ignore_index=True)
br_results_df.rename(
    columns={
        "train_pop": "Planning Population",
        "test_pop": "Test Population",
        "return": "Return",
        "discounted_return": "Discounted Return",
        "train_seed": "rl_policy_seed",
        "len": "Episode Length",
    },
    inplace=True,
)

br_results_df["Algorithm"] = "RL-BR"
br_results_df["In Distribution"] = br_results_df.apply(
    get_in_out_dist_label, axis=1
)


br_results_df.sort_values(
    by=[
        "full_env_id", 
        "Planning Population", 
        "Test Population"
    ], 
    inplace=True
)

del br_results_dfs

## Combined (RL+Planning) Data

In [None]:
combined_results_dfs = []
for env_id, env_data in ALL_ENV_DATA.items():
    env_combined_results = pd.read_csv(env_data.combined_results_file)
    env_combined_results["full_env_id"] = env_id
    env_combined_results["Type"] = "Combined"
    combined_results_dfs.append(env_combined_results)

combined_results_df = pd.concat(combined_results_dfs, ignore_index=True)
combined_results_df.rename(
    columns={
        "alg": "Algorithm",
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "return": "Return",
        "discounted_return": "Discounted Return",
        "len": "Episode Length",
    },
    inplace=True,
)

combined_results_df["In Distribution"] = combined_results_df.apply(
    get_in_out_dist_label, axis=1
)

combined_results_df.sort_values(
    by=[
        "Algorithm", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "search_time_limit"
    ], 
    inplace=True
)

assert (combined_results_df["rl_policy_pop_id"] == combined_results_df["Planning Population"]).all()
combined_results_df.drop(columns=["rl_policy_pop_id"], inplace=True)
print(combined_results_df["search_time_limit"].unique())

del combined_results_dfs

## All Data

Here we combine planning, RL, and combined results into a single Dataframe.

In [None]:
# Need to add search_time_limit to br_results_df
# We duplicate the DF and set the search_time_limit to the max and min values
# This will produce a horizontal line in the plots :)
min_search_time = min(
    planning_results_df["search_time_limit"].min(),
    combined_results_df["search_time_limit"].min(),
)
max_search_time = max(
    planning_results_df["search_time_limit"].max(),
    combined_results_df["search_time_limit"].max(),
)
br_results_df["search_time_limit"] = max_search_time
min_br_results_df = br_results_df.copy(deep=True)
min_br_results_df["search_time_limit"] = min_search_time

# combine planning, combined, and rl results together
all_results_df = pd.concat(
    [
        planning_results_df, 
        combined_results_df, 
        br_results_df,
        min_br_results_df
    ],
    ignore_index=True
)

def normalize_return(row):
    min_return, max_return = RETURN_RANGES[row["full_env_id"]]
    return (row["Return"] - min_return) / (max_return - min_return)

all_results_df["Normalized Return"] = all_results_df.apply(normalize_return, axis=1)

all_results_df.sort_values(
    by=[
        "Algorithm", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "search_time_limit"
    ], 
    inplace=True
)

all_results_df["rl_policy_seed"] = all_results_df["rl_policy_seed"].astype("category")
all_results_df["num"] = all_results_df["num"].astype("category")

## Generalization Data

Compute results with Generalization Gap: in-distribution - out-of-distribution returns.

In [None]:
# remove unused columns
gg_df = all_results_df[
    [
        "Algorithm",
        "full_env_id",
        "Planning Population",
        "Test Population",
        "search_time_limit",
        "Return",
        "In Distribution",
        # "rl_policy_seed",
    ]
]
gg_df = gg_df[
    gg_df["Algorithm"].isin(["COMBINED", "RL-BR", "IPOMCP", "POTMMCP"])
]

# average over episodes
gg_gb = gg_df.groupby(
    [
        "Algorithm",
        "full_env_id",
        "Planning Population",
        "Test Population",
        "search_time_limit",
        "In Distribution",
    ]
).agg({"Return": ["mean", "count", "std", "var"]}).reset_index()
gg_gb.columns = ["_".join(x) if x[1] != '' else x[0] for x in gg_gb.columns.ravel()]
# gg_gb

gg_df = gg_gb[
    gg_gb["In Distribution"] == True
].merge(
    gg_gb[gg_gb["In Distribution"] == False],
    on=[
        "Algorithm",
        "full_env_id",
        "Planning Population",
        "search_time_limit",
        # "rl_policy_seed",
    ],
    suffixes=("", "_test"),
)
gg_df["Generalization Gap"] = gg_df["Return_mean"] - gg_df["Return_mean_test"]
gg_df["Generalization Gap CI"] = 1.96 * np.sqrt(
    gg_df["Return_var"] / gg_df["Return_count"]
    + gg_df["Return_var_test"] / gg_df["Return_count_test"]
)
print(gg_df["Algorithm"].unique())
gg_df


# Overall Results

Here we plot both in- and out-of-distribution performance averaged over environments.

In [None]:
for return_key in ["Return", "Normalized Return"]:
    df = all_results_df.groupby(
        [
            "Algorithm",
            "full_env_id",
            "search_time_limit",
            "In Distribution",
        ],
        as_index=False
    ).agg({f"{return_key}": ["mean", "count", "std", "var"]})
    df.columns = ["_".join(x).strip() if x[1] != '' else x[0] for x in df.columns.values]

    fig, axs = plt.subplots(
        nrows=1, 
        ncols=2, 
        figsize=(6, 4.125),
        sharey=True,
    )
    search_times = all_results_df["search_time_limit"].unique().tolist()
    search_times.sort()

    for in_dist, ax in zip([True, False], axs):
        df_dist = df[(df["In Distribution"] == in_dist)]

        for alg in alg_order:
            if alg not in df_dist["Algorithm"].unique():
                continue
            df_alg = df_dist[(df_dist["Algorithm"] == alg)]
            x = []
            y = []
            yerr = []
            for t in search_times:
                if t not in df_alg["search_time_limit"].unique():
                    continue
                df_alg_t = df_alg[(df_alg["search_time_limit"] == t)]
                mean = df_alg_t[f"{return_key}_mean"].mean()
                var = df_alg_t[f"{return_key}_var"].sum() / len(df_alg_t[f"{return_key}_count"])**2
                ci = 1.96 * np.sqrt(var / df_alg_t[f"{return_key}_count"].sum())
                x.append(t)
                y.append(mean)
                yerr.append(ci)
            ax.plot(
                x,
                y,
                label=alg,
                color=alg_palette[alg],
                linestyle="-" if isinstance(alg_dashes[alg], str) else (0, alg_dashes[alg]),
            )
            ax.fill_between(
                x,
                np.array(y) - np.array(yerr),
                np.array(y) + np.array(yerr),
                alpha=0.2,
                color=alg_palette[alg],
            )

        ax.set_xlabel("Search Time (s)")
        if in_dist:
            ax.set_ylabel(f"Mean {return_key}")
        ax.set_title("In Distribution" if in_dist else "Out of Distribution")
        ax.yaxis.set_tick_params(labelleft=True)

    lines, labels = axs[1].get_legend_handles_labels()
    if PRL_VERSION:
        fig.legend(
            lines,
            labels,
            loc="lower center",
            bbox_to_anchor=(0.5, 0.0),
            ncol=3,
            frameon=False,
            fontsize=12,
            # title="Algorithm",
        )
        # left, bottom, right, top
        fig.tight_layout(rect=(0.0, 0.125, 1.0, 1.0))
    else:
        fig.legend(
            lines,
            labels,
            loc="center left",
            bbox_to_anchor=(1.0, 0.5),
            ncol=1,
            frameon=False,
            fontsize=12,
            title="Algorithm",
        )
        fig.tight_layout()
        

    if SAVE_RESULTS:
        print("saving figure")
        fig.savefig(
            str(exp_utils.ENV_DATA_DIR / "figures" / f"all_methods_{return_key}_ID_and_OOD.pdf"), 
            bbox_inches="tight"
        )
    del fig, axs, df

# In-Distribution Results

Here we look specifically at the in-distribution performance of all methods in each environment separately.

## Mean Return

Dimensions:

- `y-axis`: mean episode return
- `x-axis`: search time
- `z-axis/hue`: algorithm

In [None]:
g = sns.relplot(
    data=all_results_df[all_results_df["In Distribution"] == True],
    x="search_time_limit",
    y="Return",
    hue="Algorithm",
    style="Algorithm",
    col="full_env_id",
    col_wrap=N_COLS,
    hue_order=[a for a in alg_order if a not in ["INTMCP", "POMCP"]],
    legend="full",
    palette=alg_palette,
    markers=False,
    dashes=alg_dashes,
    kind="line",
    height=2.75 if PRL_VERSION else 4,
    aspect=1.2 if PRL_VERSION else 1,
    facet_kws={
        "sharey": False, 
        "sharex": False
    },
)

for i, (col_key, ax) in enumerate(g.axes_dict.items()):
    ax.set_title(col_key)
    if i % N_COLS == 0:
        ax.set_ylabel("Mean Return")
    if i >= (N_ROWS -1) * N_COLS:
        ax.set_xlabel("Search Time (s)")

if PRL_VERSION:
    sns.move_legend(
        g, 
        loc="lower center",
        ncol=4,
        bbox_to_anchor=(0.5, 0.0),
        title=None,
        frameon=False
    )
    g.tight_layout(rect=(0.0, 0.04, 1.0, 1.0))

if SAVE_RESULTS:
    print("saving figure")
    g.savefig(
        exp_utils.ENV_DATA_DIR / "figures" / "all_methods_returns_ID_allenvs.pdf", 
        bbox_inches="tight"
    )
del g

In [None]:
# Same as above except a seperate figure for each env
for env_id in all_results_df["full_env_id"].unique():
    g = sns.relplot(
        data=all_results_df[
            (all_results_df["In Distribution"] == True)
            & (all_results_df["full_env_id"] == env_id)
        ],
        x="search_time_limit",
        y="Return",
        hue="Algorithm",
        style="Algorithm",
        hue_order=[a for a in alg_order if a not in ["INTMCP", "POMCP"]],
        legend="full",
        palette=alg_palette,
        markers=False,
        dashes=alg_dashes,
        kind="line",
        height=4,
        aspect=1,
    )
    g.ax.set_title(env_id)
    g.ax.set_ylabel("Mean Return")
    g.ax.set_xlabel("Search Time (s)")

    if SAVE_RESULTS:
        print("saving figure")
        g.savefig(
            exp_utils.ENV_DATA_DIR / "figures" / f"all_methods_returns_ID_{env_id}.pdf", 
            bbox_inches="tight"
        )
del g

In [None]:
g = sns.relplot(
    data=all_results_df[all_results_df["In Distribution"] == True],
    x="search_time_limit",
    y="Return",
    hue="Algorithm",
    style="Algorithm",
    row="full_env_id",
    col="Planning Population",
    hue_order=[a for a in alg_order if a not in ["INTMCP", "POMCP"]],
    legend="full",
    palette=alg_palette,
    markers=False,
    dashes=alg_dashes,
    kind="line",
    height=3,
    aspect=1.5,
    facet_kws={
        "sharey": "row", 
        "sharex": True
    },
)

for col_key, ax in g.axes_dict.items():
    ax.set_title(col_key)
    ax.set_xlabel("Search Time Limit (s)")
    ax.set_ylabel("Mean Return")

del g

## Average Episode Length

Here we plot the average episode length of each algorithm in each environment. We show results using the maximum search time for the planning and combine methods.

Dimensions:

- `y-axis`: mean episode length
- `x-axis`: environment
- `z-axis/hue`: algorithm

In [None]:
g = sns.catplot(
    data=all_results_df[
        (all_results_df["In Distribution"] == True)
        & (all_results_df["search_time_limit"] == max_search_time)
    ],
    x="full_env_id",
    y="Episode Length",
    hue="Algorithm",
    hue_order=[a for a in alg_order if a not in ["INTMCP", "POMCP"]],
    legend="full",
    palette=sns.color_palette("colorblind")[:4],
    kind="bar",
    height=4,
    aspect=1.5,
)

g.set_xticklabels(rotation=90)
g.set_axis_labels("", "Mean Episode Length")

if SAVE_RESULTS:
    print("saving figure")
    g.savefig(
        exp_utils.ENV_DATA_DIR / "figures" / "all_methods_lens_ID_allenvs.pdf", 
        bbox_inches="tight"
    )
del g

# Out-of-Distribution Results

Here we look at the out-of-distribution performance of all methods in each environment separately.

## Mean Return

Dimensions:

- `y-axis`: mean episode return
- `x-axis`: search time
- `z-axis/hue`: algorithm

In [None]:
g = sns.relplot(
    data=all_results_df[all_results_df["In Distribution"] == False],
    x="search_time_limit",
    y="Return",
    hue="Algorithm",
    style="Algorithm",
    col="full_env_id",
    col_wrap=N_COLS,
    hue_order=alg_order,
    legend="full",
    palette=alg_palette,
    markers=False,
    dashes=alg_dashes,
    kind="line",
    height=2.75 if PRL_VERSION else 4,
    aspect=1.2 if PRL_VERSION else 1,
    facet_kws={
        "sharey": False, 
        "sharex": False
    },
)

for i, (col_key, ax) in enumerate(g.axes_dict.items()):
    ax.set_title(col_key)
    if i % N_COLS == 0:
        ax.set_ylabel("Mean Return")
    if i >= (N_ROWS - 1) * N_COLS:
        ax.set_xlabel("Search Time (s)")

if PRL_VERSION:
    sns.move_legend(
        g, 
        loc="lower center",
        ncol=3,
        bbox_to_anchor=(0.5, 0.0),
        title=None,
        frameon=False
    )
    g.tight_layout(rect=(0.0, 0.075, 1.0, 1.0))

if SAVE_RESULTS:
    print("saving figure")
    g.savefig(
        exp_utils.ENV_DATA_DIR / "figures" / "all_methods_returns_OOD_allenvs.pdf", 
        bbox_inches="tight"
    )
del g

# Generalization (In- vs Out-of-Distribution) Results

Here we look at the gap between in-distribution and out-of-distribution performance of each algorithm.

## Mean Return

In the first plot we show the mean return of each algorithm for in and out of distribution settings, averaged across all environments.

- `x-axis`: Search Time
- `y-axis`: Mean Return
- `hue/z-axis`: In Distribution (or not)
- `col`: Algorithm

In [None]:
for return_key in ["Return", "Normalized Return"]:
    df = all_results_df[
            all_results_df["Algorithm"].isin(ID_VS_OOD_ALGS)
        ].groupby(
            [
                "Algorithm",
                "full_env_id",
                "search_time_limit",
                "In Distribution",
            ],
            as_index=False
        ).agg({f"{return_key}": ["mean", "count", "std", "var"]})
    df.columns = ["_".join(x).strip() if x[1] != '' else x[0] for x in df.columns.values]

    fig, axs = plt.subplots(
        nrows=1, 
        ncols=len(ID_VS_OOD_ALGS), 
        figsize=(10, 4.125),
        sharey=True,
        squeeze=False
    )
    search_times = all_results_df["search_time_limit"].unique().tolist()
    search_times.sort()

    i = 0
    for alg in alg_order:
        if alg not in df["Algorithm"].unique():
            continue
        df_alg = df[(df["Algorithm"] == alg)]
        ax = axs[0][i]
        i += 1

        for in_dist in [True, False]:
            df_dist = df_alg[(df_alg["In Distribution"] == in_dist)]
            
            x = []
            y = []
            yerr = []
            for t in search_times:
                if t not in df_dist["search_time_limit"].unique():
                    continue
                df_dist_t = df_dist[(df_dist["search_time_limit"] == t)]
                mean = df_dist_t[f"{return_key}_mean"].mean()
                var = df_dist_t[f"{return_key}_var"].sum() / len(df_dist_t[f"{return_key}_count"])**2
                ci = 1.96 * np.sqrt(var / df_dist_t[f"{return_key}_count"].sum())
                x.append(t)
                y.append(mean)
                yerr.append(ci)
            ax.plot(x, y, label=in_dist)
            ax.fill_between(
                x,
                np.array(y) - np.array(yerr),
                np.array(y) + np.array(yerr),
                alpha=0.2,
            )

        ax.set_xlabel("Search Time (s)")
        if i == 1:
            ax.set_ylabel(f"Mean {return_key}")
        ax.set_title(alg)
        ax.yaxis.set_tick_params(labelleft=True)

    lines, labels = axs[0][0].get_legend_handles_labels()
    if PRL_VERSION:
        fig.legend(
            lines,
            labels,
            loc="lower center",
            bbox_to_anchor=(0.5, 0.0),
            ncol=3,
            frameon=False,
            fontsize=12,
            title="In Dist.",
        )
        # left, bottom, right, top
        fig.tight_layout(rect=(0.0, 0.125, 1.0, 1.0))
    else:
        fig.legend(
            lines,
            labels,
            loc="center left",
            bbox_to_anchor=(1.0, 0.5),
            ncol=1,
            frameon=False,
            fontsize=12,
            title="In Dist.",
        )
        fig.tight_layout()

    if SAVE_RESULTS:
        print("saving figure")
        fig.savefig(
            str(exp_utils.ENV_DATA_DIR / "figures" / f"all_methods_{return_key}_ID_vs_OOD.pdf"), 
            bbox_inches="tight"
        )

    del fig, axs, df


## Generalization Gap

Here we plot ID - OOD performance for each algorithm. We show results for planning and combined methods using the maximum search time.

In [None]:
# remove unused columns
return_key = "Normalized Return"
gg_df = all_results_df[
    all_results_df["search_time_limit"] == max_search_time
][
    [
        "Algorithm",
        "full_env_id",
        "Planning Population",
        "Test Population",
        "search_time_limit",
        return_key,
        "In Distribution",
        # "rl_policy_seed",
    ]
]
gg_df = gg_df[
    gg_df["Algorithm"].isin(["COMBINED", "RL-BR", "IPOMCP", "POTMMCP"])
]

# average over episodes
gg_gb = gg_df.groupby(
    [
        "Algorithm",
        "full_env_id",
        "Planning Population",
        "Test Population",
        "search_time_limit",
        "In Distribution",
    ]
).agg({return_key: ["mean", "count", "std", "var"]}).reset_index()
gg_gb.columns = ["_".join(x) if x[1] != '' else x[0] for x in gg_gb.columns.ravel()]
# gg_gb

# merge ID and OOD rows into single row
gg_df = gg_gb[
    gg_gb["In Distribution"] == True
].merge(
    gg_gb[gg_gb["In Distribution"] == False],
    on=[
        "Algorithm",
        "full_env_id",
        "Planning Population",
        "search_time_limit",
        # "rl_policy_seed",
    ],
    suffixes=("", "_test"),
)

# compute generalization gap
gg_df["Generalization Gap"] = (
    gg_df[f"{return_key}_mean"] - gg_df[f"{return_key}_mean_test"]
)
# compute CI by combining variance of ID and OOD
gg_df["Generalization Gap CI"] = 1.96 * np.sqrt(
    ((gg_df[f"{return_key}_var"] + gg_df[f"{return_key}_var_test"]))
) / (
    gg_df[[f"{return_key}_count", f"{return_key}_count_test"]].min(axis=1)
)
print(gg_df["Algorithm"].unique())

In [None]:
gg_key = "Generalization Gap"
# Average over populations
overall_gg_df = gg_df.groupby(
    [
        "Algorithm",
        "full_env_id",
    ],
    as_index=False
).agg({
    gg_key: ["mean"],
    f"{gg_key} CI": ["mean"]
})
overall_gg_df.columns = [
    "_".join(x).strip() if x[1] != '' else x[0] 
    for x in overall_gg_df.columns.values
]

fig, ax = plt.subplots(
    nrows=1, 
    ncols=1, 
    figsize=(4, 4),
    sharey=True,
)

x = np.arange(len(overall_gg_df["Algorithm"].unique()))
ys = []
yerrs = []
for alg in ID_VS_OOD_ALGS:
    df_alg = overall_gg_df[(overall_gg_df["Algorithm"] == alg)]
    ys.append(df_alg[f"{gg_key}_mean"].mean())
    yerrs.append(df_alg[f"{gg_key} CI_mean"].mean())
ax.bar(
    x,
    ys,
    yerr=yerrs,
    tick_label=ID_VS_OOD_ALGS,
    color=sns.color_palette("colorblind"),
)

ax.set_xticklabels(ID_VS_OOD_ALGS, rotation=90)
ax.set_xlabel("Algorithm")
ax.set_ylabel(gg_key)
ax.yaxis.set_tick_params(labelleft=True)
fig.tight_layout()

if SAVE_RESULTS:
    print("saving figure")
    fig.savefig(
        str(exp_utils.ENV_DATA_DIR / "figures" / f"all_methods_{gg_key}.pdf"), 
        bbox_inches="tight"
    )
del fig, ax, overall_gg_df

## Driving-v1

Here we do a bit of a deeper dive into the performance of each algorithm in the  `Driving-v1` environment.

In particular we can look at the distribution of episode lengths as well as the proportion of episodes that ended with a crash, timeout, or success.

In [None]:
driving_df = all_results_df[
    (all_results_df["full_env_id"] == "Driving-v1")
    & (all_results_df["In Distribution"] == True)
]

max_search_time = driving_df["search_time_limit"].max()
print(max_search_time)

driving_df = driving_df[
    (driving_df["search_time_limit"] == max_search_time)
]

crashed_df = driving_df[
    (driving_df["Return"] < 0)
]
success_df = driving_df[
    (driving_df["Return"] >= 0.99)
]
timedout_df = driving_df[
    (driving_df["Return"] < 0.99)
    & (driving_df["Return"] >= 0)
]
assert len(driving_df) - len(crashed_df) - len(success_df) - len(timedout_df) == 0

for alg in alg_order:
    if alg not in driving_df["Algorithm"].unique():
        continue
    print(f"\n{alg}")
    total_eps = len(driving_df[driving_df["Algorithm"] == alg])
    num_crashed = len(crashed_df[crashed_df["Algorithm"] == alg])
    num_success = len(success_df[success_df["Algorithm"] == alg])
    num_timeout = len(timedout_df[timedout_df["Algorithm"] == alg])

    print(f"Total:    {total_eps}")
    print(f"Crashed:  {num_crashed} ({num_crashed / total_eps * 100:.2f}%)")
    print(f"Success:  {num_success} ({num_success / total_eps * 100:.2f}%)")
    print(f"Timedout: {num_timeout} ({num_timeout / total_eps * 100:.2f}%)")

# latex table output
print("\n\n")
print("Algorithm & Crashed & Success & Timedout \\\\")
for alg in alg_order:
    if alg not in driving_df["Algorithm"].unique():
        continue
    total_eps = len(driving_df[driving_df["Algorithm"] == alg])
    num_crashed = len(crashed_df[crashed_df["Algorithm"] == alg])
    num_success = len(success_df[success_df["Algorithm"] == alg])
    num_timeout = len(timedout_df[timedout_df["Algorithm"] == alg])
    print(f"{alg} & ${num_crashed / total_eps * 100:.2f}\%$ & ${num_success / total_eps * 100:.2f}\%$ & ${num_timeout / total_eps * 100:.2f}\%$ \\\\")

# Belief Statistics

Here we plot the belief accuracy of the combined and POTMMCP methods.

## Combined Beliefs

In [None]:
combined_belief_dfs = []
for full_env_id in ALL_ENV_DATA:
    print(full_env_id)
    env_belief_results = pd.read_csv(
        ALL_ENV_DATA[full_env_id].env_data_dir / "combined_belief_per_step_results.csv"
    )
    env_belief_results["full_env_id"] = full_env_id

    env_step_limit = max(
        int(c.split("_")[-1]) 
        for c in env_belief_results.columns 
        if c.startswith("belief_state_acc_")
    )
    env_belief_results["step_limit"] = env_step_limit

    # drop unused columns
    # keep only id columns and per step belief stats (remove mean belief stats)
    id_cols = [
        "alg", 
        "full_env_id", 
        "planning_pop_id", 
        "test_pop_id", 
        "rl_policy_seed",
        "search_time_limit", 
        "step_limit",
        "num"
    ]
    cols_to_keep = [*id_cols] + [
        c for c in env_belief_results.columns 
        if c.startswith("belief_") and c.split("_")[-1] != "acc"
    ]
    env_belief_results.drop(
        columns=[c for c in env_belief_results.columns if c not in cols_to_keep], 
        inplace=True
    )

    # convert from wide to long format for per step belief stats
    stub_names = [
        k for k in [
            "belief_state_acc", 
            "belief_history_acc", 
            "belief_action_acc", 
            "belief_policy_acc"
        ]
        if any(c.startswith(k) for c in env_belief_results.columns)
    ]

    env_belief_results = pd.wide_to_long(
        env_belief_results, 
        stubnames=stub_names, 
        i=id_cols, 
        j="step", 
        sep="_"
    ).reset_index()
    combined_belief_dfs.append(env_belief_results)

combined_belief_df = pd.concat(combined_belief_dfs, ignore_index=True)
combined_belief_df.rename(
    columns={
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "belief_state_acc": "Belief State Accuracy",
        "belief_history_acc": "Belief History Accuracy",
        "belief_action_acc": "Belief Action Accuracy",
        "belief_policy_acc": "Belief Policy Accuracy",
        "search_time_limit": "Search Time (s)",
    },
    inplace=True,
)

# Add In/Out of Distribution labels
def get_in_out_dist_label(row):
    return row["Planning Population"] == row["Test Population"]

combined_belief_df["In Distribution"] = combined_belief_df.apply(
    get_in_out_dist_label, axis=1
)

combined_belief_df.sort_values(
    by=[
        "alg", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "Search Time (s)"
    ], 
    inplace=True
)

print(combined_belief_df["Search Time (s)"].unique())
print(combined_belief_df["full_env_id"].unique())
print(combined_belief_df["step"].min(), combined_belief_df["step"].max())

# for c in belief_results_df.columns:
#     print(c)

In [None]:
# Plot in shared plot
for belief_stat_key in [
    "Belief State Accuracy",
    # "Belief History Accuracy",
    "Belief Action Accuracy",
    "Belief Policy Accuracy",
]:
    print(belief_stat_key)

    if belief_stat_key == "Belief Policy Accuracy":
        g = sns.relplot(
            data=combined_belief_df[combined_belief_df["In Distribution"] == True],
            x="step",
            y=belief_stat_key,
            hue="Search Time (s)",
            col_wrap=N_COLS,
            kind="line",
            col="full_env_id",
            height=3,
            aspect=1.5,
            facet_kws={
                "sharey": False,
                "sharex": False, 
            },
        )
    else:
        g = sns.relplot(
            data=combined_belief_df,
            x="step",
            y=belief_stat_key,
            hue="Search Time (s)",
            style="In Distribution",
            style_order=[True, False],
            col_wrap=N_COLS,
            kind="line",
            col="full_env_id",
            height=3,
            aspect=1.5,
            facet_kws={
                "sharey": False,
                "sharex": False, 
            },
        )

    for i, (col_key, ax) in enumerate(g.axes_dict.items()):
        ax.set_title(f"{col_key}")
        if i >= (N_ROWS-1) * N_COLS:
            ax.set_xlabel("Episode Step")


    if PRL_VERSION:
        sns.move_legend(
            g, 
            loc="lower center",
            ncol=3,
            bbox_to_anchor=(0.5, 0.0),
            title=None,
            frameon=False
        )
        g.tight_layout(rect=(0.0, 0.1, 1.0, 1.0))

    if True:
        print(f"saving {belief_stat_key} figure")
        g.savefig(
            exp_utils.ENV_DATA_DIR / "figures" / f"per_step_{belief_stat_key}.pdf", 
            bbox_inches="tight"
        )

    del g

## Planning Beliefs

In [None]:
planning_belief_dfs = []
for full_env_id in ALL_ENV_DATA:
    print(full_env_id)
    env_belief_results = pd.read_csv(
        ALL_ENV_DATA[full_env_id].env_data_dir / "planning_belief_per_step_results.csv"
    )
    env_belief_results["full_env_id"] = full_env_id

    env_step_limit = max(
        int(c.split("_")[-1]) 
        for c in env_belief_results.columns 
        if c.startswith("belief_state_acc_")
    )
    env_belief_results["step_limit"] = env_step_limit

    # drop unused columns
    # keep only id columns and per step belief stats (remove mean belief stats)
    id_cols = [
        "alg", 
        "full_env_id", 
        "planning_pop_id", 
        "test_pop_id", 
        "search_time_limit", 
        "step_limit",
        "num"
    ]
    cols_to_keep = [*id_cols] + [
        c for c in env_belief_results.columns 
        if c.startswith("belief_") and c.split("_")[-1] != "acc"
    ]
    env_belief_results.drop(
        columns=[c for c in env_belief_results.columns if c not in cols_to_keep], 
        inplace=True
    )

    # convert from wide to long format for per step belief stats
    stub_names = [
        k for k in [
            "belief_state_acc", 
            "belief_history_acc", 
            "belief_action_acc", 
            "belief_policy_acc"
        ]
        if any(c.startswith(k) for c in env_belief_results.columns)
    ]

    env_belief_results = pd.wide_to_long(
        env_belief_results, 
        stubnames=stub_names, 
        i=id_cols, 
        j="step", 
        sep="_"
    ).reset_index()
    planning_belief_dfs.append(env_belief_results)

planning_belief_df = pd.concat(planning_belief_dfs, ignore_index=True)
planning_belief_df.rename(
    columns={
        "alg": "Algorithm",
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "belief_state_acc": "Belief State Accuracy",
        "belief_history_acc": "Belief History Accuracy",
        "belief_action_acc": "Belief Action Accuracy",
        "belief_policy_acc": "Belief Policy Accuracy",
        "search_time_limit": "Search Time (s)",
    },
    inplace=True,
)

# Add In/Out of Distribution labels
def get_in_out_dist_label(row):
    if row["Algorithm"] in ["POMCP", "INTMCP"]:
        return False
    return bool(row["Planning Population"] == row["Test Population"])

planning_belief_df["In Distribution"] = planning_belief_df.apply(
    get_in_out_dist_label, axis=1
)

planning_belief_df.sort_values(
    by=[
        "Algorithm", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "Search Time (s)"
    ], 
    inplace=True
)

print(planning_belief_df["Search Time (s)"].unique())
print(planning_belief_df["full_env_id"].unique())
print(planning_belief_df["step"].min(), planning_belief_df["step"].max())

In [None]:
for alg in planning_belief_df["Algorithm"].unique():
    if alg in ["POMCP", "INTMCP"]:
                continue

    alg_df = planning_belief_df[planning_belief_df["Algorithm"] == alg]
    for belief_stat_key in [
        "Belief State Accuracy",
        # "Belief History Accuracy",
        "Belief Action Accuracy",
        "Belief Policy Accuracy",
    ]:
        print(alg, belief_stat_key)


        if belief_stat_key == "Belief Policy Accuracy":
            if alg in ["POMCP", "INTMCP"]:
                continue
            g = sns.relplot(
                data=alg_df[alg_df["In Distribution"] == True],
                x="step",
                y=belief_stat_key,
                hue="Search Time (s)",
                col_wrap=N_COLS,
                kind="line",
                col="full_env_id",
                height=3,
                aspect=1.5,
                facet_kws={
                    "sharey": False,
                    "sharex": False, 
                },
            )
        else:
            g = sns.relplot(
                data=alg_df,
                x="step",
                y=belief_stat_key,
                hue="Search Time (s)",
                style="In Distribution",
                style_order=[True, False],
                col_wrap=N_COLS,
                kind="line",
                col="full_env_id",
                height=3,
                aspect=1.5,
                facet_kws={
                    "sharey": False,
                    "sharex": False, 
                },
            )

        for i, (col_key, ax) in enumerate(g.axes_dict.items()):
            ax.set_title(f"{col_key}")
            if i >= N_COLS:
                ax.set_xlabel("Episode Step")

        if True:
            print(f"saving {alg} {belief_stat_key} figure")
            g.savefig(
                exp_utils.ENV_DATA_DIR / "figures" / f"{alg}_per_step_{belief_stat_key}.pdf", 
                bbox_inches="tight"
            )

        del g

# Search Statistics

Here we look at various statistics of the search process for each algorithm.

Each figure is a different statistic, each line is a different environment since we expect some differences between environments based on things like average steps to terminal state. In and out of distribution results are grouped together since we expect and see no different in search statistics between in vs out of distribution.

- `x-axis`: Search time
- `y-axis`: Search Statistic Values
- `col/figures`: Environment

In [None]:
if False:
    for stat_key in [
        "search_time",
        "update_time",
        "reinvigoration_time",
        "evaluation_time",
        "policy_calls",
        "inference_time",
        "search_depth",
        "num_sims",
        "mem_usage",
        "min_value",
        "max_value",
    ]:
        print(stat_key)
        stat_df = all_results_df[all_results_df["Algorithm"] != "RL-BR"]
        stat_min = stat_df[stat_key].min()
        stat_max = stat_df[stat_key].max()
        use_log_scale = (stat_max / max(1, stat_min)) > 100 and stat_key != "min_value"

        g = sns.relplot(
            data=all_results_df[all_results_df["Algorithm"] != "RL-BR"],
            x="search_time_limit",
            y=stat_key,
            hue="Algorithm",
            # col_wrap=N_COLS,
            row="Planning Population",
            col="full_env_id",
            kind="line",
            facet_kws={
                "sharey": True, 
                "sharex": True,
            },
        )

        for (row_key, col_key), ax in g.axes_dict.items():
            ax.set_title(f"{row_key} | {col_key}")
            ax.set_xlabel("Search Time Limit (s)")
            if use_log_scale:
                ax.set_yscale("log")

        del g