In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import wandb
from matplotlib import pyplot as plt
import seaborn as sns
from pathlib import Path

sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=0.8,
    rc={
        "figure.figsize": (3.5, 2.0),
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "text.usetex": True,
        "lines.linewidth": 0.7,
        "axes.linewidth": 0.7,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
        "pdf.fonttype": 42,
    },
)

In [None]:
tag = f"compare-encoding-connected-mask"
save_path = Path(f"../figures/compare-encoding/")
save_path.mkdir(parents=True, exist_ok=True)

In [None]:
api = wandb.Api()
runs = api.runs(
    path="damowerko-academic/motion-planning",
    filters={
        "tags": tag,
        "state": "finished",
    },
    order="-created_at",
)
ids = [run.id for run in runs]
configs = {run.id: run.config for run in runs}
paths = ["/".join(run.path) for run in runs]

In [None]:
import shutil
import tempfile

# Create a temporary directory for the archive
with tempfile.TemporaryDirectory() as temp_dir:
    temp_path = Path(temp_dir)
    # Copy all result directories to temp directory
    paths = []
    for id in ids:
        src = Path(f"../data/test_results/{id}/")
        dst = temp_path / id
        if src.exists():
            shutil.copytree(src, dst)
            paths.append(src)
    # Create zip archive
    shutil.make_archive(f"../data/{tag}", "zip", temp_path)

In [None]:
existing_ids = [id for id in ids if Path(f"../data/test_results/{id}/").exists()]
if len(existing_ids) < len(ids):
    print("WARNING: did not find test results for all runs.")

df_test = pd.concat(
    [
        pd.read_parquet(f"../data/test_results/{id}/{id}.parquet").assign(id=id)
        for id in existing_ids
    ]
)
df_scalability = pd.concat(
    [
        pd.read_parquet(f"../data/test_results/{id}/scalability.parquet").assign(id=id)
        for id in existing_ids
    ]
)

# give each run a name
names = {}
for id, config in configs.items():
    if config["encoding_type"] == "mlp":
        names[id] = f"{config['encoding_type']} : {config['attention_window']} : NA "
    else:
        names[id] = (
            f"{config['encoding_type']} : {config['attention_window']} : {config['encoding_frequencies']}"
        )

df_scalability = df_scalability.assign(
    name=df_scalability["id"].map(lambda x: names[x]),
    encoding_type=df_scalability["id"].map(lambda x: configs[x]["encoding_type"]),
    encoding_period=df_scalability["id"].map(lambda x: configs[x]["encoding_period"]),
    encoding_frequencies=df_scalability["id"].map(
        lambda x: configs[x]["encoding_frequencies"]
    ),
    attention_window=df_scalability["id"].map(
        lambda x: configs[x].get("attention_window", 0.0)
    ),
)

df_scalability = pd.concat(
    [
        df_scalability,
        df_scalability[df_scalability["encoding_type"] == "mlp"].assign(
            encoding_frequencies="geometric"
        ),
    ]
)

In [None]:
df_scalability.query("trial == 0 and step == 1 and n_agents == 100")

In [None]:
# Precompute mean and standard error
df_summary = (
    df_scalability.groupby(
        [
            "time",
            "n_agents",
            "encoding_type",
            "encoding_frequencies",
            "attention_window",
            "name",
        ]
    )
    .agg(
        coverage_mean=("coverage", "mean"),
    )
    .reset_index()
)

In [None]:
g = sns.relplot(
    data=df_summary.query("attention_window in [0, 500, 1000]"),
    x="time",
    y="coverage_mean",
    hue="encoding_type",
    style="encoding_frequencies",
    row="n_agents",
    col="attention_window",
    kind="line",
    height=2,
    aspect=1.0,
)
g.set_axis_labels("Time (s)", "Coverage")
g.set_titles(
    template="Agents: {row_name} $\\mid$  Window: {col_name}",
)
g.legend.get_texts()[0].set_text("Encoding Type")
g.legend.get_texts()[4].set_text("Frequencies")
plt.savefig(save_path / f"{tag}.png")
plt.show()

In [None]:
g = sns.relplot(
    data=df_summary.query(
        "attention_window in [0, 500] and encoding_frequencies == 'linear'"
    ),
    x="time",
    y="coverage_mean",
    hue="encoding_type",
    style="attention_window",
    col="n_agents",
    kind="line",
    errorbar="se",
    height=2,
    aspect=1,
)
g.legend.get_texts()[0].set_text("Encoding Type")
g.legend.get_texts()[4].set_text("Attention Window")

plt.savefig(save_path / f"{tag}-small.png")
plt.show()

# Explaining Differences Between Positional Encodings

## Linear vs Geometric Frequency Set

In [None]:
import matplotlib.pyplot as plt
from motion_planning.architecture.transformer import (
    linear_frequencies,
    geometric_frequencies,
)
import torch

period = 1000
n_frequencies = 16

In [None]:
frequencies = linear_frequencies(period, n_frequencies)
fig, ax = plt.subplots(2, 1, figsize=(20, 10))
for idx in range(len(frequencies)):
    t = torch.linspace(0, period, 1000)
    ax[0].plot(t.numpy(), torch.cos(frequencies[idx] * t).numpy(), label=idx)
ax[0].legend()
ax[0].set_title(f"Linear Frequencies (Period {period}, {n_frequencies} Frequencies)")

