In [None]:
import pandas as pd
from pathlib import Path
import seaborn as sns
import wandb
from matplotlib import pyplot as plt

sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=0.8,
    rc={
        "figure.figsize": (2.0, 3.5),
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "text.usetex": True,
        "lines.linewidth": 0.7,
        "axes.linewidth": 0.7,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
        "pdf.fonttype": 42,
    },
)

In [None]:
data_dir = Path("../data/test_results")
fig_dir = Path("../figures/delay")
fig_dir.mkdir(parents=True, exist_ok=True)

# old models
models = {
    "mlp local": "cearrudq",
    "mlp masked": "phvd8c75",
    "rotary local": "7969mfvs",
    "rotary masked": "khpb9hkx",
}

# Initialize a W&B API object
api = wandb.Api()

runs = {
    name: api.run(f"damowerko-academic/motion-planning/{id}")
    for name, id in models.items()
}
run_metadata = [
    {
        "id": run.id,
        "name": f"{run.config['encoding_type']} {run.config['connected_mask']}",
        "path": "/".join(run.path),
        "encoding_type": run.config["encoding_type"],
        "attention_window": run.config["attention_window"],
        "connected_mask": run.config["connected_mask"],
        "encoding_period": run.config["encoding_period"],
        "encoding_frequencies": run.config["encoding_frequencies"],
    }
    for _, run in runs.items()
    if run.config["encoding_frequencies"] == "linear"
]

In [None]:
dfs = []
for meta in run_metadata:
    df = pd.read_parquet(data_dir / meta["id"] / "delay.parquet")
    for k, v in meta.items():
        df[k] = v
    dfs.append(df)
df = pd.concat(dfs)
df["step"] += 1
df["time_s"] = df["step"]
df.head()

In [None]:
data = (
    df.groupby(
        ["time_s", "delay_s", "connected_mask", "samples_per_cluster", "encoding_type"]
    )["coverage"]
    .agg(["mean", "sem"])
    .reset_index()
)
data["mean_minus_se"] = data["mean"] - data["sem"]
data["mean_plus_se"] = data["mean"] + data["sem"]

for encoding_type in df["encoding_type"].unique():
    print(encoding_type)
    g = sns.relplot(
        data=data.query(f"encoding_type == '{encoding_type}'"),
        x="time_s",
        y="mean",
        hue="delay_s",
        kind="line",
        palette="viridis",
        row="connected_mask",
        col="samples_per_cluster",
        legend="full",
        height=2,
        aspect=1.5,
    )
    g.set_axis_labels("Time (s)", "Coverage")
    g.set_titles(
        template="Connected Mask: {row_name} \n  Cluster Size: {col_name}",
    )
    g.legend.set_title("Delay (s)")
    plt.savefig(fig_dir / f"{encoding_type}.png")
    plt.show()