# Environment Experiment Analysis

Has analysis for a single environment's experiments.

In [None]:
import copy
from dataclasses import dataclass
from itertools import product
import os
import os.path as osp
from pprint import pprint
import sys
import tempfile
from typing import Dict, List, Tuple, Union

import pandas as pd
import matplotlib.pyplot as plt

import posggym_agents

import potmmcp.plot as plot_utils
import potmmcp.plot.paper as paper_utils
from potmmcp.config import BASE_REPO_DIR

# import experiment parameters for each env
sys.path.append(osp.join(BASE_REPO_DIR, "experiments"))
import common
import run_driving_exp as driving
import run_pe_evader_exp as pe_evader
import run_pe_pursuer_exp as pe_pursuer
import run_pp2_exp as pp2
import run_pp4_exp as pp4

BASE_EXP_DIR = osp.join(BASE_REPO_DIR, "experiments")
ENV_RESULTS_DIR = osp.join(BASE_EXP_DIR, "results")
POSGGYM_AGENTS_DIR = osp.join(posggym_agents.config.BASE_DIR, 'agents')

algname = "POTMMCP"
baselinealgname = "I-POMCP-PF"

SAVE_FIGURES = False
SAVE_EXPECTED_RESULTS = True

# uncomment the env to run analysis for 
# ENV_NAME = "driving"
# ENV_NAME = "pe_evader"
ENV_NAME = "pe_pursuer"
# ENV_NAME = "pp2"
# ENV_NAME = "pp4"

In [None]:
@dataclass
class EnvInfo:
    id: str
    label: str
    id_short: str
    exp_params: common.EnvExperimentParams
    policy_results_file: str
    best_response_map: Dict[Tuple[str, ...], str]
    policy_labels: Dict[Union[str, Tuple[str, ...]], str]
    
    @property
    def results_dir(self) -> str:
        return osp.join(ENV_RESULTS_DIR, self.id_short)
    
    @property
    def figure_dir(self) -> str:
        if SAVE_FIGURES:
            return osp.join(self.results_dir, "figures")
        return tempfile.gettempdir()
    
    @property
    def baseline_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "baseline_experiment_results.csv")
    
    @property
    def baseline_avg_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "baseline_avg_performance_results.csv")
    
    @property
    def meta_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "meta_experiment_results.csv")
    
    @property
    def meta_avg_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "meta_avg_performance_results.csv")
    
    @property
    def lambda_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "lambda_experiment_results.csv")
    
    @property
    def lambda_avg_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "lambda_avg_performance_results.csv")
    
    @property
    def many_pi_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "many_pi_experiment_results.csv")
    
    @property
    def many_pi_avg_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "many_pi_avg_performance_results.csv")
    
    @property
    def sensitivity_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "sensitivity_experiment_results.csv")
    
    @property
    def sensitivity_avg_exp_results_file(self) -> str:
        return osp.join(self.results_dir, "sensitivity_avg_performance_results.csv")
        
        
        
pe_policy_labels = {
    "klr_k0_seed0_i0-v0": "K0_0",
    "klr_k1_seed0_i0-v0": "K1_0",
    "klr_k2_seed0_i0-v0": "K2_0",
    "klr_k3_seed0_i0-v0": "K3_0",
    "klr_k4_seed0_i0-v0": "K4_0",
    "klr_k0_seed0_i1-v0": "K0_1",
    "klr_k1_seed0_i1-v0": "K1_1",
    "klr_k2_seed0_i1-v0": "K2_1",
    "klr_k3_seed0_i1-v0": "K3_1",
    "klr_k4_seed0_i1-v0": "K4_1",
}

for pi_id in list(pe_policy_labels):
    pe_policy_labels[(pi_id,)] = pe_policy_labels[pi_id]


