# Analysis of Populations Diversity

In this notebook we explore the diversity of the populations for each environment.

In [None]:
import sys
from typing import Optional, List

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from posggym_baselines.config import REPO_DIR

sys.path.insert(0, str(REPO_DIR / "baseline_exps"))
import exp_utils

sns.set_theme()
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

SAVE_RESULTS = False

In [None]:
ALL_ENV_DATA = exp_utils.load_all_env_data()
for k in ALL_ENV_DATA:
    print(k)

NUM_ENVS = len(ALL_ENV_DATA)


## Population Diversity

The diversity of a population of policies is calculated by taking the mean of the pairwise distances between the returns of all policies in the population.  The Euclidean distance is used for the distance metric, as it captures the magnitude of the difference between the returns of each policy against each other policy.

In [None]:
def policy_pop_id(row, full_env_id):
    if row["policy_name"] in ALL_ENV_DATA[full_env_id].pop_policy_names["P0"]:
        return "P0"
    if row["policy_name"] in ALL_ENV_DATA[full_env_id].pop_policy_names["P1"]:
        return "P1"
    # not used in exps
    return "P2"

def co_team_name(row):
    co_team_id = row["co_team_id"].replace("(", "").replace(")", "")
    return co_team_id

def co_team_pop_id(row, full_env_id):
    if row["co_team_name"] in ALL_ENV_DATA[full_env_id].pop_co_team_names["P0"]:
        return "P0"
    if row["co_team_name"] in ALL_ENV_DATA[full_env_id].pop_co_team_names["P1"]:
        return "P1"
    # not used in exps
    return "P2"


div_results_dfs = []
for full_env_id, env_data in ALL_ENV_DATA.items():
    print(full_env_id)
    env_div_df = pd.read_csv(env_data.pop_div_results_file)
    env_div_df["full_env_id"] = full_env_id
    env_div_df["agent_id"] = env_div_df["agent_id"].astype(str)
    # drop "Random-v0" policy
    env_div_df = env_div_df[env_div_df["policy_id"] != "Random-v0"]
    env_div_df = env_div_df[env_div_df["co_team_id"] != "(Random)"]

    env_div_df["co_team_name"] = env_div_df.apply(
        lambda row: co_team_name(row), axis=1
    )
    env_div_df["policy_pop_id"] = env_div_df.apply(
        lambda row: policy_pop_id(row, full_env_id), axis=1
    )
    env_div_df["co_team_pop_id"] = env_div_df.apply(
        lambda row: co_team_pop_id(row, full_env_id), axis=1
    )

    if full_env_id == "Driving-v1":
        env_div_df["policy_name"] = env_div_df["policy_name"].apply(
            lambda x: x.replace("Shortestpath", "")
        )
        env_div_df["co_team_name"] = env_div_df["co_team_name"].apply(
            lambda x: x.replace("Shortestpath", "")
        )

    # drop unused rows
    env_div_df = env_div_df[env_div_df["policy_pop_id"] != "P2"]
    env_div_df = env_div_df[env_div_df["co_team_pop_id"] != "P2"]

    # average over any duplicate rows
    env_div_df = env_div_df.groupby([
        "env_id",
        "full_env_id", 
        "policy_name",
        "co_team_name",
        "policy_pop_id", 
        "co_team_pop_id",
        "agent_id",
    ]).agg(
        {"episode_reward_mean": "mean"}
    ).reset_index()

    env_min_return = env_div_df["episode_reward_mean"].min()
    env_max_return = env_div_df["episode_reward_mean"].max()
    env_div_df["normalized_episode_reward_mean"] = (
        env_div_df["episode_reward_mean"] - env_min_return
    ) / (env_max_return - env_min_return)

    div_results_dfs.append(env_div_df)

div_results_df = pd.concat(div_results_dfs)


In [None]:
def sort_policy_name_fn(name):
    name = name.split("_")[0]
    if name.startswith("H") or name.startswith("A"):
        return (name[0], int(name[1:]))
    if name.startswith("RL"):
        return (name[0:2], int(name[2:]))
    if name.startswith("KLRBR"):
        return (name[0:5], 1000)
    if name.startswith("KLR"):
        return (name[0:3], int(name[3:]))
    return (name, 0)


def sort_index_fn(x: pd.Index, sort_fn=sort_policy_name_fn):
    x_list = x.to_list()
    x_list.sort(key=sort_fn)
    new_positions = [x_list.index(i) for i in x]
    return pd.Index(new_positions)


def get_env_pw_returns_df(
    full_env_id: str, 
    policy_pop_id: Optional[str] = None, 
    co_team_pop_id: Optional[str] = None,
    values: List[str] = ["episode_reward_mean", "normalized_episode_reward_mean"],
):
    env_df = div_results_df[
        (div_results_df["full_env_id"] == full_env_id) &
        (div_results_df["agent_id"] == ALL_ENV_DATA[full_env_id].agent_id)
    ]
    if policy_pop_id is not None:
        env_df = env_df[env_df["policy_pop_id"] == policy_pop_id]
    if co_team_pop_id is not None:
        env_df = env_df[env_df["co_team_pop_id"] == co_team_pop_id]
    pw_returns_df = env_df.pivot(
        index="policy_name", 
        columns="co_team_name", 
        values=values
    )
    pw_returns_df = pw_returns_df.sort_index(
        axis='columns',
        level=1,
        inplace=False,
        key=sort_index_fn
    )
    pw_returns_df = pw_returns_df.sort_index(
        axis='rows', # type: ignore
        level=0,
        inplace=False,
        key=sort_index_fn
    ) # type: ignore
    return pw_returns_df
    

In [None]:
plot_individual_envs = True

if not plot_individual_envs:
    fig, axes = plt.subplots(
        nrows=NUM_ENVS, 
        ncols=1, 
        figsize=(6, (NUM_ENVS * 4)+2),
        squeeze=False
    )
else:
    fig = None
    axes = None

pop_div_results_df = []
for row, full_env_id in enumerate(ALL_ENV_DATA):
    print(full_env_id)
    pw_returns_df = get_env_pw_returns_df(full_env_id)

    if plot_individual_envs:
        fig, env_axes = plt.subplots(
            nrows=1, 
            ncols=1, 
            figsize=(6, 6),
            squeeze=False
        )
        env_axes = env_axes[0]
    else:
        assert axes is not None
        env_axes = axes[row, :]

    sns.heatmap(
        data=pw_returns_df["episode_reward_mean"],
        ax=env_axes[0],
        annot=True,
        cmap="YlGnBu",
        fmt=".2f",
        square=True,
        annot_kws={"fontsize": 6},

    )
    env_axes[0].set(xlabel="Other Agent Policy", ylabel="Policy")

    if not SAVE_RESULTS:
        env_axes[0].set_title(f"Env: {full_env_id}")

    if plot_individual_envs:
        assert fig is not None
        fig.tight_layout()
        if SAVE_RESULTS:
            fig.savefig(
                str(ALL_ENV_DATA[full_env_id].env_data_dir / "pairwise_returns.pdf"), 
                bbox_inches="tight"
            )


if not plot_individual_envs:
    assert fig is not None
    fig.tight_layout()