In [13]:
import torch

In [14]:
B = 2 # batch_sz
T = 4 # block_sz (not exactly, i know, but _anyways_)
embd_sz = 512 
nh = 4 # n_heads
hs = int(embd_sz / nh) # head_sz -> 512 / n_heads
hs

128

In [15]:
q = torch.randn((B, nh, T, hs)) # B, nh, T, hs
q.shape

torch.Size([2, 4, 4, 128])

q is an input to the `rope` function along with it's position `m`

In [4]:
x = q.view(B*nh, T, hs//2, 2)
x.shape

torch.Size([8, 4, 64, 2])

In [5]:
theta_i = 10_000 ** ((-2 * torch.arange(0, hs//2)) / hs) # one theta value for each 2d pair (hs//2 pairs)
theta_i.shape, theta_i

(torch.Size([64]),
 tensor([1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
         4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
         1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
         7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
         3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
         1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
         5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
         2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
         1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
         4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
         1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04]))

2D vectors can be rotated using the typical `2x2` matrix, but there's a more elegant way to do this in the complex plane.<br>
Given a 2D vector $[x_1, x_2]$, you can write it as ($x = x_1 + i \cdot x_2$) in complex number notation.

You can then apply the rotation by an angle $\theta_i$ as follows:

$x$ is the complex vector as denoted above. Let $x'$ be the rotated vector.

$x' = x \cdot e^{i\cdot\theta_i}$

where $e^{i\cdot\theta_i} = cos(\theta) + i \cdot sin(\theta)$

In [6]:
pos = torch.arange(T, dtype=torch.float32)
pos.shape

torch.Size([4])

In [7]:
freqs = torch.outer(pos, theta_i)
freqs.shape

torch.Size([4, 64])

In [8]:
rope_cache = torch.stack((torch.cos(freqs), torch.sin(freqs)), dim=-1)
rope_cache.shape

torch.Size([4, 64, 2])

In [11]:
def precompute_rope_naive(dim):
    theta_i = torch.tensor([10000 ** ((-2*i) / dim) for i in range(0, dim//2)]) # (dim//2)
    rope_cache = torch.zeros(T, hs//2, 2)
    for pos in range(T):
        rope_cache[pos] = torch.stack((torch.cos(pos * theta_i), torch.sin(pos * theta_i)), dim=1) # hs//2, 2
    return rope_cache

In [12]:
torch.allclose(precompute_rope_naive(hs), rope_cache)

True

In [14]:
def precompute_rope_vectorized(dim):
    theta_i = 10000 ** ((-2 * torch.arange(0, dim//2, dtype=torch.float32)) / dim)
    pos = torch.arange(T, dtype=torch.float32)
    freqs = torch.outer(pos, theta_i)
    return torch.stack((torch.cos(freqs), torch.sin(freqs)), dim=-1)

In [15]:
torch.allclose(precompute_rope_naive(hs), precompute_rope_vectorized(hs))

True

In [13]:
%timeit precompute_rope_naive(hs)

170 µs ± 8.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
%timeit precompute_rope_vectorized(hs)

58.3 µs ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
x = torch.view_as_complex(x)
x.shape

torch.Size([8, 4, 64])

In [18]:
torch.view_as_complex(rope_cache).shape

torch.Size([4, 64])

In [19]:
rotd_x = x * torch.view_as_complex(rope_cache).unsqueeze(0)
rotd_x.shape

torch.Size([8, 4, 64])

In [20]:
rotd_x = torch.view_as_real(rotd_x)
rotd_x.shape

torch.Size([8, 4, 64, 2])

In [21]:
rotd_q = rotd_x.view(B, nh, T, hs)
rotd_q.shape

torch.Size([2, 4, 4, 128])

Reference implementation from [rotary-embedding-torch](https://github.com/lucidrains/rotary-embedding-torch):

In [17]:
from rotary_embedding_torch import RotaryEmbedding

rotary_emb = RotaryEmbedding(dim=hs)

In [18]:
test_rotd_q = rotary_emb.rotate_queries_or_keys(q)
test_rotd_q.shape

torch.Size([2, 4, 4, 128])

In [24]:
torch.allclose(rotd_q, test_rotd_q)

True

Seems like our implementation gives the same results as the reference one by lucidrains.

References:

1. [You could have designed state of the art positional encoding](https://huggingface.co/blog/designing-positional-encoding)
2. [Rotary Embeddings: A Relative Revolution](https://blog.eleuther.ai/rotary-embeddings/)
