In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
from motion_planning.architecture.transformer import (
    linear_frequencies,
    geometric_frequencies,
)
import torch

period = 10
n_frequencies = 16

In [None]:
frequencies = linear_frequencies(period, n_frequencies)
plt.figure(figsize=(20, 5))
for i in range(len(frequencies)):
    t = torch.linspace(0, period, 1000)
    plt.plot(t.numpy(), torch.cos(frequencies[i] * t).numpy(), label=i)
plt.legend()
plt.title(f"Linear Frequencies (Period {period}, {n_frequencies} Frequencies)")
plt.show()

In [None]:
frequencies = geometric_frequencies(period, n_frequencies)
plt.figure(figsize=(20, 5))
for i in range(len(frequencies)):
    t = torch.linspace(0, period, 1000)
    plt.plot(t.numpy(), torch.cos(frequencies[i] * t).numpy(), label=i)
plt.legend()
plt.title(f"Geometric Frequencies (Period {period}, {n_frequencies} Frequencies)")
plt.show()

In [None]:
from motion_planning.architecture.transformer import (
    RotaryPositionalEncoding,
    AbsolutePositionalEncoding,
)
from itertools import product

embed_dim = 128
period = 10
n_dimensions = 1
encodings = [AbsolutePositionalEncoding, RotaryPositionalEncoding]
frequency_generators = ["linear", "geometric"]


def plot_encoding(y):
    plt.figure()
    plt.imshow(y.squeeze().detach().numpy())
    plt.xlabel("Embedding Dimension")
    plt.ylabel("Position")
    plt.colorbar()


for encoding_cls, frequency_generator in product(encodings, frequency_generators):
    encoding = encoding_cls(embed_dim, period, n_dimensions, frequency_generator)
    x = torch.ones(100, embed_dim)
    pos = torch.linspace(0, period, 100).unsqueeze(-1)
    if isinstance(encoding, AbsolutePositionalEncoding):
        y = encoding(pos)
    elif isinstance(encoding, RotaryPositionalEncoding):
        y = encoding(x, pos)

    plot_encoding(y)
    plt.title(f"{type(encoding).__name__} | {frequency_generator}")
    plt.show()

In [None]:
encoding = AbsolutePositionalEncoding(embed_dim, period, n_dimensions, "geometric")
x = torch.ones(100, embed_dim)
y = x + encoding(pos)

for i in range(0, 100, 25):
    A = (y[i, :] * y).reshape(-1, embed_dim // 2, 2).sum(-1)
    plt.figure()
    plt.imshow(A.detach().numpy())
    plt.title(f"Product of {i}th embedding with all embeddings")
    plt.ylabel("Position")
    plt.xlabel("Embedding Dimension")

In [None]:
encoding = RotaryPositionalEncoding(embed_dim, period, n_dimensions, "geometric")
x = torch.ones(100, embed_dim)
y = encoding(x, pos)

for i in range(0, 100, 25):
    A = (y[i, :] * y).reshape(-1, embed_dim // 2, 2).sum(-1)
    plt.figure()
    plt.imshow(A.detach().numpy())
    plt.title(f"Product of {i}th embedding with all embeddings")
    plt.ylabel("Position")
    plt.xlabel("Embedding Dimension")