# Analysis of Planning Results

Here we explore the performance of the different planning algorithms across the different environments.

For each algorithm we will look at their mean performance (i.e. episode returns) across the different environments. Looking at both in-distribution (planning population matches the test population) and out-of-distribution (planning population does not match the test population) settings.

**Note** some algorithms, namely I-NTMCP and POMCP, have no planning population. I-NTMCP and POMCP model the other agent as an I-POMDP and uniform random, respectively. Thus, we only consider the out-of-distribution setting for these algorithms, looking at their performance against the various populations (`P0` and `P1`).

In [None]:
import os
import os.path as osp
import yaml
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from posggym_baselines.config import REPO_DIR

sys.path.insert(0, osp.join(REPO_DIR, "baseline_exps"))
import exp_utils

sns.set_theme()
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

SAVE_RESULTS = False

In [None]:
ALL_ENV_DATA = exp_utils.load_all_env_data()
for k in ALL_ENV_DATA:
    print(k)

NUM_ENVS = len(ALL_ENV_DATA)

# figure parameters
FIGSIZE = (10, 10)
N_COLS = min(3, NUM_ENVS)
N_ROWS = (NUM_ENVS // N_COLS) + int(NUM_ENVS % N_COLS > 0)

## Planning Algorithm Performance against Planning and Test Populations

Dimensions:

- Environment
- Algorithm
- Planning Population
- Test Population
- Search Time

In [None]:
planning_results = []
for env_id, env_data in ALL_ENV_DATA.items():
    env_planning_results = pd.read_csv(env_data.planning_results_file)
    env_planning_results["full_env_id"] = env_id
    planning_results.append(env_planning_results)

planning_results_df = pd.concat(planning_results, ignore_index=True)
planning_results_df.rename(
    columns={
        "planning_pop_id": "Planning Population",
        "test_pop_id": "Test Population",
        "return": "Return",
    },
    inplace=True,
)
planning_results_df.sort_values(
    by=["alg", "full_env_id", "Planning Population", "Test Population"], 
    inplace=True
)

max_search_time = planning_results_df["search_time_limit"].max()
planning_results_max_search_df = planning_results_df[
    planning_results_df["search_time_limit"] == max_search_time
]

for c in planning_results_max_search_df.columns:
    print(c)


planning_max_search_plot = sns.catplot(
    data=planning_results_max_search_df,
    x="Planning Population",
    y="Return",
    hue="Test Population",
    col="full_env_id",
    # col_wrap=N_COLS,
    row="alg",
    kind="box",
    sharey=False,
)

if SAVE_RESULTS:
    planning_max_search_plot.figure.savefig(
        osp.join(
            exp_utils.ENV_DATA_DIR, 
            "all_env_planning_max_search_results.png"
        ), 
        bbox_inches="tight"
    )