# RoPE

$$
x=[x^{(0)},x^{(1)},...,x^{(|D|-1)}]
$$
$$
f_{rope}([x^{(2d)},x^{2d+1}]^T)=\begin{pmatrix}  \cos m\theta_d & -\sin m\theta_d) \\  \sin m \theta_d &  \cos m \theta_d \end{pmatrix}\begin{pmatrix}  x^{(2d)}  \\  x^{2d+1}    \end{pmatrix}
$$

In [None]:
import torch
from torch import nn

class RoPEEmbedding(nn.Module):
    def __init__(self, head_dim, max_seq_len, base=10000):
        super().__init__()
        assert head_dim % 2==0, "维度必须为偶数"

        self.head_dim=head_dim
        self.max_seq_len=max_seq_len
        self.base=base

        # 计算  theta = 1 / (base^(2i / head_dim))
        theta=1.0 / (base**(torch.range(0, head_dim, 2).float() / head_dim)) 
        
        pos_ids=torch.arrange(max_seq_len)
        freqs=pos_ids * theta
        sin = torch.sin(freq)
        cos = torch.cos(freq)
        self.register_buffer('sin_table', sin)  # [max_seq_len, head_dim/2]
        self.register_buffer('cos_table', cos)  # [max_seq_len, head_dim/2]

    def forward(self, x, offset=0):
        _, _, seq_len, _=x.shape # [batch_size, num_heads, seq_len, head_dim]

        sin=self.sin_table[offset:seq_len+offset]
        cos=self.cos_table[offset:seq_len+offset]

        x1=x[..., 0::2] # [batch_size, num_heads, seq_len, head_dim//2]
        x2=x[..., 1::2]
        rotated_x1=x1*cos - x2*sin
        rotated_x2=x2*cos + x1*sin
        # 使用 stack 和 flatten/reshape 来高效地交错合并
        # 1. 堆叠: [batch_size, num_heads, seq_len, head_dim / 2, 2]
        # 2. 展平: [batch_size, num_heads, seq_len, head_dim]        
        rotated_x = torch.stack((rotated_x1, rotated_x2), dim=-1).flatten(-2)
        return rotated_x