In [11]:
from typing import Callable
from functools import partial

In [4]:
import torch

# RoPE常规的直接实现

In [20]:
def get_rotary_matrix(seq_len: int, dim: int, base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
    """生成RoPE的旋转矩阵"""
    # 生成不同频率的正弦和余弦值
    theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))  # shape为[dim//2]
    # 生成位置索引
    position = torch.arange(seq_len).float()  # shape为[seq_len]
    # 计算每个位置和维度对应的角度
    theta = torch.outer(position, theta)  # 计算外积，其中第(i, j)个元素是position[i] * theta[j]；shape为[seq_len, dim//2]
    # 计算正弦和余弦值
    cos = torch.cos(theta)  # shape为[seq_len, dim//2]
    sin = torch.sin(theta)  # shape为[seq_len, dim//2]
    return cos, sin


def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """应用旋转位置编码"""
    # 假设x的形状为[batch_size, seq_len, dim]
    # 将向量视为复数，每两个维度一组
    x_reshape = x.view(*x.shape[:-1], -1, 2)  # shape为[batch_size, seq_len, dim//2, 2]，即沿着特征维度拆分
    
    # 构建正弦和余弦矩阵，使其与x_reshape形状匹配
    cos_expanded = cos.view(1, cos.shape[0], cos.shape[1], 1)  # shape为[1, seq_len, dim//2, 1]
    sin_expanded = sin.view(1, sin.shape[0], sin.shape[1], 1)  # shape为[1, seq_len, dim//2, 1]
    
    # 旋转操作（复数乘法）
    # [x_real, x_imag] * (cos + i*sin) = [x_real*cos - x_imag*sin, x_real*sin + x_imag*cos]
    x_out_1 = x_reshape[:, :, :, 0:1] * cos_expanded - x_reshape[:, :, :, 1:2] * sin_expanded
    x_out_2 = x_reshape[:, :, :, 0:1] * sin_expanded + x_reshape[:, :, :, 1:2] * cos_expanded
    
    # 合并结果
    x_out = torch.cat([x_out_1, x_out_2], dim=-1)  # shape为[batch_size, seq_len, dim//2, 2]
    return x_out.view(*x.shape)


# 示例用法
def apply_rope(x: torch.Tensor, rotary_matrix_function: Callable, seq_len: int = None) -> torch.Tensor:
    """
    对输入向量应用RoPE位置编码
    x: 输入向量
    seq_len: 不同情况下输入不同
    """
    _, x_seq_len, dim = x.shape
    t_seq_len = seq_len if seq_len is not None else x_seq_len
    cos, sin = rotary_matrix_function(t_seq_len, dim)
    return apply_rotary_embedding(x, cos, sin)

In [21]:
rotary_matrix_function = get_rotary_matrix

batch_size, seq_len, dim = 2, 10, 512
x = torch.randn(batch_size, seq_len, dim)

# 应用RoPE
x_with_rope = apply_rope(x, rotary_matrix_function)
print(f"输入形状: {x.shape}")
print(f"输出形状: {x_with_rope.shape}")


输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])


# 外扩

## NTK版本的外扩

In [22]:
def get_ntk_rotary_matrix(seq_len: int, dim: int, base: int = 10000, scaling_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor]:
    """
    NTK缩放版本的RoPE旋转矩阵
    seq_len: 期望扩展后的序列长度
    dim: 向量维度
    base: 基础频率
    scaling_factor: 缩放因子；扩展后序列长度和原始序列长度的比值
    """
    # 应用缩放因子
    effective_base = base * (scaling_factor ** (dim / (dim - 2)))
    
    # 生成不同频率的基础角度
    theta = 1.0 / (effective_base ** (torch.arange(0, dim, 2).float() / dim))
    
    # 生成位置索引
    position = torch.arange(seq_len).float()
    
    # 计算每个位置和维度对应的角度
    theta = torch.outer(position, theta)
    
    # 计算正弦和余弦值
    cos = torch.cos(theta)
    sin = torch.sin(theta)
    return cos, sin

In [23]:
rotary_matrix_function = partial(get_ntk_rotary_matrix, scaling_factor=2.0)

batch_size, seq_len, dim = 2, 20, 512  # 此处将seq_len设置为20，与scaling_factor=2.0相匹配
x = torch.randn(batch_size, seq_len, dim)

# 应用RoPE
x_with_rope = apply_rope(x, rotary_matrix_function)
print(f"输入形状: {x.shape}")
print(f"输出形状: {x_with_rope.shape}")


输入形状: torch.Size([2, 20, 512])
输出形状: torch.Size([2, 20, 512])


## 线性插值

