# Analysis of the RL BR policy training

In this notebook we plot the learning curves for the RL-BR policies.

In [None]:
import sys

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from posggym_baselines.config import REPO_DIR

sys.path.insert(0, str(REPO_DIR / "baseline_exps"))
import exp_utils

sns.set_theme()
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

SAVE_RESULTS = False

In [None]:
ALL_ENV_DATA = exp_utils.load_all_env_data()
for k in ALL_ENV_DATA:
    print(k)

## BR Policy Learning Curves

Here we plot the learning curves for the BR policies trained on `P0` and `P1` for each environment. 

We use a separate plot for each population for each environment, and plot each training seed as a separate line.

In [None]:
# This code cleans the raw data downloaded from wandb
# - removing weird names
# - removing the __MIN and __MAX columns
# - updating column names
# - reformating the data to use a seed column instead of separate columns for each seed
for full_env_id, env_data in ALL_ENV_DATA.items():
    for pop_id, results_file in env_data.rl_br_training_results_files.items():
        results_df = pd.read_csv(results_file)
        if "global_step" not in results_df.columns:
            # already cleaned
            continue
        
        columns_to_drop = []
        step_col_to_keep = None
        for col in results_df.columns:
            if col.endswith("__MIN") or col.endswith("__MAX"):
                columns_to_drop.append(col)
            elif col != "global_step" and col.endswith("_step"):
                if step_col_to_keep is None:
                    step_col_to_keep = col
                else:
                    columns_to_drop.append(col)
        results_df = results_df.drop(columns=columns_to_drop)

        column_rename_map = {}
        seeds = []
        for col in results_df.columns:
            if col == "global_step":
                column_rename_map[col] = "step"
            elif col.endswith("_step"):
                column_rename_map[col] = "update"
            else:
                # Format: BR-PPO_<env_id>[_agent_id]_<pop_id>_<seed>_<date>_<time> - policy_stats/BR/mean_episode_return
                # e.g. "BR-PPO_PursuitEvasion-v1_i1_P0_0_20240119_053338 - policy_stats/BR/mean_episode_return"
                assert col.endswith("/BR/mean_episode_return")
                tokens = col.split("_")
                env_token_idx = tokens.index(env_data.env_id)
                if tokens[env_token_idx + 1].startswith("i"):
                    seed = int(tokens[env_token_idx + 3])
                else:
                    seed = int(tokens[env_token_idx + 2])
                column_rename_map[col] = seed
                seeds.append(seed)

        results_df = results_df.rename(columns=column_rename_map)

        results_df = results_df.melt(
            id_vars=["step", "update"], 
            value_vars=seeds, 
            var_name="seed", 
            value_name="mean_episode_return"
        )
        results_df["seed"] = results_df["seed"].astype(int)
        
        results_df.to_csv(str(results_file), index=False)

In [None]:
# Import data
training_results_dfs = []
for full_env_id, env_data in ALL_ENV_DATA.items():
    env_pop_results_df = []
    for pop_id, results_file in env_data.rl_br_training_results_files.items():
        results_df = pd.read_csv(results_file)
        results_df["pop_id"] = pop_id
        env_pop_results_df.append(results_df)
    env_pop_results_df = pd.concat(env_pop_results_df)
    env_pop_results_df["full_env_id"] = full_env_id
    training_results_dfs.append(env_pop_results_df)

training_results_df = pd.concat(training_results_dfs)

full_env_ids = training_results_df["full_env_id"].unique().tolist()
full_env_ids.sort()
print("Full env ids:", full_env_ids)

pop_ids = training_results_df["pop_id"].unique().tolist()
pop_ids.sort()
print("Pop ids:", pop_ids)

seeds = training_results_df["seed"].unique().tolist()
seeds.sort()
print("Seeds:", seeds)

In [None]:
# Plot data
pallete = {s: 'grey' for s in seeds}

for full_env_id, env_data in ALL_ENV_DATA.items():
    fig, axes = plt.subplots(
        nrows=1, 
        ncols=2, 
        figsize=(6, 3), 
        sharey=True,
        sharex=True,
    )
    env_df = training_results_df[training_results_df["full_env_id"] == full_env_id]
    for c, pop_id in enumerate(pop_ids):
        ax = axes[c]
        pop_df = env_df[env_df["pop_id"] == pop_id]
        sns.lineplot(
            data=pop_df,
            x="step",
            y="mean_episode_return",
            hue="seed",
            palette=pallete,
            ax=ax,
            legend=False,
            alpha=0.5,
            linewidth=0.5,
        )

        sns.lineplot(
            data=pop_df,
            x="step",
            y="mean_episode_return",
            ax=ax,
            color=sns.color_palette()[0],
            legend=False,
            errorbar=None,
            linewidth=0.75,
        )

        # Set axis labels
        ax.set_ylabel("Mean Return")
        ax.set_xlabel("Training Step")
        ax.set_title(pop_id)

    fig.tight_layout()
    if SAVE_RESULTS:
        fig.savefig(str(env_data.env_data_dir / "rl_br_learning_curve.pdf"))
    fig.suptitle(full_env_id)
    fig.subplots_adjust(top=0.8)