# Analysis relevant info for each experiment environment
env_info_map = {
    "driving": EnvInfo(
        id=driving.ENV_ID,
        label="Driving",
        id_short="driving",
        exp_params=driving.DRIVING_EXP_PARAMS,
        policy_results_file=osp.join(
            POSGGYM_AGENTS_DIR, "driving14x14wideroundabout_n2_v0",  "results", "klrbr_results.csv"
        ),
        best_response_map = {
            ("klr_k0_seed0-v0",): "klr_k1_seed0-v0",
            ("klr_k1_seed0-v0",): "klr_k2_seed0-v0",
            ("klr_k2_seed0-v0",): "klr_k3_seed0-v0",
            ("klr_k3_seed0-v0",): "klr_k4_seed0-v0",
            # K4 isn't added since it's not in policy prior, 
            # ("klr_k4_seed0-v0",): "klr_k4_seed0-v0",
        },
        policy_labels={
            "klr_k0_seed0-v0": "K0",
            "klr_k1_seed0-v0": "K1",
            "klr_k2_seed0-v0": "K2",
            "klr_k3_seed0-v0": "K3",
            "klr_k4_seed0-v0": "K4",
            ("klr_k0_seed0-v0",): "K0",
            ("klr_k1_seed0-v0",): "K1",
            ("klr_k2_seed0-v0",): "K2",
            ("klr_k3_seed0-v0",): "K3",
            ("klr_k4_seed0-v0",): "K4",
        }
    ),
    "pe_evader": EnvInfo(
        id=pe_evader.ENV_ID,
        label="PE (Evader)",
        id_short="pe_evader",
        exp_params=pe_evader.PE_EVADER_EXP_PARAMS,
        policy_results_file=osp.join(
            POSGGYM_AGENTS_DIR, "pursuitevasion16x16_v0", "results", "pairwise_results.csv"
        ),
        best_response_map={
            ("klr_k0_seed0_i1-v0",): "klr_k1_seed0_i0-v0",
            ("klr_k1_seed0_i1-v0",): "klr_k2_seed0_i0-v0",
            ("klr_k2_seed0_i1-v0",): "klr_k3_seed0_i0-v0",
            ("klr_k3_seed0_i1-v0",): "klr_k4_seed0_i0-v0",
            # "klr_k4_seed0_i1-v0": "klr_k4_seed0_i0-v0",
        }, 
        policy_labels=pe_policy_labels
    ),
    "pe_pursuer": EnvInfo(
        id=pe_pursuer.ENV_ID,
        label="PE (Pursuer)",
        id_short="pe_pursuer",
        exp_params=pe_pursuer.PE_PURSUER_EXP_PARAMS,
        policy_results_file=osp.join(
            POSGGYM_AGENTS_DIR, "pursuitevasion16x16_v0", "results", "pairwise_results.csv"
        ),
        best_response_map={
            ("klr_k0_seed0_i0-v0",): "klr_k1_seed0_i1-v0",
            ("klr_k1_seed0_i0-v0",): "klr_k2_seed0_i1-v0",
            ("klr_k2_seed0_i0-v0",): "klr_k3_seed0_i1-v0",
            ("klr_k3_seed0_i0-v0",): "klr_k4_seed0_i1-v0",
            # "klr_k4_seed0_i0-v0": "klr_k4_seed0_i1-v0",
        },
        policy_labels=pe_policy_labels
    ),
    "pp2": EnvInfo(
        id=pp2.ENV_ID,
        label="PP (two-agents)",
        id_short="pp2",
        exp_params=pp2.PP2_EXP_PARAMS,
        policy_results_file=osp.join(
            POSGGYM_AGENTS_DIR, "predatorprey10x10_P2_p3_s2_coop_v0", "results", "pairwise_results.csv"
        ),
        best_response_map={
            ('sp_seed0-v0',): 'sp_seed0-v0',
            ('sp_seed1-v0',): 'sp_seed1-v0',
            ('sp_seed2-v0',): 'sp_seed2-v0',
            ('sp_seed3-v0',): 'sp_seed3-v0',
            ('sp_seed4-v0',): 'sp_seed4-v0',
        },
        policy_labels={
            'sp_seed0-v0': "S0",
            'sp_seed1-v0': "S1",
            'sp_seed2-v0': "S2",
            'sp_seed3-v0': "S3",
            'sp_seed4-v0': "S4",
            ('sp_seed0-v0',): "S0",
            ('sp_seed1-v0',): "S1",
            ('sp_seed2-v0',): "S2",
            ('sp_seed3-v0',): "S3",
            ('sp_seed4-v0',): "S4",
        }
    ),
    "pp4": EnvInfo(
        id=pp4.ENV_ID,
        label="PP (four-agents)",
        id_short="pp4",
        exp_params=pp4.PP4_EXP_PARAMS,
        policy_results_file=osp.join(
            POSGGYM_AGENTS_DIR, "predatorprey10x10_P4_p3_s3_coop_v0", "results", "pairwise_results.csv"
        ),
        best_response_map={
            ('sp_seed0-v0', 'sp_seed0-v0', 'sp_seed0-v0'): 'sp_seed0-v0',
            ('sp_seed1-v0', 'sp_seed1-v0', 'sp_seed1-v0'): 'sp_seed1-v0',
            ('sp_seed2-v0', 'sp_seed2-v0', 'sp_seed2-v0'): 'sp_seed2-v0',
            ('sp_seed3-v0', 'sp_seed3-v0', 'sp_seed3-v0'): 'sp_seed3-v0',
            ('sp_seed4-v0', 'sp_seed4-v0', 'sp_seed4-v0'): 'sp_seed4-v0',
        },
        policy_labels={
            'sp_seed0-v0': "S0",
            'sp_seed1-v0': "S1",
            'sp_seed2-v0': "S2",
            'sp_seed3-v0': "S3",
            'sp_seed4-v0': "S4",
            ('sp_seed0-v0', 'sp_seed0-v0', 'sp_seed0-v0'): "T0",
            ('sp_seed1-v0', 'sp_seed1-v0', 'sp_seed1-v0'): "T1",
            ('sp_seed2-v0', 'sp_seed2-v0', 'sp_seed2-v0'): "T2",
            ('sp_seed3-v0', 'sp_seed3-v0', 'sp_seed3-v0'): "T3",
            ('sp_seed4-v0', 'sp_seed4-v0', 'sp_seed4-v0'): "T4",
        }
    ),
}


# ensure figure directory exists for each env
for env_info in env_info_map.values():
    os.makedirs(env_info.figure_dir, exist_ok=True)
    
env_info = env_info_map[ENV_NAME]

print(f"Running analysis for {ENV_NAME}")


## Utility functions

In [None]:
def display_df_info(df, display_columns: bool = True):
    print("\nDF info")
    print("-------")
    
    for k in ["agent_id"]:
        if k in df.columns:
            values = df[k].unique().tolist()
            values.sort()
            print(f"{k}: {values}")
    
    if "policy_id" in df.columns:
        policy_ids = df["policy_id"].unique().tolist()
        policy_ids.sort()
        print("\nPolicies")
        print("--------")
        for pi_id in policy_ids:
            print(pi_id)
    
    if "alg_id" in df.columns:
        alg_ids = df["alg_id"].unique().tolist()
        alg_ids.sort()
        print("\nAlg IDs")
        print("-------")
        for n in alg_ids:
            print(n)
    
    if "meta_pi" in df.columns:
        print("\nMeta Pis:")
        print("---------")
        print(df["meta_pi"].unique().tolist())
    
    if "co_team_id" in df.columns:
        team_ids = df["co_team_id"].unique().tolist()
        team_ids.sort()
        print("\nCo-player Team IDs")
        print("------------------")
        for t_id in team_ids:
            print(t_id)
    
    if "co_team_seed" in df.columns:
        team_seeds = df["co_team_seed"].unique().tolist()
        team_seeds.sort()
        print("\nCo-player Team Seeds")
        print("--------------------")
        print(team_seeds)
        
    if "num_episodes" in df.columns:
        num_eps = df["num_episodes"].unique().tolist()
        num_eps.sort()
        print("\nNum episodes")
        print("------------")
        print(num_eps)
    
    if display_columns:
        print("\nColumns")
        print("-------")
        for c in df.columns:
            print(c)
            

## The Environment

In [None]:
fig, ax = plot_utils.plot_environment(env_info.id, (9, 9))
fig.savefig(osp.join(env_info.figure_dir, "env.png"))
plt.show()

## Policy Prior

In [None]:
print("Policy Prior")
print("------------")
pprint(env_info.exp_params.get_policy_prior_map(remove_env_id=True))

print("\nPlanning agent policy ids")
print("-------------------------")
pprint(env_info.exp_params.get_planning_policy_ids(remove_env_id=True))

print("\nOther agent policy ids")
print("----------------------")
pprint(env_info.exp_params.get_other_policy_ids(remove_env_id=True))