In [49]:
def get_linear_interpolation_rope1(seq_len: int, dim: int, target_len: int,
                                   base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
    """
    通过线性插值扩展RoPE位置编码
    seq_len: 原始序列长度
    dim: 向量维度
    target_len: 目标序列长度
    base: 基础频率
    """
    # 原始RoPE编码
    original_cos, original_sin = get_rotary_matrix(seq_len, dim, base)  # [seq_len, dim//2]
    
    # 创建两个全零向量用于存储插值后结果
    interpolated_cos = torch.zeros(target_len, original_cos.size(1))
    interpolated_sin = torch.zeros(target_len, original_sin.size(1))
    
    for i in range(original_cos.size(1)):  # 逐一进行一维线性插值
        interpolated_cos[:, i] = torch.nn.functional.interpolate(
            original_cos[:, i].unsqueeze(0).unsqueeze(0),  # original_cos[:, i]是复数维度中第i维的所有seq_len长度的序列，即[\theta_0, ..., \theta_{seq_len-1}]
            size=(target_len,),  # 以线性插值的方式扩展到目标长度target_len
            mode='linear',
            align_corners=True  # 确保插值时原始序列的两端点精确对齐
        ).squeeze(0).squeeze(0)
        
        interpolated_sin[:, i] = torch.nn.functional.interpolate(
            original_sin[:, i].unsqueeze(0).unsqueeze(0),
            size=(target_len,),
            mode='linear',
            align_corners=True
        ).squeeze(0).squeeze(0)
    
    return interpolated_cos, interpolated_sin

In [50]:
rotary_matrix_function = partial(get_linear_interpolation_rope1, target_len=20)

seq_len = 10  # 训练长度
batch_size, target_len, dim = 2, 20, 512  # 此处将target_length为目标长度
x = torch.randn(batch_size, target_len, dim)

# 应用RoPE
x_with_rope = apply_rope(x, rotary_matrix_function, seq_len=seq_len)
print(f"输入形状: {x.shape}")
print(f"输出形状: {x_with_rope.shape}")

输入形状: torch.Size([2, 20, 512])
输出形状: torch.Size([2, 20, 512])


In [53]:
def get_linear_interpolation_rope2(seq_len: int, dim: int, target_len: int,
                                   base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
    """
    通过线性插值扩展RoPE位置编码
    seq_len: 原始序列长度
    dim: 向量维度
    target_len: 目标序列长度
    base: 基础频率
    """
    # 原始RoPE编码
    original_cos, original_sin = get_rotary_matrix(seq_len, dim, base)
    
    # 将张量转换为正确的维度以使用二维插值
    # 添加两个维度，使其形状为 [1, 1, seq_len, dim//2]
    original_cos_expanded = original_cos.unsqueeze(0).unsqueeze(0)
    original_sin_expanded = original_sin.unsqueeze(0).unsqueeze(0)
    
    # 使用二维插值
    interpolated_cos = torch.nn.functional.interpolate(
        original_cos_expanded,
        size=(target_len, original_cos.size(1)),
        mode='bilinear',
        align_corners=True
    ).squeeze(0).squeeze(0)
    
    interpolated_sin = torch.nn.functional.interpolate(
        original_sin_expanded,
        size=(target_len, original_sin.size(1)),
        mode='bilinear',
        align_corners=True
    ).squeeze(0).squeeze(0)
    
    return interpolated_cos, interpolated_sin

In [54]:
rotary_matrix_function = partial(get_linear_interpolation_rope2, target_len=20)

seq_len = 10  # 训练长度
batch_size, target_len, dim = 2, 20, 512  # 此处将target_length为目标长度
x = torch.randn(batch_size, target_len, dim)

# 应用RoPE
x_with_rope = apply_rope(x, rotary_matrix_function, seq_len=seq_len)
print(f"输入形状: {x.shape}")
print(f"输出形状: {x_with_rope.shape}")

输入形状: torch.Size([2, 20, 512])
输出形状: torch.Size([2, 20, 512])


## 动态NTK

In [55]:
def compute_dynamic_ntk_scaling_factor(seq_len: int, target_len: int, alpha: float = 1.0) -> float:
    """计算动态NTK缩放系数"""
    return (target_len / seq_len) ** alpha


def get_dynamic_ntk_rotary_matrix(seq_len: int, dim: int, target_len: int, base: int = 10000,
                                  alpha: float = 1.0) -> tuple[torch.Tensor, torch.Tensor]:
    """应用动态NTK缩放的RoPE
    seq_len: 训练时的原始序列长度
    target_len: 目标序列长度
    dim: 向量维度
    base: 基础频率
    alpha: 缩放因子
    """
    # 计算缩放因子
    scaling_factor = compute_dynamic_ntk_scaling_factor(seq_len, target_len, alpha)
    
    # 获取缩放后的旋转矩阵
    cos, sin = get_ntk_rotary_matrix(target_len, dim, base, scaling_factor)

    return cos, sin

In [56]:
rotary_matrix_function = partial(get_dynamic_ntk_rotary_matrix, target_len=20)

seq_len = 10  # 训练长度
batch_size, target_len, dim = 2, 20, 512  # 此处将target_len为目标长度
x = torch.randn(batch_size, target_len, dim)

# 应用RoPE
x_with_rope = apply_rope(x, rotary_matrix_function, seq_len=seq_len)
print(f"输入形状: {x.shape}")
print(f"输出形状: {x_with_rope.shape}")

输入形状: torch.Size([2, 20, 512])
输出形状: torch.Size([2, 20, 512])
