In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# set theme
sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=0.8,
    rc={
        "figure.figsize": (3.5, 2.0),
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "text.usetex": True,
        "lines.linewidth": 0.7,
        "axes.linewidth": 0.7,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
        "pdf.fonttype": 42,
        "figure.autolayout": True,
    },
)


fig_path = Path("../figures")
data_path = Path("../data")

models = {
    "1zgs74or": "GNN IL",
    "lo49pixb": "GNN RL",
    "jwtdsmlx": "Transformer IL (global)",
    "cbhe2s17": "Transformer IL (local)",
    "c": "Centralized (LSAP)",
    "capt": "Centralized (CAPT)",
    "d0": "Decentralized (0-Hop)",
    "d1": "Decentralized (1-Hop)",
}

# Load Data

In [None]:
dfs = []
for model_name in models.keys():
    path = data_path / "test_results" / model_name
    if not path.is_dir():
        continue
    data = pd.read_parquet(path / f"{model_name}.parquet")
    dfs.append(data.assign(policy=models[path.name]))
df = pd.concat(dfs)
df.head()

# Plot

In [None]:
# Create the line plot
plt.figure(figsize=(7, 4))
sns.lineplot(
    data=df.query("step < 100"),
    x="step",
    y="coverage",
    hue="policy",
    # errorbar="sd",
)
plt.xlabel("Step")
plt.ylabel("Coverage")

# Save the plot
plt.savefig(fig_path / "coverage_comparison.png")
plt.savefig(fig_path / "coverage_comparison.pdf")
plt.show()

In [None]:
# count the number of collisions
# divide by numer of agents
df.groupby(["policy", "trial"])[["collisions", "near_collisions"]].sum().groupby(
    "policy"
).mean()