In [1]:
from rope2d import *

dim = 128 # 注意力的维度
num_heads = 8 # 注意力头的数量
rope_theta = 10.0 # RoPE的theta参数
compute_cis = partial(
    compute_axial_cis, dim=dim // num_heads, theta=rope_theta
)

freqs_cis = compute_cis(end_x=14, end_y=14)
print(freqs_cis.shape)

tensor([1.0000, 0.5623, 0.3162, 0.1778])
tensor([1.0000, 0.5623, 0.3162, 0.1778])
torch.Size([4]) torch.Size([4])
torch.Size([196, 8])


In [4]:
def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode="floor").float()
    return t_x, t_y

t_x, t_y = init_t_xy(14, 14)
print(t_x, t_y)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 1

$$
freqs = \frac{1}{N} \sum_{i=0}^{N-1} \delta(t - t_i)
$$

In [6]:
# rope 1d
import torch

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    print(freqs)
    # print(freqs.shape)
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)  # type: ignore
    freqs = torch.outer(t, freqs)  # type: ignore
    print(freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

freqs_cis_1d = precompute_freqs_cis(dim // num_heads, 14 * 14, rope_theta)
print(freqs_cis_1d.shape)

tensor([1.0000, 0.7499, 0.5623, 0.4217, 0.3162, 0.2371, 0.1778, 0.1334])
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.4989e-01, 5.6234e-01,  ..., 2.3714e-01, 1.7783e-01,
         1.3335e-01],
        [2.0000e+00, 1.4998e+00, 1.1247e+00,  ..., 4.7427e-01, 3.5566e-01,
         2.6670e-01],
        ...,
        [1.9300e+02, 1.4473e+02, 1.0853e+02,  ..., 4.5768e+01, 3.4321e+01,
         2.5737e+01],
        [1.9400e+02, 1.4548e+02, 1.0909e+02,  ..., 4.6005e+01, 3.4499e+01,
         2.5870e+01],
        [1.9500e+02, 1.4623e+02, 1.0966e+02,  ..., 4.6242e+01, 3.4676e+01,
         2.6004e+01]])
torch.Size([196, 8])


In [4]:
def compute_cis(dim: int, theta: float = 100.0):
    """
    计算旋转式位置编码
    :param dim: 位置编码的维度
    :param theta: 旋转角度
    :return: 复数形式的位置编码
    """
    # 计算频率
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    print(freqs)
    t = torch.arange(0, dim, 2).float() / dim
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

a = compute_cis(128, 20.0)

tensor([1.0000, 0.9543, 0.9106, 0.8690, 0.8293, 0.7913, 0.7551, 0.7206, 0.6877,
        0.6562, 0.6262, 0.5976, 0.5702, 0.5442, 0.5193, 0.4955, 0.4729, 0.4512,
        0.4306, 0.4109, 0.3921, 0.3742, 0.3571, 0.3408, 0.3252, 0.3103, 0.2961,
        0.2826, 0.2696, 0.2573, 0.2456, 0.2343, 0.2236, 0.2134, 0.2036, 0.1943,
        0.1854, 0.1769, 0.1689, 0.1611, 0.1538, 0.1467, 0.1400, 0.1336, 0.1275,
        0.1217, 0.1161, 0.1108, 0.1057, 0.1009, 0.0963, 0.0919, 0.0877, 0.0837,
        0.0798, 0.0762, 0.0727, 0.0694, 0.0662, 0.0632, 0.0603, 0.0575, 0.0549,
        0.0524])


In [8]:
import torch

# 生成旋转矩阵
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后，每组元素对应的旋转角度 \theta_i
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2]
    freqs = torch.outer(t, freqs).float()  # 计算m * \theta

    # 计算结果是个复数向量
    # 假设 freqs = [x, y]
    # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


# 旋转位置编码计算
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    """
    xq: q 矩阵
    xk: k 矩阵
    freqs_cis: 位置编码
    """
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)

    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)

    # 应用旋转操作，然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
    return xq_out.type_as(xq), xk_out.type_as(xk)


# 测试代码
freqs = precompute_freqs_cis(128, 14 * 14, 10.0)
xq = torch.randn(2, 14 * 14, 128)
xk = torch.randn(2, 14 * 14, 128)
xq_out, xk_out = apply_rotary_emb(xq, xk, freqs)

print(freqs.shape)

torch.Size([196, 64])