print("\nAll policy ids")
print("----------------------")
pprint(env_info.exp_params.get_all_policy_ids(remove_env_id=True))

print("\nOther joint policies")
print("----------------------")
pprint(env_info.exp_params.get_other_joint_policies(remove_env_id=True))

## Loading Fixed Policy Data and add Full-Knowledge Best-Response baseline

In [None]:
policy_df = plot_utils.import_results(env_info.policy_results_file)

all_fixed_policy_ids = policy_df["policy_id"].unique().tolist()
all_fixed_policy_ids.sort()
print("All Policies")
print("------------")
for pi_id in all_fixed_policy_ids:
    print(pi_id)
# delete so it's not hanging around and polluting the namespace
del all_fixed_policy_ids

all_co_team_ids = policy_df["co_team_id"].unique().tolist()
all_co_team_ids.sort()
print("\nAll Co-Team IDs")
print("---------------")
for t_id in all_co_team_ids:
    print(t_id)
del all_co_team_ids
    
# Drop unneeded co-player policies
policy_df = policy_df[policy_df["policy_id"].isin(env_info.exp_params.get_all_policy_ids(remove_env_id=True))]
policy_df = policy_df[policy_df["co_team_id"].isin(env_info.exp_params.get_other_joint_policies(remove_env_id=True))]


if env_info.exp_params.symmetric_env:
    # for symmetric env we need to make sure there is a row for planning agent for all (policy_id, co_team_id)
    # in asymmetric env, just drop rows for non-planning agent
    next_exp_id = policy_df["exp_id"].max() + 1
    new_rows = []
    for pi_id, co_team_id in product(policy_df["policy_id"].unique(), policy_df["co_team_id"].unique()):
        pair_df = policy_df[(policy_df["policy_id"] == pi_id) & (policy_df["co_team_id"] == co_team_id)]
        if len(pair_df) == 0:
            print(f"missing entry for ({pi_id}, {co_team_id}")
            continue
        elif env_info.exp_params.planning_agent_id in pair_df["agent_id"].unique():
            # already in df with correct agent id
            continue
        # take first entry (there should only be one)
        pair_row = policy_df.loc[
            (policy_df["policy_id"] == pi_id) & (policy_df["co_team_id"] == co_team_id)
        ].copy()
        pair_row["agent_id"] = env_info.exp_params.planning_agent_id
        pair_row["exp_id"] = next_exp_id
        next_exp_id += 1
        new_rows.append(pair_row)
    
    if len(new_rows):
        print(f"\nAdding {len(new_rows)} rows for agent_id={env_info.exp_params.planning_agent_id}")
        new_pairs_df = pd.concat(new_rows, axis='rows').reset_index(drop=True)
        policy_df = pd.concat([policy_df, new_pairs_df], ignore_index=True)
        
# Drop rows for non-planning agent
policy_df = policy_df[policy_df["agent_id"] == env_info.exp_params.planning_agent_id]
assert len(policy_df["agent_id"].unique()) == 1


# Add full-knowledge br
new_rows = []
next_exp_id = policy_df["exp_id"].max() + 1
for co_team_id, br_policy_id in env_info.best_response_map.items():
    # Add BR agent row
    br_row = policy_df.loc[
        (policy_df["policy_id"] == br_policy_id) 
        & (policy_df["co_team_id"] == co_team_id)
    ].copy()
    # update policy id to baseline name
    br_row["policy_id"] = "full-knowledge-br"
    br_row["exp_id"] = next_exp_id
    next_exp_id += 1
    new_rows.append(br_row)
    

print("\nStats from adding full-knowledge-br:")
print(f"{len(policy_df)=}")
print(f"{len(new_rows)=}")
br_df = pd.concat(new_rows, axis='rows').reset_index(drop=True)
print(f"{len(br_df)=}")
policy_df = pd.concat([policy_df, br_df], ignore_index=True)

print(f"{len(policy_df)=} (i.e. all together)")
policy_df[policy_df["policy_id"] == "full-knowledge-br"]

display_df_info(policy_df)

## Fixed policies pairwise performance

This is what was used to generate the meta-policies.

Here we show pairwise performance between each individual policy and co-team

In [None]:
fig, axs = plot_utils.plot_pairwise_comparison(
    policy_df[policy_df["policy_id"] != "full-knowledge-br"],
    y_key="episode_return_mean", 
    policy_key="policy_id",
    coplayer_policy_key="co_team_id",
    y_err_key="episode_return_CI",
    vrange=None, 
    figsize=(3, 6), 
    valfmt="{x:.2f}",
    policies=None,
    coplayer_policies=None,
    policy_labels=env_info.policy_labels,
    average_duplicates=True,
    duplicate_warning=True
)
fig.tight_layout()
fig.savefig(osp.join(env_info.figure_dir, "fixed_policy_payoffs.png"))

## Loading BAPOSGMCP Data and Combining with Fixed-Policy data

In [None]:
def combine_dfs(df1, df2):
    # First need to update exp_ids of one dataframe so there are no duplicate exp_ids
    df1_max_exp_id = df1["exp_id"].max()
    if df2["exp_id"].min() <= df1_max_exp_id:
        df2["exp_id"] += df1_max_exp_id+1

    combined_df = pd.concat([df1, df2]).reset_index(drop = True)

    def add_alg_id(row):
        pi_id = row["policy_id"]
        if pi_id.startswith("klr") or pi_id.startswith("sp"):
            return "fixed"
        tokens = pi_id.split("_")
        alg_id = "_".join([
            t for t in tokens 
            if all(
                s not in t for s in [
                    "actionselection", 
                    "searchtimelimit",
                    "numsims", 
                    "truncated", 
                    "greedy", 
                    "softmax", 
                    "uniform", 
                    "piklr",
                    "pisp",
                    "i0",
                    "i1"
                ]
            )
        ])
        return alg_id


    def add_meta_pi(row):
        pi_id = row["policy_id"]
        for meta_pi in ["greedy", "softmax", "uniform"]:
            if meta_pi in pi_id:
                return meta_pi
        return "NA"

    combined_df["alg_id"] = combined_df.apply(add_alg_id, axis=1)
    combined_df["meta_pi"] = combined_df.apply(add_meta_pi, axis=1)
        
    return combined_df

In [None]:
baseline_df = plot_utils.import_results(env_info.baseline_exp_results_file)

