In [1]:
# LLaMA's RoPE implementation
import torch

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [5]:
import torch
import math

# Example usage
def create_rope_embeddings(seq_len, dim, base=10000):
    """Create RoPE cos/sin embeddings"""
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(seq_len).type_as(inv_freq)
    freqs = torch.einsum("i,j->ij", t, inv_freq)  # [seq_len, dim//2]
    emb = torch.cat((freqs, freqs), dim=-1)  # [seq_len, dim]
    cos = emb.cos()[None, :, None, :]  # [1, seq_len, 1, dim]
    sin = emb.sin()[None, :, None, :]  # [1, seq_len, 1, dim]
    return cos, sin
# Demo
batch_size, num_heads, seq_len, head_dim = 4, 2, 4, 4

# Create query and key tensors
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)

# Create RoPE embeddings
cos, sin = create_rope_embeddings(seq_len, head_dim)
position_ids = torch.arange(seq_len)

print("Original shapes:")
print(f"q: {q.shape}, k: {k.shape}")
print(f"cos: {cos.shape}, sin: {sin.shape}")

# Apply RoPE
q_rope, k_rope = apply_rotary_pos_emb(q, k, cos, sin, position_ids)

print(f"\nAfter RoPE:")
print(f"q_rope: {q_rope.shape}, k_rope: {k_rope.shape}")


Original shapes:
q: torch.Size([4, 2, 4, 4]), k: torch.Size([4, 2, 4, 4])
cos: torch.Size([1, 4, 1, 4]), sin: torch.Size([1, 4, 1, 4])

After RoPE:
q_rope: torch.Size([4, 2, 4, 4]), k_rope: torch.Size([4, 2, 4, 4])
