In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import wandb
from matplotlib import pyplot as plt
import seaborn as sns

sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=0.8,
    rc={
        "figure.figsize": (3.5, 2.0),
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "text.usetex": True,
        "lines.linewidth": 0.7,
        "axes.linewidth": 0.7,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
        "pdf.fonttype": 42,
    },
)

In [None]:
# Initialize a W&B API object
api = wandb.Api()

runs = api.runs(
    path="damowerko-academic/motion-planning",
    filters={
        "tags": "compare-encoding-2",
        "state": "finished",
    },
    order="-created_at",
)
ids = [run.id for run in runs]
configs = {run.id: run.config for run in runs}
paths = ["/".join(run.path) for run in runs]
# generate commands to test agent generalization
for path in paths:
    print(f"python scripts/test.py test --checkpoint wandb://{path}")
    print(f"python scripts/scalability.py --checkpoint wandb://{path}")

In [None]:
df_test = pd.concat(
    [
        pd.read_parquet(f"../data/test_results/{id}/{id}.parquet").assign(id=id)
        for id in ids
    ]
)
df_scalability = pd.concat(
    [
        pd.read_parquet(f"../data/test_results/{id}/scalability.parquet").assign(id=id)
        for id in ids
    ]
)

In [None]:
# give each run a name
names = {
    id: f"{config['encoding_type']} | {config['encoding_period']} | {config['encoding_frequencies']}"
    for id, config in configs.items()
}
df_scalability = df_scalability.assign(
    name=df_scalability["id"].map(lambda x: names[x]),
    encoding_type=df_scalability["id"].map(lambda x: configs[x]["encoding_type"]),
    encoding_period=df_scalability["id"].map(lambda x: configs[x]["encoding_period"]),
    encoding_frequencies=df_scalability["id"].map(
        lambda x: configs[x]["encoding_frequencies"]
    ),
    attention_window=df_scalability["id"].map(
        lambda x: configs[x].get("attention_window", 0.0)
    ),
)

In [None]:
for period in [10.0, 20.0]:
    sns.relplot(
        data=df_scalability.query(
            f"n_agents in [100, 1000] and encoding_period == {period} and attention_window == 0.0"
        ),
        x="step",
        y="coverage",
        hue="encoding_type",
        col="n_agents",
        row="encoding_frequencies",
        kind="line",
    )
    print(f"Coverage with encoding_period={period}")
    plt.show()

## Ablation with Geometric Frequencies

In [None]:
name_dict = {
    "2n07h71c": "rotary + window",
    "gdgmofog": "rotary",
    "o5mwohl0": "absolute + window",
    "xmyposrs": "absolute",
    "o1cnaiy8": "mlp + window",
    "72l28hqh": "mlp",
}
data = (
    df_scalability[df_scalability["id"].isin(name_dict.keys())]
    .query(f"n_agents in [100, 1000]")
    .assign(name=lambda x: x["id"].map(lambda x: name_dict[x]))
)

sns.relplot(
    data=data,
    x="step",
    y="coverage",
    hue="name",
    col="n_agents",
    kind="line",
)
plt.show()

# Explaining Differences Between Positional Encodings

## Linear vs Geometric Frequency Set

In [None]:
import matplotlib.pyplot as plt
from motion_planning.architecture.transformer import (
    linear_frequencies,
    geometric_frequencies,
)
import torch

period = 10
n_frequencies = 16

In [None]:
frequencies = linear_frequencies(period, n_frequencies)
plt.figure(figsize=(20, 5))
for idx in range(len(frequencies)):
    t = torch.linspace(0, period, 1000)
    plt.plot(t.numpy(), torch.cos(frequencies[idx] * t).numpy(), label=idx)
plt.legend()
plt.title(f"Linear Frequencies (Period {period}, {n_frequencies} Frequencies)")
plt.show()

In [None]:
frequencies = geometric_frequencies(period, n_frequencies)
plt.figure(figsize=(20, 5))
for idx in range(len(frequencies)):
    t = torch.linspace(0, period, 1000)
    plt.plot(t.numpy(), torch.cos(frequencies[idx] * t).numpy(), label=idx)
plt.legend()
plt.title(f"Geometric Frequencies (Period {period}, {n_frequencies} Frequencies)")
plt.show()

# Visualize Absolute vs Rotary Encodings

In [None]:
from motion_planning.architecture.transformer import (
    RotaryPositionalEncoding,
    AbsolutePositionalEncoding,
)
from itertools import product

embed_dim = 64
period = 10
length = 64
n_dimensions = 1
encodings = [AbsolutePositionalEncoding, RotaryPositionalEncoding]
frequency_generators = ["linear", "geometric"]


fig, ax = plt.subplots(len(encodings), len(frequency_generators), figsize=(5, 5))
for i, j in product(range(len(encodings)), range(len(frequency_generators))):
    encoding_cls = encodings[i]
    frequency_generator = frequency_generators[j]

    encoding = encoding_cls(embed_dim, period, n_dimensions, frequency_generator)
    x = torch.ones(length, embed_dim)
    pos = torch.linspace(0, period, length).unsqueeze(-1)
    if isinstance(encoding, AbsolutePositionalEncoding):
        y = encoding(pos)
    elif isinstance(encoding, RotaryPositionalEncoding):
        y = encoding(x, pos)

    ax[i, j].imshow(y.squeeze().detach().numpy())
    ax[i, j].set_title(f"{type(encoding).__name__} : {frequency_generator}")
    if i == 1:
        ax[i, j].set_xlabel("Embedding Dimension")
    if j == 0:
        ax[i, j].set_ylabel("Position")
plt.show()

## Visualizing Products of Keys and Queries
By visualizing the products of keys and queries, we can see that rotary encodings are shift-equivariant, while absolute encodings are not.
In the below example we have a 1D input sequence with length 100. Their positions evenly distributed in $[0, 10]$ and positional encodings with period $10$.
Consider $q_i$ and $k_j$ the query at the $i$-th index and the key at the $j$-th index, respectively. The input signal is simply a signal of all ones.

In [None]:
n_plots = 4
fig, ax = plt.subplots(2, n_plots, figsize=(10, 10))
x = torch.ones(length, embed_dim)


encoding = AbsolutePositionalEncoding(embed_dim, period, n_dimensions, "linear")
y_absolute = x + encoding(pos)
encoding = RotaryPositionalEncoding(embed_dim, period, n_dimensions, "linear")
y_rotary = encoding(x, pos)

for i, idx in enumerate(range(0, embed_dim, embed_dim // n_plots)):
    ax[0, i].set_title(f"Absolute: $q_j \odot k_{{{idx}}}$")
    ax[1, i].set_title(f"Rotary: $q_j \odot k_{{{idx}}}$")
    if i == 0:
        ax[0, i].set_ylabel("j")
        ax[1, i].set_ylabel("j")
    ax[1, i].set_xlabel("Embedding Dimension")

    A = (y_absolute[idx, :] * y_absolute).reshape(-1, embed_dim // 2, 2).sum(-1)
    ax[0, i].imshow(A.detach().numpy())

    A = (y_rotary[idx, :] * y_rotary).reshape(-1, embed_dim // 2, 2).sum(-1)
    ax[1, i].imshow(A.detach().numpy())