# drop non-planning agent rows
baseline_df = baseline_df[baseline_df["agent_id"] == env_info.exp_params.planning_agent_id]

baseline_df = combine_dfs(baseline_df, policy_df)
display_df_info(baseline_df)

In [None]:
meta_df = plot_utils.import_results(env_info.meta_exp_results_file)

# drop non-planning agent rows
meta_df = meta_df[meta_df["agent_id"] == env_info.exp_params.planning_agent_id]

meta_df = combine_dfs(meta_df, policy_df)
display_df_info(meta_df)

In [None]:
lambda_df = plot_utils.import_results(env_info.lambda_exp_results_file)

# drop non-planning agent rows
lambda_df = lambda_df[lambda_df["agent_id"] == env_info.exp_params.planning_agent_id]

display_df_info(lambda_df)

In [None]:
if (
    env_info.exp_params.many_pi_pairwise_returns is not None
    and osp.exists(env_info.many_pi_exp_results_file)
):
    many_pi_df = plot_utils.import_results(env_info.many_pi_exp_results_file)

    # drop non-planning agent rows
    many_pi_df = many_pi_df[many_pi_df["agent_id"] == env_info.exp_params.planning_agent_id]

    # Add Full-knowledge BR
    many_pi_pairwise_returns = env_info.exp_params.many_pi_pairwise_returns
    br_returns = {}
    for pi_state, policy_returns in many_pi_pairwise_returns.items():
        max_policies = []
        max_return = -float("inf")
        for pi_id, ret in policy_returns.items():
            if ret > max_return:
                max_return = ret
                max_policies = [pi_id]
            elif ret == max_return:
                max_policies.append(pi_id)

        br_returns[pi_state] = max_return

    mean_br_value = sum(br_returns.values()) / len(br_returns)

    # Add BR row to df
    # We will only be using episode returns in analysis so we just copy any row
    br_row = many_pi_df.loc[[0], :].copy(deep=True)
    br_row["policy_id"] = "full-knowledge-br"
    br_row[f"coplayer_policy_id_{env_info.exp_params.planning_agent_id}"] = "full-knowledge-br"
    br_row["episode_return_mean"] = mean_br_value
    br_row["episode_return_CI"] = 0.0
    br_row["exp_id"] = many_pi_df["exp_id"].max() + 1

    many_pi_df = pd.concat([many_pi_df, br_row], ignore_index=True)

    display_df_info(many_pi_df)
else:
    # exp not run for current env 
    many_pi_df = None

In [None]:
if (
    env_info.exp_params.sensitivity_pairwise_returns is not None
    and osp.exists(env_info.sensitivity_exp_results_file)
):
    sens_df = plot_utils.import_results(env_info.sensitivity_exp_results_file)

    # drop non-planning agent rows
    sens_df = sens_df[sens_df["agent_id"] == env_info.exp_params.planning_agent_id]
    
    # Add Full-knowledge BR
    sens_pairwise_returns = env_info.exp_params.sensitivity_pairwise_returns
    br_returns = {}
    for pi_state, policy_returns in sens_pairwise_returns.items():
        max_policies = []
        max_return = -float("inf")
        for pi_id, ret in policy_returns.items():
            if ret > max_return:
                max_return = ret
                max_policies = [pi_id]
            elif ret == max_return:
                max_policies.append(pi_id)

        br_returns[pi_state] = max_return

    mean_br_value = sum(br_returns.values()) / len(br_returns)

    # Add BR row to df
    # We will only be using episode returns in analysis so we just copy any row
    br_row = sens_df.iloc[[0], :].copy(deep=True)
    br_row["policy_id"] = "full-knowledge-br"
    br_row[f"coplayer_policy_id_{env_info.exp_params.planning_agent_id}"] = "full-knowledge-br"
    br_row["episode_return_mean"] = mean_br_value
    br_row["episode_return_CI"] = 0.0
    br_row["exp_id"] = sens_df["exp_id"].max() + 1

    sens_df = pd.concat([sens_df, br_row], ignore_index=True)

    display_df_info(sens_df)
else:
    # exp not run for current env 
    sens_df = None

## Pairwise performance

Here we look at the performance of each policy against each other policy including BAPOSGMCP and baselines with different number of simulations, action selection, and meta-policies.

In [None]:
plot_utils.plot_pairwise_comparison(
    baseline_df,
    y_key="episode_return_mean", 
    policy_key="policy_id",
    y_err_key=None,
    coplayer_policy_key="co_team_id",
    vrange=None, 
    figsize=(20, len(baseline_df["policy_id"].unique())), 
    valfmt="{x:.2f}",
    policies=None,
    coplayer_policies=None,
    policy_labels=env_info.policy_labels,
    average_duplicates=True,
    duplicate_warning=False
)

In [None]:
plot_utils.plot_pairwise_comparison(
    baseline_df,
    # lambda_df,
    y_key="num_episodes",
    policy_key="policy_id",
    y_err_key=None,
    coplayer_policy_key="co_team_id",
    vrange=None, 
    figsize=(20, len(baseline_df["policy_id"].unique())), 
    valfmt="{x:.0f}",
    policies=None,
    coplayer_policies=None,
    policy_labels=env_info.policy_labels,
    average_duplicates=True,
    duplicate_warning=False
)

# Expected Performance

Here we look at the expected performance given the policy prior of BAPOSGMCP and the different baselines.

Specifically:

1. Comparing different meta-policies
2. Comparing performance using meta-policy vs using a single fixed policy
3. Comparing performance between all algorithms


In [None]:
baseline_exp_df = plot_utils.get_uniform_expected_df(
    baseline_df, 
    coplayer_policies=env_info.exp_params.get_other_joint_policies(remove_env_id=True),
    coplayer_policy_key="co_team_id"
)
display_df_info(baseline_exp_df, display_columns=False)
if SAVE_EXPECTED_RESULTS:
    baseline_exp_df.to_csv(env_info.baseline_avg_exp_results_file)

In [None]:
meta_exp_df = plot_utils.get_uniform_expected_df(
    meta_df, 
    coplayer_policies=env_info.exp_params.get_other_joint_policies(remove_env_id=True),
    coplayer_policy_key="co_team_id"
)
display_df_info(meta_exp_df, display_columns=False)
if SAVE_EXPECTED_RESULTS:
    meta_exp_df.to_csv(env_info.meta_avg_exp_results_file)

