# Analysis of Combined Results

Here we explore the performance of the combined planning+RL algorithm across the different environments.

For each algorithm we will look at their mean performance (i.e. episode returns) across the different environments. Looking at both in-distribution (planning population matches the test population) and out-of-distribution (planning population does not match the test population) settings.

**Note** each experiment run was repeated 5 times (once for each RL policy seed), so we average the results across these 5 runs.

In [None]:
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from posggym_baselines.config import REPO_DIR

sys.path.insert(0, str(REPO_DIR / "baseline_exps"))
import exp_utils

sns.set_theme()
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

SAVE_RESULTS = False

In [None]:
ALL_ENV_DATA = exp_utils.load_all_env_data()
for k in ALL_ENV_DATA:
    print(k)

NUM_ENVS = len(ALL_ENV_DATA)

# figure parameters
FIGSIZE = (10, 10)
N_COLS = min(3, NUM_ENVS)
N_ROWS = (NUM_ENVS // N_COLS) + int(NUM_ENVS % N_COLS > 0)

## Load Planning Results

In [None]:
results = []
for env_id, env_data in ALL_ENV_DATA.items():
    env_planning_results = pd.read_csv(env_data.combined_results_file)
    env_planning_results["full_env_id"] = env_id
    results.append(env_planning_results)

results_df = pd.concat(results, ignore_index=True)
results_df.rename(
    columns={
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "return": "Return",
    },
    inplace=True,
)

# Add In/Out of Distribution labels
def get_in_out_dist_label(row):
    return row["Planning Population"] == row["Test Population"]

results_df["In Distribution"] = results_df.apply(
    get_in_out_dist_label, axis=1
)

results_df.sort_values(
    by=[
        "alg", 
        "full_env_id", 
        "Planning Population", 
        "Test Population", 
        "search_time_limit"
    ], 
    inplace=True
)

max_search_time = results_df["search_time_limit"].max()

# Can remove this once we have all the data
# results_df = results_df[results_df["search_time_limit"] > 0.05]
# results_df = results_df[
#     (results_df["alg"] == "POTMMCP") &
#     (results_df["full_env_id"] != "LevelBasedForaging-v3") &
#     (results_df["full_env_id"] != "PursuitEvasion-v1_i0")
# ]
results_df = results_df[results_df["search_time_limit"].isin([0.05, 0.1, 0.5, 1, 5, 10, 20])]
print(results_df["search_time_limit"].unique())

## RL+Planning Algorithm Performance against Planning and Test Populations

Here we look at the performance of the RL+Planning algorithm for each planning and test population combination. We look only at the performance of the algorithm given the maximum search time.

Dimensions:

- Environment
- Planning Population
- Test Population
- Search Time

In the first plot we look at the distribution of the returns for each environment (column) and for each RL policy seed (row).

In [None]:
plot = sns.catplot(
    data=results_df[
        results_df["search_time_limit"] == max_search_time
    ],
    x="Planning Population",
    y="Return",
    hue="Test Population",
    col="full_env_id",
    # col_wrap=N_COLS,
    row="rl_policy_seed",
    kind="violin",
    sharey=True,
)

del plot

Here we look at the same graph but we average the returns across the different RL policy seeds.

In [None]:
plot = sns.catplot(
    data=results_df[
        results_df["search_time_limit"] == max_search_time
    ],
    x="Planning Population",
    y="Return",
    hue="Test Population",
    col="full_env_id",
    # col_wrap=N_COLS,
    kind="violin",
    sharey=True,
)

del plot

## In vs Out of Distribution Performance by Environment

Here we look at the in-distribution vs out-of-distribution performance of for each environment given the maximum search time. We look both at the distribution of returns and the mean returns.

- `x-axis`: Environment
- `y-axis`: Mean episode return
- `hue/z-axis`: In (True) vs Out (False) of Distribution
- `col/figures`: Algorithm

In [None]:
for kind in ("bar", "violin"):
    plot = sns.catplot(
        data=results_df[
            results_df["search_time_limit"] == max_search_time
        ],
        x="full_env_id",
        y="Return",
        hue="In Distribution",
        col="",
        # col_wrap=N_COLS,
        # row="alg",
        kind=kind,
        sharey=False,
    )
    plot.set_xticklabels(rotation=90)

    if SAVE_RESULTS:
        plot.figure.savefig(
            exp_utils.ENV_DATA_DIR / f"combined_id_vs_ood_by_env_results_{kind}.png", 
            bbox_inches="tight"
        )

    del plot

## In vs Out of Distribution Performance vs Search Time

Here we look at the in vs out of distribution performance across search budgets.

- `x-axis`: Search time
- `y-axis`: Mean episode return
- `hue/z-axis`: In (True) vs Out (False) of Distribution
- `col/figures`: Environment

In [None]:
plot = sns.relplot(
    data=results_df,
    x="search_time_limit",
    y="Return",
    hue="In Distribution",
    col="full_env_id",
    col_wrap=N_COLS,
    kind="line",
)

if SAVE_RESULTS:
    plot.figure.savefig(
        exp_utils.ENV_DATA_DIR / "combined_id_vs_ood_by_search_time_per_alg.png",
        bbox_inches="tight"
    )

del plot

Here we show the same but averaged across environments.

In [None]:
plot = sns.relplot(
    data=results_df,
    x="search_time_limit",
    y="Return",
    hue="In Distribution",
    # col="full_env_id",
    # col_wrap=N_COLS,
    kind="line",
)

del plot

## Search Statistics

Here we look at various statistics of the search process for each algorithm.

Each figure is a different statistic, each line is a different environment since we expect some differences between environments based on things like average steps to terminal state. In and out of distribution results are grouped together since we expect and see no different in search statistics between in vs out of distribution.

- `x-axis`: Search time
- `y-axis`: Search Statistic Values
- `col/figures`: Environment

In [None]:
for stat_key in [
    "update_time",
    "reinvigoration_time",
    "evaluation_time",
    "policy_calls",
    "inference_time",
    "search_depth",
    "num_sims",
    "mem_usage",
    "min_value",
    "max_value",
]:

    plot = sns.relplot(
        data=results_df,
        x="search_time_limit",
        y=stat_key,
        hue="full_env_id",
        # col_wrap=N_COLS,
        kind="line",
        col="Planning Population",
    )

    del plot