In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
#生成旋转频率
def get_rotary_frequency(dim: int, seq_len: int, theta_base: float = 10000.0) -> torch.Tensor:
    """生成RoPE的旋转频率

    Args:
        dim (int):  嵌入维度 必须是偶数
        seq_len (int): 序列长度
        theta_base (float, optional): 基础频率参数. Defaults to 10000.0

    Returns:
        freq: shape (seq_len, dim//2)
    """
    assert dim % 2 == 0
    i = torch.arange(0, dim//2, dtype=torch.float32)
    freqs = theta_base ** (-2 * i / dim)  # 频率 \theta_i shape (dim//2,)
    
    #位置索引
    positions = torch.arange(seq_len, dtype=torch.float32)
    angles = torch.outer(positions, freqs) #shape (seq_len, dim//2)
    return angles    

In [3]:
dim = 64
seq_len = 128
angles = get_rotary_frequency(dim, seq_len)
print(f"Angles shape: {angles.shape}")  # (128, 32)
print(f"Angles[0]: {angles[0][:5]}")    # 位置 0 的前 5 个维度对的角度
print(f"Angles[1]: {angles[1][:5]}")    # 位置 1 的前 5 个维度对的角度

Angles shape: torch.Size([128, 32])
Angles[0]: tensor([0., 0., 0., 0., 0.])
Angles[1]: tensor([1.0000, 0.7499, 0.5623, 0.4217, 0.3162])


In [None]:
# 构建sin/cos缓存
def get_rotary_embedding(dim:int, seq_len:int, theta:float=10000.0):
    """
    预计算 RoPE 的 sin 和 cos 值

    Returns:
        cos: shape (seq_len, dim)
        sin: shape (seq_len, dim)
    """
    angles = get_rotary_frequency(dim, seq_len, theta)
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    #将cos/sin重复 这样就是按前后两半一一匹配
    cos = torch.cat([cos, cos], dim=-1)  
    sin = torch.cat([sin, sin], dim=-1)
    return cos, sin    

In [7]:
cos, sin = get_rotary_embedding(dim, seq_len)
cos.shape, sin.shape, cos[1][0:32], cos[1][32:64]

(torch.Size([128, 64]),
 torch.Size([128, 64]),
 tensor([0.5403, 0.7318, 0.8460, 0.9124, 0.9504, 0.9720, 0.9842, 0.9911, 0.9950,
         0.9972, 0.9984, 0.9991, 0.9995, 0.9997, 0.9998, 0.9999, 0.9999, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000]),
 tensor([0.5403, 0.7318, 0.8460, 0.9124, 0.9504, 0.9720, 0.9842, 0.9911, 0.9950,
         0.9972, 0.9984, 0.9991, 0.9995, 0.9997, 0.9998, 0.9999, 0.9999, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000]))

In [16]:
#应用旋转变换：
def rotate_half(x):
    """
    将向量的前半部分和后半部分交换，并对后半部分取负
    [x1, x2, x3, x4] -> [-x3, -x4, x1, x2]
    """
    x1 = x[..., :x.shape[-1]//2]
    x2 = x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim]
    sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [17]:
#完整RoPE模块：
class RoPE(nn.Module):
    def __init__(self, dim:int, max_seq_len:int, theta:float = 10000.0 ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.theta = theta

        #预计算和缓存的cos和sin
        cos, sin = get_rotary_embedding(dim, max_seq_len, theta)
        self.register_buffer('cos_cached', cos)
        self.register_buffer('sin_cached', sin)
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, position: torch.Tensor = None):
        """
        对 Query 和 Key 应用 RoPE

        Args:
            q: Query，shape (batch, seq_len, num_heads, head_dim)
            k: Key，shape (batch, seq_len, num_heads, head_dim)
            positions: 位置索引，默认为 [0, 1, 2, ..., seq_len-1]

        Returns:
            q_rot, k_rot: 旋转后的 Query 和 Key
        """
        seq_len = q.shape[1]
        cos = self.cos_cached[:seq_len]
        sin = self.sin_cached[:seq_len]
        q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
        return q_rot, k_rot

In [18]:
# 测试
rope = RoPE(dim=64, max_seq_len=4096)

# 模拟输入
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64

q = torch.randn(batch_size, seq_len, num_heads, head_dim)
k = torch.randn(batch_size, seq_len, num_heads, head_dim)

q_rot, k_rot = rope(q, k)
print(f"Q_rot shape: {q_rot.shape}")
print(f"K_rot shape: {k_rot.shape}")

Q_rot shape: torch.Size([2, 128, 8, 64])
K_rot shape: torch.Size([2, 128, 8, 64])


In [None]:
def verify_relative_position_invariance():
    """
    验证 RoPE 的相对位置不变性
    """
    dim = 64
    max_seq_len = 100

    # 预计算 cos/sin
    cos, sin = get_rotary_embedding(dim, max_seq_len)

    # 创建两个相同的向量
    torch.manual_seed(42)
    q = torch.randn(1, 1, 1, dim)
    k = torch.randn(1, 1, 1, dim)

    # 场景 1：q 在位置 0，k 在位置 5（相对位置 = 5）
    cos1_q, sin1_q = cos[0:1], sin[0:1]
    cos1_k, sin1_k = cos[5:6], sin[5:6]

    q1_rot, _ = apply_rotary_pos_emb(q, q, cos1_q, sin1_q)
    _, k1_rot = apply_rotary_pos_emb(k, k, cos1_k, sin1_k)

    dot_product_1 = (q1_rot * k1_rot).sum()

    # 场景 2：q 在位置 10，k 在位置 15（相对位置仍然是 5）
    cos2_q, sin2_q = cos[10:11], sin[10:11]
    cos2_k, sin2_k = cos[15:16], sin[15:16]

    q2_rot, _ = apply_rotary_pos_emb(q, q, cos2_q, sin2_q)
    _, k2_rot = apply_rotary_pos_emb(k, k, cos2_k, sin2_k)

    dot_product_2 = (q2_rot * k2_rot).sum()

    print(f"位置 (0, 5) 的内积: {dot_product_1.item():.6f}")
    print(f"位置 (10, 15) 的内积: {dot_product_2.item():.6f}")
    print(f"差异: {abs(dot_product_1.item() - dot_product_2.item()):.10f}")
    print("验证通过！" if abs(dot_product_1.item() - dot_product_2.item()) < 1e-5 else "验证失败！")

verify_relative_position_invariance()

位置 (0, 5) 的内积: 5.536925
位置 (10, 15) 的内积: 5.536925
差异: 0.0000004768
验证通过！