In [None]:
lambda_exp_df = plot_utils.get_uniform_expected_df(
    lambda_df, 
    coplayer_policies=env_info.exp_params.get_other_joint_policies(remove_env_id=True),
    coplayer_policy_key="co_team_id"
)
display_df_info(lambda_exp_df, display_columns=False)
if SAVE_EXPECTED_RESULTS:
    lambda_exp_df.to_csv(env_info.lambda_avg_exp_results_file)

In [None]:
if many_pi_df is not None:
    # Many pi df already averaged, need to save to include Full-Knowledge BR though
    display_df_info(many_pi_df, display_columns=False)
    if SAVE_EXPECTED_RESULTS:
        many_pi_df.to_csv(env_info.many_pi_avg_exp_results_file)

In [None]:
if sens_df is not None:
    # Sensitivity df already averaged, need to save to include Full-Knowledge BR though
    display_df_info(sens_df, display_columns=False)
    if SAVE_EXPECTED_RESULTS:
        sens_df.to_csv(env_info.sensitivity_avg_exp_results_file)

## Comparison of the different Meta-Policies

Here we look at the performance of our algorithm using the difference meta-policies.

Looking at performance with:

- truncated search
- using PUCB

We also look at the performance of the metabaseline with the different meta-policies.

In [None]:
meta_pi_label_map = {
    "greedy": r"$\sigma^{G}$",
    "softmax": r"$\sigma^{S}$",
    "uniform": r"$\sigma^{U}$",
}

potmmcp_meta_exp_df = meta_exp_df[
    (meta_exp_df["alg_id"].isin(["baposgmcp"])) & (meta_exp_df["truncated"] == True)
]

display_df_info(potmmcp_meta_exp_df, display_columns=False)

In [None]:
fig_kwargs = {"figsize": (paper_utils.PAGE_COL_WIDTH, 3.8)}
subplot_kwargs = {
    "ylabel": "Mean Episode Return",
    "xlabel": "Search time (s)"
}

fig, ax = plt.subplots(
    nrows=1,
    ncols=1,
    squeeze=True,
    subplot_kw=subplot_kwargs,
    **fig_kwargs,
)

paper_utils.plot_meta_policy_performance(
    potmmcp_meta_exp_df,
    ax,
    x_key="search_time_limit",
    y_key="episode_return_mean",
    y_err_key="episode_return_CI",
    meta_pi_label_map=meta_pi_label_map,
)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, ncol=3, loc="lower right")

fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.12, 1.0, 1.0))
fig.savefig(osp.join(env_info.figure_dir, f"meta_pi_return.png"))

del fig_kwargs
del subplot_kwargs

## Comparing Meta-Policy versus using a single policy

Here we look at the performance of POTMMCP with a meta-policy against not using a meta-policy (i.e. using the different fixed policies).

In [None]:
meta_vs_single_label_map = {
    "baposgmcp_metasoftmax": r"$\sigma^{S}$",
    "baposgmcp-random": "Random",
    "baposgmcp-random_i0": "Random",
    "baposgmcp-random_i1": "Random",
    "baposgmcp-fixed_piklrk0seed0-v0": "K0", 
    "baposgmcp-fixed_piklrk1seed0-v0": "K1", 
    "baposgmcp-fixed_piklrk2seed0-v0": "K2",
    "baposgmcp-fixed_piklrk3seed0-v0": "K3",
    "baposgmcp-fixed_piklrk4seed0-v0": "K4",
    "baposgmcp-fixed_pispseed0-v0": "S0",
    "baposgmcp-fixed_pispseed1-v0": "S1",
    "baposgmcp-fixed_pispseed2-v0": "S2",
    "baposgmcp-fixed_pispseed3-v0": "S3",
    "baposgmcp-fixed_pispseed4-v0": "S4",
    "baposgmcp-fixed_i0_piklrk0seed0i0-v0": "K0_0",
    "baposgmcp-fixed_i0_piklrk1seed0i0-v0": "K1_0",
    "baposgmcp-fixed_i0_piklrk2seed0i0-v0": "K2_0",
    "baposgmcp-fixed_i0_piklrk3seed0i0-v0": "K3_0",
    "baposgmcp-fixed_i0_piklrk4seed0i0-v0": "K4_0",
    "baposgmcp-fixed_i1_piklrk0seed0i1-v0": "K0_1",
    "baposgmcp-fixed_i1_piklrk1seed0i1-v0": "K1_1",
    "baposgmcp-fixed_i1_piklrk2seed0i1-v0": "K2_1",
    "baposgmcp-fixed_i1_piklrk3seed0i1-v0": "K3_1",
    "baposgmcp-fixed_i1_piklrk4seed0i1-v0": "K4_1",
}

meta_vs_single_df = meta_exp_df[
    (meta_exp_df["alg_id"].isin(["baposgmcp", "baposgmcp-fixed", "baposgmcp-random"]))
    & (meta_exp_df["action_selection"].isin(["pucb"]))
    & (meta_exp_df["meta_pi"].isin(["softmax", "NA"]))
    & (
        ((meta_exp_df["alg_id"] == "baposgmcp-random") & (meta_exp_df["truncated"] == False))
        | ((meta_exp_df["alg_id"].isin(["baposgmcp", "baposgmcp-fixed"])) & (meta_exp_df["truncated"] == True))
    )
]
display_df_info(meta_vs_single_df, display_columns=False)

policy_prefixes = set()
for pi_id in meta_vs_single_df["policy_id"].unique():
    tokens = pi_id.split("_")
    if tokens[0] == "baposgmcp-fixed":
        if "_i0" in pi_id or "_i1" in pi_id:
            prefix = "_".join(tokens[:3])
        else:
            prefix = "_".join(tokens[:2])
        policy_prefixes.add(prefix)
        
policy_prefixes = list(policy_prefixes)
policy_prefixes.sort()
# add this way so order in figure is correct
policy_prefixes.insert(0, "baposgmcp-random")
policy_prefixes.insert(0, "baposgmcp_metasoftmax")
print("\nPolicy Prefixes:")
pprint(policy_prefixes)

In [None]:
fig_kwargs = {"figsize": (paper_utils.PAGE_COL_WIDTH, 3.8)}
subplot_kwargs = {
    "ylabel": "Mean Episode Return",
    "xlabel": "Search time (s)"
}

