In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import wandb
from motion_planning.plot.plot import set_theme_paper

set_theme_paper()
api = wandb.Api()
fig_path = Path("../figures")
data_path = Path("../data/test_results")

In [None]:
# from motion_planning.plot.compare_design import configs_from_tag

# for tag in [
#     "compare-encoding-omniscient",
#     "compare-encoding-local",
#     "compare-encoding-connected-mask",
# ]:
#     for id, config in configs_from_tag(tag).items():
#         print(
#             f"./cluster/run.sh scripts/test.py --checkpoint wandb://damowerko-academic/motion-planning/{id} --n_trials 100 --n_workers 20"
#         )

In [None]:
from motion_planning.plot.data import df_from_tag, aggregate_results

tag = "compare-encoding-connected-mask"
df_basic = aggregate_results(df_from_tag(tag, "basic"))
df_scalability = aggregate_results(df_from_tag(tag, "scalability"))

In [None]:
from motion_planning.plot.plot import plot_scenarios_initialization

fig = plot_scenarios_initialization()
fig

In [None]:
from motion_planning.plot.data import load_baseline, load_test
from motion_planning.plot.plot import plot_scenarios

df_scenarios = pd.concat(
    [
        aggregate_results(load_baseline("d8_sq", "scenarios")),
        aggregate_results(
            load_test("khpb9hkx", "scenarios").assign(policy="TF Masked")
        ),
    ],
    ignore_index=True,
)
plot_scenarios(df_scenarios)

In [None]:
from motion_planning.plot.plot import plot_scenarios_terminal

df_scenarios.query("time == time.max() and policy == 'TF Masked'")

plot_scenarios_terminal(df_scenarios)

In [None]:
from motion_planning.plot.plot import plot_delay_over_time, plot_delay_terminal
from motion_planning.plot.data import load_baselines, load_test

# delays
models = {
    "7969mfvs": "Local Transformer",
    "khpb9hkx": "Masked Transformer",
}
df_delay = pd.concat(
    [
        aggregate_results(load_test(model, "delay")).assign(policy=name)
        for model, name in models.items()
    ],
    ignore_index=True,
).query("delay_s > 0 and delay_s in [0.2, 0.4, 0.6, 0.8, 1.0]")
plot_delay_over_time(df_delay)

In [None]:
plot_delay_terminal(df_delay)

In [None]:
from motion_planning.plot.plot import plot_comparison

baseline_policies = ["LSAP", "DLSAP-0", "DLSAP-4", "DLSAP-8"]
df_baseline = aggregate_results(load_baselines()).query(
    f"policy in {baseline_policies}"
)
models = {
    "8hlpz45j": "TF Clairvoyant",
    "xdbf9fux": "TF Local",
    "o5tb680f": "TF Masked",
}
df_models = pd.concat(
    [
        aggregate_results(load_test(model, "basic")).assign(policy=name)
        for model, name in models.items()
    ]
)
df_decentralized = (
    aggregate_results(load_test("khpb9hkx", "delay"))
    .assign(policy="TF Delayed")
    .query("delay_s == 0.1")
)

df_compare = pd.concat([df_baseline, df_models, df_decentralized], ignore_index=True)

plot_comparison(df_compare)

In [None]:
from motion_planning.plot.plot import (
    plot_encoding_comparison,
    plot_encoding_scalability,
)

display(plot_encoding_comparison(df_basic))
display(plot_encoding_scalability(df_scalability))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse.csgraph import connected_components
import networkx as nx

rng = np.random.RandomState(42)
comm_range = 2.0
attention_window = 3.0

x = np.array(
    [
        [0.5, 5.5],
        [1.5, 4],
        [2.5, 3],
        [4, 1],
        [4.7, 2],
        [5.5, 1.0],
    ]
)
attention_matrix = 70 * rng.rand(6, 6) + 30
dist = np.linalg.norm(x[:, None, :] - x[None, :, :], axis=-1)

comm_mask = dist < comm_range
attention_window = dist < attention_window
components = connected_components(comm_mask)[1]
connected_mask = components[:, None] == components[None, :]
attention_masked = attention_matrix * attention_window * connected_mask

# The masks to a .dat file for PGFPlots to read from.
path = fig_path / "journal" / "architecture" / "attention_data"
path.mkdir(exist_ok=True, parents=True)
np.savetxt(path / "positions.dat", x, fmt="%.2f %.2f")
np.savetxt(path / "comm_mask.dat", comm_mask, fmt="%d")
np.savetxt(path / "attention_window.dat", attention_window, fmt="%d")
np.savetxt(path / "connected_mask.dat", connected_mask, fmt="%d")
np.savetxt(path / "attention_full.dat", attention_matrix, fmt="%d")
np.savetxt(path / "attention_masked.dat", attention_masked, fmt="%d")

f, ax = plt.subplots(1, 5, figsize=(12, 2))
g = nx.from_numpy_array(comm_mask)
g.remove_edges_from(nx.selfloop_edges(g))
nx.draw(g, ax=ax[0], with_labels=False, pos=x, node_size=10)  # type: ignore
ax[1].imshow(attention_matrix, cmap="Blues")
ax[2].imshow(attention_window, cmap="Reds")
ax[3].imshow(connected_mask, cmap="Greens")
ax[4].imshow(attention_masked, cmap="Blues")
plt.show()