In [1]:
import torch

RoPECache = torch.Tensor


"""
My interpretation:
- apply identically across sequence length
- Pair up subsequent dimensions - (0, 1), (2, 3), (4, 5), (6, 7)
- Theta - exponentially decreasing from 1 by a constant (?) factor
- seq_idx - sequence indices
- Multiply each value in theta by each value in seq_idx, outer product
- Each value in this is an angle
- Take cos and sin of each angle

"""

def build_rope_cache(
    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
) -> RoPECache:
    """Enhanced Transformer with Rotary Position Embedding.

    Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
    transformers/rope/__init__.py. MIT License:
    https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
    """
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta).float()

    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.half()
    return cache

In [62]:
seq_len = 100
n_elem = 8
dtype = torch.float32
device = torch.device("mps")
base = 10000

In [63]:
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (10000 ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
theta

tensor([1.0000, 0.1000, 0.0100, 0.0010], device='mps:0')

In [64]:

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

In [65]:

# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()

In [66]:
seq_idx.shape

torch.Size([100])

In [67]:
theta.shape

torch.Size([4])

In [41]:
idx_theta.shape

torch.Size([100, 4])

In [42]:

cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

In [44]:
cache.shape

torch.Size([100, 4, 2])

In [69]:
rope_cache = cache

In [79]:
# Shape: batch, seq_len, n_heads, n_elem/n_heads
x = torch.randn(1, 100, 32, 128)

In [80]:
T = x.size(1)


rope_cache = rope_cache[:T]

In [81]:

# cast because the reference does
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)

In [85]:
rope_cache.shape

torch.Size([100, 4, 2])

In [86]:
xshaped.shape

torch.Size([1, 100, 4, 2])

In [83]:
# 1, 100, 1, 2, 2

rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)


RuntimeError: shape '[1, 100, 1, 2, 2]' is invalid for input of size 800

In [68]:

x_out2 = torch.stack(
    [
        xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
        xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
    ],
    -1,
)

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)