fig, ax = plt.subplots(
    nrows=1,
    ncols=1,
    squeeze=True,
    subplot_kw=subplot_kwargs,
    **fig_kwargs,
)

paper_utils.plot_performance(
    meta_vs_single_df,
    ax=ax,
    x_key="search_time_limit",
    y_key="episode_return_mean",
    y_err_key="episode_return_CI",
    policy_key="policy_id",
    policy_prefixes=policy_prefixes,
    constant_policy_prefixes=[],
    pi_label_map=meta_vs_single_label_map,
)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, ncol=2, loc="lower right")

fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.35, 1.0, 1.0))
fig.savefig(osp.join(env_info.figure_dir, f"meta_vs_fixed_return.png"))

del fig_kwargs
del subplot_kwargs

In [None]:
meta_vs_single_best_worst_pi_env_label_maps = {
    "driving": {
        "baposgmcp-fixed_piklrk0seed0-v0": "Worst", 
        # "baposgmcp-fixed_piklrk1seed0-v0": "K1", 
        "baposgmcp-fixed_piklrk2seed0-v0": "Best",
        # "baposgmcp-fixed_piklrk3seed0-v0": "K3",
        # "baposgmcp-fixed_piklrk4seed0-v0": "K4"
    },
    "pe_evader": {
        "baposgmcp-fixed_i0_piklrk0seed0i0-v0": "Worst",
        #"baposgmcp-fixed_i0_piklrk1seed0i0-v0": "K1_0",
        "baposgmcp-fixed_i0_piklrk2seed0i0-v0": "Best",
        # "baposgmcp-fixed_i0_piklrk3seed0i0-v0": "K3_0",
        # "baposgmcp-fixed_i0_piklrk4seed0i0-v0": "K4_0"
    },
    "pe_pursuer": {
        "baposgmcp-fixed_i1_piklrk0seed0i1-v0": "Worst",
        # "baposgmcp-fixed_i1_piklrk1seed0i1-v0": "K1_1",
        "baposgmcp-fixed_i1_piklrk2seed0i1-v0": "Best",
        # "baposgmcp-fixed_i1_piklrk3seed0i1-v0": "K3_1",
        # "baposgmcp-fixed_i1_piklrk4seed0i1-v0": "K4_1"
    },
    "pp2": {
        # "baposgmcp_fixed_pispseed0-v0": "S0",
        "baposgmcp-fixed_pispseed1-v0": "Worst",
        # "baposgmcp_fixed_pispseed2-v0": "S2",
        # "baposgmcp_fixed_pispseed3-v0": "S3",
        "baposgmcp-fixed_pispseed4-v0": "Best",
    },
    "pp4": {
        "baposgmcp-fixed_pispseed0-v0": "Worst", 
        "baposgmcp-fixed_pispseed1-v0": "Best", 
        # "baposgmcp-fixed_pispseed2-v0": "S2",
        # "baposgmcp-fixed_pispseed3-v0": "S3",
        # "baposgmcp-fixed_pispseed4-v0": "S4"
    }
}

meta_vs_single_best_worst_pi_label_map = {
    "baposgmcp_metasoftmax": r"$\sigma^{S}$",
    "baposgmcp-random": "Random",
}
meta_vs_single_best_worst_pi_label_map.update(meta_vs_single_best_worst_pi_env_label_maps[env_info.id_short])

fig_kwargs = {"figsize": (paper_utils.PAGE_COL_WIDTH, 3.8)}
subplot_kwargs = {
    "ylabel": "Mean Episode Return",
    "xlabel": "Search time (s)"
}

num_rows = 1
num_cols = 1

fig, ax = plt.subplots(
    nrows=num_rows,
    ncols=num_cols,
    squeeze=True,
    subplot_kw=subplot_kwargs,
    **fig_kwargs,
)


paper_utils.plot_performance(
    meta_vs_single_df,
    ax=ax,
    x_key="search_time_limit",
    y_key="episode_return_mean",
    y_err_key="episode_return_CI",
    policy_key="policy_id",
    policy_prefixes=list(meta_vs_single_best_worst_pi_label_map),
    constant_policy_prefixes=[],
    pi_label_map=meta_vs_single_best_worst_pi_label_map,
)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, ncol=2, loc="lower right")

fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.2, 1.0, 1.0))
fig.savefig(osp.join(env_info.figure_dir, f"meta_vs_fixed_return_best_and_worst.png"))

del fig_kwargs
del subplot_kwargs

## Comparing BAPOSGMCP versus baselines

Finally we compare BAPOSGMCP versus baselines. Specifically we compare:

- BAPOSGMCP (PUCB + Best Meta)
- IPOMCP-Meta (UCB + Best Meta)
- IPOMCP (UCB + Random)
- Full Knowledge BR
- Meta

In [None]:
best_meta_pi = "softmax"
baseline_perf_df = baseline_exp_df[
    (baseline_exp_df["alg_id"] == "full-knowledge-br")
    | ((baseline_exp_df["alg_id"] == "metabaseline") & (baseline_exp_df["meta_pi"] == best_meta_pi))
    | ((baseline_exp_df["alg_id"] == "baposgmcp") & (baseline_exp_df["meta_pi"] == best_meta_pi) & (baseline_exp_df["truncated"] == True))
    | ((baseline_exp_df["alg_id"] == "baposgmcp-random") & (baseline_exp_df["truncated"] == False))
    | ((baseline_exp_df["alg_id"] == "ucbmcp") & (baseline_exp_df["meta_pi"] == best_meta_pi) & (baseline_exp_df["truncated"] == True))
    | ((baseline_exp_df["alg_id"] == "ucbmcp-random") & (baseline_exp_df["truncated"] == False))
]

baseline_policy_prefixes_to_plot = [
    f"baposgmcp_meta{best_meta_pi}",
    # "baposgmcp-random",
    "ucbmcp-random",
    f"ucbmcp_meta{best_meta_pi}",
    f"metabaseline_{best_meta_pi}",
    f"full-knowledge-br",
]
baseline_constant_policy_prefixes = [
    f"metabaseline_{best_meta_pi}",
    f"full-knowledge-br",
]

baseline_pi_label_map = {
    f"baposgmcp_meta{best_meta_pi}": algname,
    # "baposgmcp-random": f"{algname} + Random", 
    "full-knowledge-br": "Best-Response",
    f"metabaseline_{best_meta_pi}": "Meta-Policy",
    f"ucbmcp_meta{best_meta_pi}": f"{baselinealgname} + Meta",
    "ucbmcp-random": f"{baselinealgname} + Random"
}

