In [3]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

### ROPE 代码实现

先计算旋转矩阵，注意这里的旋转矩阵是[m theta1, m theta2, ..., m theta1, m theta2...]这样写法的。 \theta = 10000 ^{-2i/d}

In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base = 10000, precision = torch.half, learnable = False):
        super().__init__()
        # 输出是一个 shape (dim/2, )的向量，注意arange中间隔2其实就是*2，1/就是负号
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        # 如果可学习，每次的前向传播都会重新计算位置编码，同时缓存设为None
        if learnable:
            self.inv_freq = nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        # 不可学习时，inv_freq是预先计算好的常数，并且sin和cos值会被缓存
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None

        self.precision = precision

    def forward(self, x, seq_dim = 1, seq_len = None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        
        if self.max_seq_len_cached is None or self.max_seq_len_cached < seq_len:
            # 如果可学习那就一直设置为None
            self.max_seq_len_cached = None if self.learnable else seq_len
            
            # 这个是为了求外积准备的的seq_len长度的m（等差序列）
            t = torch.arange(seq_len, device=x.device, dtype = self.inv_freq.dtype)
            # 这个就是m \theta
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # 这个就是把dim//2 concat 成我需要的shape，注意这里是整块concat不是逐元素
            # shape (seq_len, dim)
            emb = torch.cat((freqs, freqs), dim = -1).to(x.device)
            # emb要精细计算
            if self.precision == torch.bfloat16:
                emb = emb.float()
            
            # [:, None, :] 是一个升维操作，shape (seq_len, 1, dim)
            # cos(m theta), sin(m theta)
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                # 如果可学习，防止传入
                return cos_cached, sin_cached

            self.cos_cached, self.sin_cached = cos_cached, sin_cached

        return self.cos_cached[:seq_len,...], self.sin_cached[:seq_len,...]
    
    def _apply(self, fn):
        # 确保缓存张量与应用函数一致（如设备移动）
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

In [None]:
# # 假设:
# position_id = [[0, 2],  # 批次0: 位置0, 批次1: 位置2
#               [1, 1]]   # 批次0: 位置1, 批次1: 位置1

# cos = [[cos 0],  # 位置0的cos
#        [cos 1],  # 位置1的cos
#        [cos 2]]  # 位置2的cos

# # 输出: shape 是 (position_id[0], position_id[1], dim)

# [[cos 0, cos 2],  # 第一行
#  [cos 1, cos 1]]  # 第二行

In [5]:
def rotate_half(x):
    x1, x2 = x[...,:x.shape[-1]//2], x[..., x.shape[-1]//2 :]
    return torch.cat((-x2, x1), dim = -1)

@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
    # position_id shape (seq_len, batch_size)
    # q, k shape (seq_len, batch_size, num_head, head_dim)
    # cos, sin shape (seq_len, 1, head_dim) -> (seq_len, batch_size, 1, head_dim)
    
    # shape (seq_len, 1, head_dim) -> (seq_len, head_dim) -> (seq_len, batch_size, head_dim) -> (seq_len, batch_size, 1, head_dim) 
    # 即根据 position_id 矩阵选择对应位置的旋转嵌入
    cos = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2)
    sin = F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)

    # 对应着那个矩阵加法的左半部分和右半部分, 最后广播乘法得到的 shape (seq_len, batch_size, num_head, head_dim)
    q = (q * cos) + (rotate_half(q) * sin)
    k = (k * cos) + (rotate_half(k) * sin)

    return q, k

In [8]:
def main():
    batch_size, seq_len, num_head, head_dim = 8, 4, 4, 8
    q = torch.randn(seq_len, batch_size, num_head, head_dim)
    k = torch.randn(seq_len, batch_size, num_head, head_dim)
     # 创建位置ID (可以是任意位置，不一定是连续的)
    # 这里使用连续位置: [[0,0,0,0], [1,1,1,1], ...] (seq_len, batch_size)
    position_id = torch.arange(seq_len).unsqueeze(1).repeat(1, batch_size)

    rotary_emb = RotaryEmbedding(dim=head_dim)
    # 预计算位置编码
    cos, sin = rotary_emb(q)  # shape: (seq_len, dim)
    q_rotated, k_rotated = apply_rotary_pos_emb_index(q, k, cos, sin, position_id)

    print("应用旋转位置编码后 - q_rotated形状:", q_rotated.shape)
    print("应用旋转位置编码后 - k_rotated形状:", k_rotated.shape)

main()

应用旋转位置编码后 - q_rotated形状: torch.Size([4, 8, 4, 8])
应用旋转位置编码后 - k_rotated形状: torch.Size([4, 8, 4, 8])
