In [None]:
import pandas as pd
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
import wandb

sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=0.8,
    rc={
        "figure.figsize": (2.0, 3.5),
        "figure.dpi": 100,
        "savefig.dpi": 300,
        "text.usetex": True,
        "lines.linewidth": 0.7,
        "axes.linewidth": 0.7,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
        "pdf.fonttype": 42,
    },
)

In [None]:
# old models
# models = {
#     "mlp": "72l28hqh",
#     "absolute": "xmyposrs",
#     "rotary": "qxqdo1vd",
#     "rotary window": "cbhe2s17",
#     "rotary small window": "ba1rzptc",
# }

# Initialize a W&B API object
api = wandb.Api()

runs = api.runs(
    path="damowerko-academic/motion-planning",
    filters={
        "tags": "compare-encoding-4",
        "state": "finished",
    },
    order="-created_at",
)
run_metadata = [
    {
        "id": run.id,
        "name": f"{run.config['encoding_type']} {run.config['attention_window']}",
        "path": "/".join(run.path),
        "encoding_type": run.config["encoding_type"],
        "attention_window": run.config["attention_window"],
        "encoding_period": run.config["encoding_period"],
        "encoding_frequencies": run.config["encoding_frequencies"],
    }
    for run in runs
    if run.config["encoding_frequencies"] == "linear"
]

data_dir = Path("../data/test_results")
for meta in run_metadata:
    model_id = meta["id"]
    print(
        f"python ./scripts/test.py delay --checkpoint wandb://damowerko-academic/motion-planning/{model_id} &"
    )

In [None]:
dfs = []
for meta in run_metadata:
    df = pd.read_parquet(data_dir / meta["id"] / "delay.parquet")
    for k, v in meta.items():
        df[k] = v
    dfs.append(df)
df = pd.concat(dfs)
df["step"] += 1
df["time_s"] = df["step"]
df.head()

In [None]:
sns.relplot(
    data=df,
    x="time_s",
    y="coverage",
    hue="delay_s",
    kind="line",
    errorbar="se",
    palette="viridis",
    row="encoding_type",
    col="attention_window",
)
plt.show()

In [None]:
sns.relplot(
    data=df[df["time_s"].isin([5.0, 10.0, 15.0, 20.0])],
    x="delay_s",
    y="coverage",
    hue="time_s",
    errorbar="se",
    kind="line",
    palette="viridis",
    col="encoding_type",
    row="n_agents",
)
plt.show()