In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


# set theme
sns.set_theme(
    context="paper",
    style="ticks",
    # font_scale=1.0,
    rc={
        "figure.figsize": (3.5, 2.5),
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "figure.autolayout": True,
        "text.usetex": True,
        "lines.linewidth": 0.8,
        "axes.linewidth": 0.8,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
    },
)

fig_path = Path("../figures")

# Comparison
Compare the performance in terms of coverage and number of collisions.

In [None]:
models = {
    "c": "LSAP",
    "1zgs74or": "GNN (IL)",
    "d0": "1-Hop",
    "capt": "CAPT",
    "j0pmfvt9": "GNN (RL)",
    "d1": "2-Hop",
}
df = pd.concat(
    [
        pd.read_parquet(f"../data/test_results/{key}/{key}.parquet").assign(
            policy=policy, time=lambda x: x["step"] * 0.1
        )
        for key, policy in models.items()
    ]
)

In [None]:
# Create the line plot
plt.figure(figsize=(3.5, 3))
g = sns.lineplot(
    data=df,
    x="time",
    y="coverage",
    hue="policy",
    errorbar="sd",
    palette="deep",
    linewidth=0.8,
)
sns.move_legend(g, "lower right", ncol=2, title="Policy")
plt.ylim(0, 1.02)
plt.xlabel("$t$")
plt.ylabel("$c(t)$")
# plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, frameon=False)
plt.savefig(fig_path / "coverage_comparison.pdf")
plt.show()

# Scalability Experiments
Vary the number of agents and their density. Report the discounted coverage. 

In [None]:
gamma = 0.99

df = pd.read_parquet("../data/scalability.parquet")
df["density"] = (df["n_agents"] / df["area"]).apply(lambda d: f"{d:.1f}")
df["discount"] = gamma ** df["step"]

data: pd.DataFrame = (
    df.groupby(["n_agents", "density", "trial"])
    .apply(
        lambda df: np.average(df["coverage"], weights=df["discount"]),
        include_groups=False,
    )
    .rename("discounted_coverage")  # type: ignore
    .to_frame()
)

sns.relplot(
    data=data,
    x="n_agents",
    y="discounted_coverage",
    hue="density",
    style="density",
    markers="o",
    kind="line",
    errorbar="sd",
    aspect=1.5,
)
# plt.ylim(0, 1)
plt.show()
df.groupby("n_agents")[["coverage", "collisions"]].ag