# 1 RoPE（Rotary Position Embedding，旋转位置编码）

In [None]:
import torch


# 用于预计算旋转位置嵌入（Rotary Position Embedding）的频率
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    返回预计算的频率tensor,形状为 (end, dim // 2),数据类型为complex64(复数)
    """
    # 计算频率，使用 theta 的倒数作为基数，以 2 为步长从 0 到 dim 生成一个等差数列，然后取前 dim // 2 个元素
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 创建一个从 0 到 end-1 的等差数列，作为时间步
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    # 计算时间步和频率的外积，得到每个时间步的频率
    freqs = torch.outer(t, freqs)
    # 将频率转换为复数形式，实部为 1，虚部为频率
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq, xk, freqs_cis):
    # 将QK转换为复数形式，并重塑为 (batch_size, seq_len, num_heads, head_dim // 2, 2)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    # 将QK与旋转位置嵌入频率相乘，并将结果转换回实数形式
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)