frequencies = geometric_frequencies(period, n_frequencies)
for idx in range(len(frequencies)):
    t = torch.linspace(0, period, 1000)
    ax[1].plot(t.numpy(), torch.cos(frequencies[idx] * t).numpy(), label=idx)
ax[1].legend()
ax[1].set_title(f"Geometric Frequencies (Period {period}, {n_frequencies} Frequencies)")
plt.show()

In [None]:
# Plot stem plots of frequencies and periods
fig, ax = plt.subplots(1, 2, figsize=(20, 5))

# Linear frequencies
frequencies = linear_frequencies(period, n_frequencies)
periods = 2 * torch.pi / frequencies
markerline, stemlines, baseline = ax[0].stem(
    range(len(frequencies)),
    frequencies.numpy(),
    label="Linear",
    basefmt=" ",
    linefmt="b-",
    markerfmt="bo",
)
plt.setp(stemlines, "alpha", 0.5)

# Geometric frequencies
frequencies_geom = geometric_frequencies(period, n_frequencies)
periods_geom = 2 * torch.pi / frequencies_geom
markerline, stemlines, baseline = ax[0].stem(
    range(len(frequencies_geom)),
    frequencies_geom.numpy(),
    label="Geometric",
    basefmt=" ",
    linefmt="r-",
    markerfmt="ro",
)
plt.setp(stemlines, "alpha", 0.5)
ax[0].set_xlabel("Index")
ax[0].set_ylabel("Frequency")
ax[0].legend()
ax[0].set_title("Distribution of Frequencies")

# Plot periods
markerline, stemlines, baseline = ax[1].stem(
    range(len(periods)),
    periods.numpy(),
    label="Linear",
    basefmt=" ",
    linefmt="b-",
    markerfmt="bo",
)
plt.setp(stemlines, "alpha", 0.5)
markerline, stemlines, baseline = ax[1].stem(
    range(len(periods_geom)),
    periods_geom.numpy(),
    label="Geometric",
    basefmt=" ",
    linefmt="r-",
    markerfmt="ro",
)
plt.setp(stemlines, "alpha", 0.5)
ax[1].set_xlabel("Index")
ax[1].set_ylabel("Period")
ax[1].legend()
ax[1].set_title("Distribution of Periods")

plt.show()

# Visualize Absolute vs Rotary Encodings

In [None]:
from motion_planning.architecture.transformer import (
    RotaryPositionalEncoding,
    AbsolutePositionalEncoding,
)
from itertools import product

embed_dim = 64
period = 1000
length = 64
n_dimensions = 1
encodings = [AbsolutePositionalEncoding, RotaryPositionalEncoding]
frequency_generators = ["linear", "geometric"]


fig, ax = plt.subplots(len(encodings), len(frequency_generators), figsize=(5, 5))
for i, j in product(range(len(encodings)), range(len(frequency_generators))):
    encoding_cls = encodings[i]
    frequency_generator = frequency_generators[j]

    encoding = encoding_cls(embed_dim, period, n_dimensions, frequency_generator)
    x = torch.ones(length, embed_dim)
    pos = torch.linspace(0, period, length).unsqueeze(-1)
    if isinstance(encoding, AbsolutePositionalEncoding):
        y = encoding(pos)
    elif isinstance(encoding, RotaryPositionalEncoding):
        y = encoding(x, pos)

    ax[i, j].imshow(y.squeeze().detach().numpy())
    ax[i, j].set_title(f"{type(encoding).__name__} : {frequency_generator}")
    if i == 1:
        ax[i, j].set_xlabel("Embedding Dimension")
    if j == 0:
        ax[i, j].set_ylabel("Position")

plt.show()

## Visualizing Products of Keys and Queries
By visualizing the products of keys and queries, we can see that rotary encodings are shift-equivariant, while absolute encodings are not.
In the below example we have a 1D input sequence with length 100. Their positions evenly distributed in $[0, 10]$ and positional encodings with period $10$.
Consider $q_i$ and $k_j$ the query at the $i$-th index and the key at the $j$-th index, respectively. The input signal is simply a signal of all ones.

In [None]:
n_plots = 4
fig, ax = plt.subplots(2, n_plots, figsize=(10, 10))
x = torch.ones(length, embed_dim)


encoding = AbsolutePositionalEncoding(embed_dim, period, n_dimensions, "linear")
y_absolute = x + encoding(pos)
encoding = RotaryPositionalEncoding(embed_dim, period, n_dimensions, "linear")
y_rotary = encoding(x, pos)

for i, idx in enumerate(range(0, embed_dim, embed_dim // n_plots)):
    ax[0, i].set_title(f"Absolute: $q_j \odot k_{{{idx}}}$")
    ax[1, i].set_title(f"Rotary: $q_j \odot k_{{{idx}}}$")
    if i == 0:
        ax[0, i].set_ylabel("j")
        ax[1, i].set_ylabel("j")
    ax[1, i].set_xlabel("Embedding Dimension")

    A = (y_absolute[idx, :] * y_absolute).reshape(-1, embed_dim // 2, 2).sum(-1)
    ax[0, i].imshow(A.detach().numpy())

    A = (y_rotary[idx, :] * y_rotary).reshape(-1, embed_dim // 2, 2).sum(-1)
    ax[1, i].imshow(A.detach().numpy())

plt.savefig(save_path / f"absolute-rotary.png")