In [19]:
import torch
from typing import Tuple, Literal

interleaved:
 - q'(2i-1) = q(2i-1) * cos(phi_i) − q_(2i) * sin(phi_i)
 - q'(2i) = q(2i-1) * sin(phi_i) + q_(2i) * cos(phi_i)


half:
 - q'i = q_i * cos(phi_i) − q(i+h) * sin(phi_i)
 - q'(i+h) = q_i * sin(phi_i) + q(i+h) * cos(phi_i)

In [20]:
def build_rope_cache(seq_len, dim, base=10000.0, layout="interleaved", device=None, dtype=torch.float32):
    assert dim % 2 == 0, "dimension must be even"
    # Frequencies per pair: inv_freq[i] = base^(-(2i)/D)
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))  # (D/2,)
    pos = torch.arange(seq_len, device=device, dtype=torch.float32)                    # (S,)
    angles = torch.outer(pos, inv_freq) # angles[m, i] = pos[m] * inv_freq[i] (S, D/2)


    # B: batch size, S: sequence length, H: head, D: dimension
    if layout == "interleaved":
        cos = angles.cos()[None, :, None, :].to(dtype)  # (1, S, 1, D/2)
        sin = angles.sin()[None, :, None, :].to(dtype)  # (1, S, 1, D/2)
    elif layout == "half":
        angles2 = torch.cat([angles, angles], dim=-1)   # (S, D)
        cos = angles2.cos()[None, :, None, :].to(dtype) # (1, S, 1, D)
        sin = angles2.sin()[None, :, None, :].to(dtype) # (1, S, 1, D)
    else:
        raise ValueError("layout must be 'interleaved' or 'half'.")

    return cos, sin

def apply_rope(x, cos_cached, sin_cached, layout="interleaved"):
    cos = cos_cached.to(device=x.device, dtype=x.dtype)
    sin = sin_cached.to(device=x.device, dtype=x.dtype)

    if layout == "interleaved":
        # Split into even/odd lanes, rotate, then interleave back
        x_even, x_odd = x[..., ::2], x[..., 1::2]       # (B,S,H,D/2) each
        x_even_r = x_even * cos - x_odd * sin
        x_odd_r  = x_even * sin + x_odd * cos
        out = torch.empty_like(x)
        out[..., ::2] = x_even_r
        out[..., 1::2] = x_odd_r
        return out
    elif layout == "half":
        # Rotate via [-back | front] trick
        h = x.shape[-1] // 2
        return x * cos + torch.cat([-x[..., h:], x[..., :h]], dim=-1) * sin
    else:
        raise ValueError("layout must be 'interleaved' or 'half'.")

In [21]:
# ----- fixed permutation to align "half" with "interleaved" -----
def to_even_odd_first(x):
    D = x.shape[-1]
    idx = torch.cat([torch.arange(0, D, 2), torch.arange(1, D, 2)], 0).to(x.device)
    return x[..., idx]

def from_even_odd_first(x):
    D = x.shape[-1]
    idx = torch.cat([torch.arange(0, D, 2), torch.arange(1, D, 2)], 0).to(x.device)
    inv = torch.empty_like(idx); inv[idx] = torch.arange(D, device=x.device)
    return x[..., inv]

# ====== tiny test ======
torch.manual_seed(0)
B, S, H, D = 2, 16, 8, 64
q = torch.randn(B, S, H, D)
k = torch.randn(B, S, H, D)

# 1) interleaved path
cos_i, sin_i = build_rope_cache(S, D, layout="interleaved", dtype=q.dtype, device=q.device)
q_i = apply_rope(q, cos_i, sin_i, "interleaved")
k_i = apply_rope(k, cos_i, sin_i, "interleaved")
scores_i = torch.einsum("b s h d, b t h d -> b h s t", q_i, k_i)

# 2) half path + permutation alignment
qP, kP = to_even_odd_first(q), to_even_odd_first(k)
cos_h, sin_h = build_rope_cache(S, D, layout="half", dtype=q.dtype, device=q.device)
# q_h = apply_rope(q, cos_h, sin_h, "half")
# k_h = apply_rope(k, cos_h, sin_h, "half")
q_h = from_even_odd_first(apply_rope(qP, cos_h, sin_h, "half"))
k_h = from_even_odd_first(apply_rope(kP, cos_h, sin_h, "half"))
scores_h = torch.einsum("b s h d, b t h d -> b h s t", q_h, k_h)

# compare
torch.testing.assert_close(scores_i, scores_h, atol=1e-6, rtol=1e-6)
print("Max abs diff:", (scores_i - scores_h).abs().max().item())
print("✅ Attention scores match (within numerical tolerance).")

Max abs diff: 0.0
✅ Attention scores match (within numerical tolerance).