display_df_info(baseline_perf_df, display_columns=False)

In [None]:
fig_kwargs = {"figsize": (paper_utils.PAGE_WIDTH, 6)}
subplot_kwargs = {
    "ylabel": "Mean Episode Return",
    "xlabel": "Search time (s)"
}

num_rows = 1
num_cols = 1

fig, ax = plt.subplots(
    nrows=num_rows,
    ncols=num_cols,
    squeeze=True,
    subplot_kw=subplot_kwargs,
    **fig_kwargs,
)


paper_utils.plot_performance(
    baseline_perf_df,
    ax=ax,
    x_key="search_time_limit",
    y_key="episode_return_mean",
    y_err_key="episode_return_CI",
    policy_key="policy_id",
    policy_prefixes=baseline_policy_prefixes_to_plot,
    constant_policy_prefixes=baseline_constant_policy_prefixes,
    pi_label_map=baseline_pi_label_map,
)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, ncol=3, loc="lower right")

fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.12, 1.0, 1.0))
fig.savefig(osp.join(env_info.figure_dir, "baselines_return_vs_search_time.png"))

del fig_kwargs
del subplot_kwargs

## Comparing POTMMCP using different lambda values

Here we look at POTMMCP using different values of the lambda hyperparameter, which controls the weighting of the policy prior during PUCT action selection.

In [None]:
fig_kwargs = {"figsize": (paper_utils.PAGE_WIDTH, 6)}
subplot_kwargs = {
    "ylabel": "Mean Episode Return",
    "xlabel": "Search time (s)"
}

num_rows = 1
num_cols = 1

fig, ax = plt.subplots(
    nrows=num_rows,
    ncols=num_cols,
    squeeze=True,
    subplot_kw=subplot_kwargs,
    **fig_kwargs,
)

paper_utils.plot_lambda_performance(
    lambda_exp_df,
    ax=ax,
    x_key="search_time_limit",
    y_key="episode_return_mean",
    y_err_key="episode_return_CI",
    lambda_key="root_exploration_fraction"
)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, ncol=4, loc="lower right")

fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.08, 1.0, 1.0))
fig.savefig(osp.join(env_info.figure_dir, "lambdas_return_vs_search_time.png"))

del fig_kwargs
del subplot_kwargs

## Comparing POTMMCP with large policy set

Here we look at POTMMCP when the policy set is large.

In [None]:
if many_pi_df is not None:
    # only supported for some envs
    fig_kwargs = {"figsize": (paper_utils.PAGE_WIDTH, 6)}
    subplot_kwargs = {
        "ylabel": "Mean Episode Return",
        "xlabel": "Search time (s)"
    }

    num_rows = 1
    num_cols = 1

    fig, ax = plt.subplots(
        nrows=num_rows,
        ncols=num_cols,
        squeeze=True,
        subplot_kw=subplot_kwargs,
        **fig_kwargs,
    )

    many_pi_policy_prefixes_to_plot = [
        #"potmmcp_metagreedy",
        "potmmcp_metasoftmax",
        # "potmmcp_metauniform",
        # "baposgmcp-random",
        # "metabaseline_greedy",
        "metabaseline_softmax",
        #"metabaseline_uniform",
        "full-knowledge-br",
    ]


    many_pi_constant_policy_prefixes = [
        # "metabaseline_greedy",
        "metabaseline_softmax",
        # "metabaseline_uniform"
        "full-knowledge-br",
    ]

    many_pi_pi_label_map = {
        "potmmcp_metagreedy": f"{algname} Greedy",
        "potmmcp_metasoftmax": f"{algname}",
        "potmmcp_metauniform": f"{algname} Uniform",
        "full-knowledge-br": "Best-Response",
        "metabaseline_greedy": "Meta-Policy Greedy",
        "metabaseline_softmax": "Meta-Policy",
        "metabaseline_uniform": "Meta-Policy Uniform",
    }

    paper_utils.plot_performance(
        many_pi_df,
        ax=ax,
        x_key="search_time_limit",
        y_key="episode_return_mean",
        y_err_key="episode_return_CI",
        policy_key="policy_id",
        policy_prefixes=many_pi_policy_prefixes_to_plot,
        constant_policy_prefixes=many_pi_constant_policy_prefixes,
        pi_label_map=many_pi_pi_label_map,
    )

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, ncol=3, loc="lower right")

    fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.08, 1.0, 1.0))
    fig.savefig(osp.join(env_info.figure_dir, "many_pi_return_vs_search_time.png"))

    del fig_kwargs
    del subplot_kwargs

## Sensitivity of POTMMCP to out of distribution policies

Here we look at POTMMCP when the policies in the internal policy set are different to the actual policies used by the other agent.

In [None]:
if sens_df is not None:
    # only supported for some envs
    fig_kwargs = {"figsize": (paper_utils.PAGE_WIDTH, 6)}
    subplot_kwargs = {
        "ylabel": "Mean Episode Return",
        "xlabel": "Search time (s)"
    }

    num_rows = 1
    num_cols = 1

    fig, ax = plt.subplots(
        nrows=num_rows,
        ncols=num_cols,
        squeeze=True,
        subplot_kw=subplot_kwargs,
        **fig_kwargs,
    )
    
    best_meta_pi = "softmax"

    sens_policy_prefixes_to_plot = [
        f"potmmcp_meta{best_meta_pi}",
        f"ucbmcp_meta{best_meta_pi}",
        "ucbmcp-random",
        f"metabaseline_{best_meta_pi}",
        "full-knowledge-br",
    ]

    sens_constant_policy_prefixes = [
        f"metabaseline_{best_meta_pi}",
        "full-knowledge-br",
    ]

    sens_pi_label_map = {
        f"potmmcp_meta{best_meta_pi}": f"{algname}",
        "full-knowledge-br": "Best-Response",
        "metabaseline_softmax": "Meta-Policy",
        f"ucbmcp_meta{best_meta_pi}": f"{baselinealgname} + Meta",
        "ucbmcp-random": f"{baselinealgname} + Random"
    }

    paper_utils.plot_performance(
        sens_df,
        ax=ax,
        x_key="search_time_limit",
        y_key="episode_return_mean",
        y_err_key="episode_return_CI",
        policy_key="policy_id",
        policy_prefixes=sens_policy_prefixes_to_plot,
        constant_policy_prefixes=sens_constant_policy_prefixes,
        pi_label_map=sens_pi_label_map,
    )

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, ncol=3, loc="lower right")

    fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.16, 1.0, 1.0))
    fig.savefig(osp.join(env_info.figure_dir, "sensitivity_return_vs_search_time.png"))

    del fig_kwargs
    del subplot_kwargs

