# Analysis of Combined Results

Here we explore the performance of the combined planning+RL algorithm across the different environments.

For each algorithm we will look at their mean performance (i.e. episode returns) across the different environments. Looking at both in-distribution (planning population matches the test population) and out-of-distribution (planning population does not match the test population) settings.

**Note** each experiment run was repeated 5 times (once for each RL policy seed), so we average the results across these 5 runs.

In [None]:
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from posggym_baselines.config import REPO_DIR

sys.path.insert(0, str(REPO_DIR / "baseline_exps"))
import exp_utils

sns.set_theme(style="white")
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

SAVE_RESULTS = False

In [None]:
ALL_ENV_DATA = exp_utils.load_all_env_data()
for k in ALL_ENV_DATA:
    print(k)

NUM_ENVS = len(ALL_ENV_DATA)

# figure parameters
FIGSIZE = (10, 10)
N_COLS = min(3, NUM_ENVS)
N_ROWS = (NUM_ENVS // N_COLS) + int(NUM_ENVS % N_COLS > 0)

## Load Planning Results

In [None]:
results = []
for env_id, env_data in ALL_ENV_DATA.items():
    env_planning_results = pd.read_csv(env_data.combined_results_file)
    env_planning_results["full_env_id"] = env_id
    results.append(env_planning_results)

results_df = pd.concat(results, ignore_index=True)
results_df.rename(
    columns={
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "return": "Return",
    },
    inplace=True,
)

# Add In/Out of Distribution labels
def get_in_out_dist_label(row):
    return row["Planning Population"] == row["Test Population"]

results_df["In Distribution"] = results_df.apply(
    get_in_out_dist_label, axis=1
)

results_df.sort_values(
    by=[
        "alg", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "search_time_limit"
    ], 
    inplace=True
)

max_search_time = results_df["search_time_limit"].max()

results_df = results_df[results_df["search_time_limit"].isin([0.05, 0.1, 0.5, 1, 5, 10, 20])]
print(results_df["search_time_limit"].unique())

## Performance against Planning and Test Populations

Here we look at the performance of the RL+Planning algorithm for each planning and test population combination. We look only at the performance of the algorithm given the maximum search time.

Dimensions:

- Environment
- Planning Population
- Test Population
- Search Time

In the first plot we look at the distribution of the returns for each environment (column) and for each RL policy seed (row).

In [None]:
plot = sns.catplot(
    data=results_df[
        results_df["search_time_limit"] == max_search_time
    ],
    x="Planning Population",
    y="Return",
    hue="Test Population",
    col="full_env_id",
    # col_wrap=N_COLS,
    row="rl_policy_seed",
    kind="violin",
    sharey=True,
    height=3,
    aspect=1,
)

del plot

Here we look at the same graph but we average the returns across the different RL policy seeds.

In [None]:
plot = sns.catplot(
    data=results_df[
        results_df["search_time_limit"] == max_search_time
    ],
    x="Planning Population",
    y="Return",
    hue="Test Population",
    col="full_env_id",
    col_wrap=N_COLS,
    kind="violin",
    sharey=True,
    height=3,
    aspect=1,
)

del plot

## In vs Out of Distribution Performance by Environment

Here we look at the in-distribution vs out-of-distribution performance of for each environment given the maximum search time. We look both at the distribution of returns and the mean returns.

- `x-axis`: Environment
- `y-axis`: Mean episode return
- `hue/z-axis`: In (True) vs Out (False) of Distribution
- `col/figures`: Algorithm

In [None]:
for kind in ("bar", "violin"):
    plot = sns.catplot(
        data=results_df[
            results_df["search_time_limit"] == max_search_time
        ],
        x="full_env_id",
        y="Return",
        hue="In Distribution",
        hue_order=[True, False],
        # col="",
        # col_wrap=N_COLS,
        # row="alg",
        kind=kind,
        sharey=False,
        height=4,
        aspect=1,
    )
    plot.set_xticklabels(rotation=90)

    if SAVE_RESULTS:
        plot.figure.savefig(
            exp_utils.ENV_DATA_DIR / f"combined_id_vs_ood_by_env_results_{kind}.png", 
            bbox_inches="tight"
        )

    del plot

## In vs Out of Distribution Performance vs Search Time

Here we look at the in vs out of distribution performance across search budgets.

- `x-axis`: Search time
- `y-axis`: Mean episode return
- `hue/z-axis`: In (True) vs Out (False) of Distribution
- `col/figures`: Environment

In [None]:
plot = sns.relplot(
    data=results_df,
    x="search_time_limit",
    y="Return",
    hue="In Distribution",
    hue_order=[True, False],
    col="full_env_id",
    col_wrap=N_COLS,
    kind="line",
    height=3,
    aspect=1.5,
)

if SAVE_RESULTS:
    plot.figure.savefig(
        exp_utils.ENV_DATA_DIR / "combined_id_vs_ood_by_search_time_per_alg.png",
        bbox_inches="tight"
    )

del plot

Here we show the same but averaged across environments.

In [None]:
plot = sns.relplot(
    data=results_df,
    x="search_time_limit",
    y="Return",
    hue="In Distribution",
    hue_order=[True, False],
    # col="full_env_id",
    # col_wrap=N_COLS,
    kind="line",
    height=4,
    aspect=1,
)

del plot

## Search Statistics

Here we look at various statistics of the search process for each algorithm.

Each figure is a different statistic, each line is a different environment since we expect some differences between environments based on things like average steps to terminal state. In and out of distribution results are grouped together since we expect and see no different in search statistics between in vs out of distribution.

- `x-axis`: Search time
- `y-axis`: Search Statistic Values
- `col/figures`: Environment

In [None]:
for stat_key in [
    "update_time",
    "reinvigoration_time",
    "evaluation_time",
    "policy_calls",
    "inference_time",
    "search_depth",
    "num_sims",
    "mem_usage",
    "min_value",
    "max_value",
]:

    plot = sns.relplot(
        data=results_df,
        x="search_time_limit",
        y=stat_key,
        hue="full_env_id",
        # col_wrap=N_COLS,
        kind="line",
        col="Planning Population",
        height=3,
        aspect=1.5,
    )

    del plot

## Belief Statistics

Here we look at various statistics of the belief of the agent in each environment, in and out of distribution.

In [None]:
belief_results_df = []
for full_env_id in ALL_ENV_DATA:
    env_belief_results = pd.read_csv(
        ALL_ENV_DATA[full_env_id].env_data_dir / "combined_belief_results.csv"
    )
    env_belief_results["full_env_id"] = full_env_id
    belief_results_df.append(env_belief_results)

belief_results_df = pd.concat(belief_results_df, ignore_index=True)
belief_results_df.rename(
    columns={
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "return": "Return",
        "belief_state_acc": "Belief State Accuracy",
        "belief_history_acc": "Belief History Accuracy",
        "belief_action_acc": "Belief Action Accuracy",
        "belief_policy_acc": "Belief Policy Accuracy",
    },
    inplace=True,
)

# Add In/Out of Distribution labels
def get_in_out_dist_label(row):
    return row["Planning Population"] == row["Test Population"]

belief_results_df["In Distribution"] = belief_results_df.apply(
    get_in_out_dist_label, axis=1
)

belief_results_df.sort_values(
    by=[
        "alg", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "search_time_limit"
    ], 
    inplace=True
)

print(belief_results_df["search_time_limit"].unique())
print(belief_results_df["full_env_id"].unique())

# for c in belief_results_df.columns:
#     print(c)

In [None]:
for belief_stat_key in [
    "Belief State Accuracy",
    "Belief History Accuracy",
    "Belief Action Accuracy",
    "Belief Policy Accuracy",
]:
    g = sns.relplot(
        data=belief_results_df,
        x="search_time_limit",
        y=belief_stat_key,
        hue="In Distribution",
        hue_order=[True, False],
        col_wrap=N_COLS,
        kind="line",
        height=3,
        aspect=1,
        # col="Planning Population",
        col="full_env_id",
        facet_kws={
            "sharey": "row",
            "sharex": True, 
        },
    )

    # for i, (col_key, ax) in enumerate(g.axes_dict.items()):
    #     ax.set_title(col_key)
    #     if i % 3 == 0:
    #         ax.set_ylabel("Mean Return")
    #     if i >= 3:
    #         ax.set_xlabel("Search Time (s)")

    # for (row_key, col_key), ax in g.axes_dict.items():
    #     ax.set_title(f"{row_key} | {col_key}")

    for col_key, ax in g.axes_dict.items():
        ax.set_title(f"{col_key}")
        ax.set_xlabel("Search Time Limit (s)")

    if True:
        print(f"saving {belief_stat_key} figure")
        g.savefig(
            exp_utils.ENV_DATA_DIR / "figures" / f"{belief_stat_key}.pdf", 
            bbox_inches="tight"
        )

    del g

### Per Episode Step

Here we look at the mean belief statistics per episode step.

- `x-axis`: Episode Step
- `y-axis`: Belief Statistic Values
- `col/figures`: Environment
- `hue/z-axis`: Search Time
- `style`: In (True) vs Out (False) of Distribution

In [None]:
ps_belief_results_df = []
for full_env_id in ALL_ENV_DATA:
    print(full_env_id)
    env_belief_results = pd.read_csv(
        ALL_ENV_DATA[full_env_id].env_data_dir / "combined_belief_per_step_results.csv"
    )
    env_belief_results["full_env_id"] = full_env_id

    env_step_limit = max(
        int(c.split("_")[-1]) 
        for c in env_belief_results.columns 
        if c.startswith("belief_state_acc_")
    )
    env_belief_results["step_limit"] = env_step_limit

    # drop unused columns
    # keep only id columns and per step belief stats (remove mean belief stats)
    id_cols = [
        "alg", 
        "full_env_id", 
        "planning_pop_id", 
        "test_pop_id", 
        "rl_policy_seed",
        "search_time_limit", 
        "step_limit",
        "num"
    ]
    cols_to_keep = [*id_cols] + [
        c for c in env_belief_results.columns 
        if c.startswith("belief_") and c.split("_")[-1] != "acc"
    ]
    env_belief_results.drop(
        columns=[c for c in env_belief_results.columns if c not in cols_to_keep], 
        inplace=True
    )

    # convert from wide to long format for per step belief stats
    stub_names = [
        k for k in [
            "belief_state_acc", 
            "belief_history_acc", 
            "belief_action_acc", 
            "belief_policy_acc"
        ]
        if any(c.startswith(k) for c in env_belief_results.columns)
    ]

    env_belief_results = pd.wide_to_long(
        env_belief_results, 
        stubnames=stub_names, 
        i=id_cols, 
        j="step", 
        sep="_"
    ).reset_index()
    ps_belief_results_df.append(env_belief_results)

ps_belief_results_df = pd.concat(ps_belief_results_df, ignore_index=True)
ps_belief_results_df.rename(
    columns={
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "belief_state_acc": "Belief State Accuracy",
        "belief_history_acc": "Belief History Accuracy",
        "belief_action_acc": "Belief Action Accuracy",
        "belief_policy_acc": "Belief Policy Accuracy",
        "search_time_limit": "Search Time (s)",
    },
    inplace=True,
)

# Add In/Out of Distribution labels
def get_in_out_dist_label(row):
    return row["Planning Population"] == row["Test Population"]

ps_belief_results_df["In Distribution"] = ps_belief_results_df.apply(
    get_in_out_dist_label, axis=1
)

ps_belief_results_df.sort_values(
    by=[
        "alg", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "Search Time (s)"
    ], 
    inplace=True
)

print(ps_belief_results_df["Search Time (s)"].unique())
print(ps_belief_results_df["full_env_id"].unique())
print(ps_belief_results_df["step"].min(), ps_belief_results_df["step"].max())

# for c in belief_results_df.columns:
#     print(c)

In [None]:
for belief_stat_key in [
    "Belief State Accuracy",
    # "Belief History Accuracy",
    "Belief Action Accuracy",
    "Belief Policy Accuracy",
]:
    print(belief_stat_key)

    if belief_stat_key == "Belief Policy Accuracy":
        g = sns.relplot(
            data=ps_belief_results_df[ps_belief_results_df["In Distribution"] == True],
            x="step",
            y=belief_stat_key,
            hue="Search Time (s)",
            col_wrap=N_COLS,
            kind="line",
            col="full_env_id",
            height=3,
            aspect=1.5,
            facet_kws={
                "sharey": False,
                "sharex": False, 
            },
        )
    else:
        g = sns.relplot(
            data=ps_belief_results_df,
            x="step",
            y=belief_stat_key,
            hue="Search Time (s)",
            style="In Distribution",
            style_order=[True, False],
            col_wrap=N_COLS,
            kind="line",
            col="full_env_id",
            height=3,
            aspect=1.5,
            facet_kws={
                "sharey": False,
                "sharex": False, 
            },
        )

    for i, (col_key, ax) in enumerate(g.axes_dict.items()):
        ax.set_title(f"{col_key}")
        if i >= N_COLS:
            ax.set_xlabel("Episode Step")

    if True:
        print(f"saving {belief_stat_key} figure")
        g.savefig(
            exp_utils.ENV_DATA_DIR / "figures" / f"per_step_{belief_stat_key}.pdf", 
            bbox_inches="tight"
        )

    del g