In [None]:
import pandas as pd
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=0.8,
    rc={
        "figure.figsize": (2.0, 3.5),
        "figure.dpi": 100,
        "savefig.dpi": 300,
        "text.usetex": True,
        "lines.linewidth": 0.7,
        "axes.linewidth": 0.7,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
        "pdf.fonttype": 42,
    },
)

In [None]:
# mlp: 72l28hqh
# absolute: xmyposrs
# rotary: qxqdo1vd
# rotary window: cbhe2s17
# rotary small window: ba1rzptc

data_dir = Path("../data/test_results")

models = {
    "mlp": "72l28hqh",
    "absolute": "xmyposrs",
    "rotary": "qxqdo1vd",
    "rotary window": "cbhe2s17",
    "rotary small window": "ba1rzptc",
}

for i, model_id in enumerate(models.values()):
    print(
        f"python ./scripts/test.py delay --checkpoint wandb://damowerko-academic/motion-planning/{model_id}"
    )

In [None]:
dfs = []
for name, id in models.items():
    df = pd.read_parquet(data_dir / id / "delay.parquet")
    df["model"] = name
    df["model_id"] = id
    dfs.append(df)
df = pd.concat(dfs)
df["step"] += 1
df["time_s"] = df["step"] * 0.1
df.head()

In [None]:
df["delay_ms"] = (df["delay_s"] * 1000).astype(int)
errorbar = "se"

In [None]:
# sns.relplot(
#     data=df[df["delay_ms"].isin([0, 20, 25, 50, 100])],
#     x="time_s",
#     y="coverage",
#     hue="delay_ms",
#     kind="line",
#     errorbar=errorbar,
#     palette="viridis",
#     row="model",
#     col="n_agents",
# )
# plt.show()

In [None]:
sns.relplot(
    data=df[
        df["delay_ms"].isin([0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
    ],
    x="time_s",
    y="coverage",
    hue="delay_ms",
    kind="line",
    errorbar=errorbar,
    palette="viridis",
    row="model",
    # col="n_agents",
)
plt.show()

In [None]:
sns.relplot(
    data=df[df["time_s"].isin([5.0, 10.0, 15.0, 20.0])],
    x="delay_ms",
    y="coverage",
    hue="time_s",
    errorbar=errorbar,
    kind="line",
    palette="viridis",
    row="model",
    # col="n_agents",
)
plt.show()