# Analysis

Here we take a deeper dive into the characteristics of POTMMCP. Specifically looking at:

1. Belief accuracy
2. Planning time

## Looking at Belief accuracy by steps

- action_dist_distance
- bayes_accuracy

In [None]:
belief_df = baseline_exp_df[
    (baseline_exp_df["alg_id"] == "baposgmcp") 
    & (baseline_exp_df["truncated"] == True)
    & (baseline_exp_df["action_selection"] == "pucb")
]

belief_alg_ids = belief_df["alg_id"].unique().tolist()
belief_alg_ids.sort()
print("Alg IDs")
print("-------")
for n in belief_alg_ids:
    print(n)

# group over meta-policy values
belief_group_keys = ["alg_id", "search_time_limit"]
belief_agg_dict = plot_utils.get_uniform_expected_agg_map(belief_df)
belief_df_cols = set(belief_df.columns)
keys_to_drop = []
for k in belief_agg_dict:
    if k not in belief_df_cols or k in belief_group_keys:
        keys_to_drop.append(k)

for k in keys_to_drop:
    belief_agg_dict.pop(k)   
        
gb = belief_df.groupby(belief_group_keys)
gb_agg = gb.agg(**belief_agg_dict)
belief_gb_df = gb_agg.reset_index()

belief_gb_df.sort_values(by=["search_time_limit"], inplace=True)

print("Ungrouped size =", len(belief_df))
print("Grouped size =", len(belief_gb_df))

display_df_info(belief_gb_df, display_columns=False)

In [None]:
fig_kwargs = {"figsize": (paper_utils.PAGE_COL_WIDTH*2, 3.7)}
subplot_kwargs = {
    "xlabel": "Step"
}

num_rows = 1
num_cols = 2
fig, axs = plt.subplots(
    nrows=num_rows,
    ncols=num_cols,
    squeeze=True,
    subplot_kw=subplot_kwargs,
    **fig_kwargs,
)

y_lims = [(0.05, 1.05), (-0.02, 1.7)]
y_labels = ["Mean Belief Probability", "Mean Wasserstein Distance"]

paper_utils.plot_expected_belief_stat_by_step(
    belief_gb_df,
    axs[0],
    z_key="search_time_limit",
    y_key_prefix="bayes_accuracy",
    step_limit=env_info.exp_params.env_step_limit,
    other_agent_id=1 if env_info.exp_params.planning_agent_id == 0 else 0,
    y_suffix="mean",
    y_err_suffix="CI",
)
axs[0].set_ylabel(y_labels[0])
axs[0].set_ylim(y_lims[0])

paper_utils.plot_expected_belief_stat_by_step(
    belief_gb_df,
    axs[1],
    z_key="search_time_limit",
    y_key_prefix="action_dist_distance",
    step_limit=env_info.exp_params.env_step_limit,
    other_agent_id=1 if env_info.exp_params.planning_agent_id == 0 else 0,
    y_suffix="mean",
    y_err_suffix="CI",
)
axs[1].set_ylabel(y_labels[1])
axs[1].set_ylim(y_lims[1])

handles, labels = axs[0].get_legend_handles_labels()
fig.legend(handles, labels, ncol=5, loc="lower center")

fig.tight_layout(rect=(0.0, 0.08, 1.0, 1.0))
fig.savefig(osp.join(env_info.figure_dir, f"bayes_accuracy.png"))
    
del fig_kwargs
del subplot_kwargs

## Looking at time

   - search_time
   - update_time
   - reinvigoration_time
   - policy_calls
   - inference_time
   - search_depth

In [None]:
search_time_df = baseline_exp_df[
    ((baseline_exp_df["alg_id"] == "baposgmcp") & (baseline_exp_df["meta_pi"] == best_meta_pi) & (baseline_exp_df["truncated"] == True))
    | ((baseline_exp_df["alg_id"] == "ucbmcp") & (baseline_exp_df["meta_pi"] == best_meta_pi) & (baseline_exp_df["truncated"] == True))
    | ((baseline_exp_df["alg_id"] == "ucbmcp-random") & (baseline_exp_df["truncated"] == False))
]

search_time_policy_prefixes_to_plot = [
    f"baposgmcp_meta{best_meta_pi}",
    "ucbmcp-random",
    f"ucbmcp_meta{best_meta_pi}"
]

display_df_info(search_time_df, display_columns=False)

In [None]:
fig_kwargs = {"figsize": (paper_utils.PAGE_COL_WIDTH, 3.8)}
subplot_kwargs = {
        "xlabel": "Search time (s)"
    }


for (y_key, y_label) in [
    ("episode_steps", "Mean episode steps"),
    ("search_time", "Mean search time"),
    ("evaluation_time", "Mean leaf node evaluation time"),
    ("inference_time", "Mean inference time"),
    ("update_time", "Mean update time"),
    ("reinvigoration_time", "Mean belief reinvigoration time"),
    ("search_depth", "Mean search depth")
]:
    subplot_kwargs["ylabel"] = y_label
    fig, ax = plt.subplots(
        nrows=1,
        ncols=1,
        squeeze=True,
        subplot_kw=subplot_kwargs,
        **fig_kwargs,
    )
        
    paper_utils.plot_performance(
        search_time_df,
        ax=ax,
        x_key="search_time_limit",
        y_key=f"{y_key}_mean",
        y_err_key=f"{y_key}_CI",
        policy_key="policy_id",
        policy_prefixes=search_time_policy_prefixes_to_plot,
        constant_policy_prefixes=[],
        pi_label_map=baseline_pi_label_map,
    )

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, ncol=1, loc="lower right")

    fig.tight_layout(pad=0.1, w_pad=0.8, h_pad=1.0, rect=(0.0, 0.25, 1.0, 1.0))
    fig.savefig(osp.join(env_info.figure_dir, f"{y_key}.png"))
    
del fig_kwargs
del subplot_kwargs