# Rotary Positional Embedding

llama代码细读：https://dingfen.github.io/2023/10/30/2023-10-30-huggingface1/

自注意力机制本身是位置不敏感的，需要加上位置信息。

一般来说，绝对位置编码具有实现简单、计算速度快等优点，而相对位置编码则直接地体现了相对位置信号，跟我们的直观理解吻合，实际性能往往也更好。

RoPE：通过绝对位置编码的方式实现相对位置编码。特性为：上下文表示与旋转矩阵相乘来编码相对位置；可扩展到任意长度；可用于线性注意力机制；词间距离与依赖性相关。

看苏剑林博客：https://kexue.fm/archives/8265 ，迷迷糊糊

$$
\mathbf{R}_{\Theta,m} ^d x = 
\begin{pmatrix}
x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d
\end{pmatrix}
\otimes
\begin{pmatrix}
\cos m \theta_1 \\ \cos m \theta_1 \\ \cos m \theta_2 \\ \cos m \theta_2 \\ \vdots \\ \cos m \theta_{d/2} \\ \cos m \theta_{d/2}
\end{pmatrix}
+
\begin{pmatrix} 
-x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_{d-1} \\ x_d
\end{pmatrix}
\otimes
\begin{pmatrix}
\sin m \theta_1 \\ \sin m \theta_1 \\ \sin m \theta_2 \\ \sin m \theta_2 \\ \vdots \\ \sin m \theta_{d/2} \\ \sin m \theta_{d/2}
\end{pmatrix}
$$

$$
\theta_{i} = 10000^{-2(i-1)/d}, \quad i \in [1, 2, \ldots, d/2]
$$

这个讲的不错：

【通俗易懂-大模型的关键技术之一：旋转位置编码rope （2）】 https://www.bilibili.com/video/BV1Tr421p7By/?share_source=copy_web&vd_source=308c6ef1763d60b08057708fbfe7c230

In [4]:
import torch

### Llama
其实llama现在的modeling.py用的也不是这种了

In [6]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim))  # [dim / 2]
    # print(freqs.shape)
    t = torch.arange(end, device=freqs.device)  # [end]
    freqs = torch.outer(t, freqs).float()  # [end , dim / 2]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # 将极坐标表示转换为直角坐标的复数。两参数作为复数的模和角度（以弧度为单位）
    # 也就得到了cos(m·theta(i))+sin(m·theta(i)) j, m=0,1,...,end-1, i=0,1,...,dim/2-1 
    return freqs_cis

In [7]:
def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert ndim > 1
    # print(freqs_cis.shape, x.shape)
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)  # [1, seq_len, 1, head_dim]

In [8]:
def apply_rotary_embedding(q, k, freqs_cis):
    q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))  # 变成复数，[batch, seq_len, head_num, head_dim/2]
    k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, q_)  # 可以跟q_, k_相乘了
    # print(q_.device, freqs_cis.device)
    q_out = torch.view_as_real(q_ * freqs_cis).flatten(2)  # 变回实数，[batch, seq_len, head_num, head_dim]
    k_out = torch.view_as_real(k_ * freqs_cis).flatten(2)  # 有head_num则flatten(3)
    # 举个例子，(x(1) + x(2)j) * (cos(m·theta(1))+sin(m·theta(1)) j) 
    #          == (x(1)cos(m·theta(1)) - x(2)sin(m·theta(1))) + (x(1)sin(m·theta(1)) + x(2)cos(m·theta(1)))j 
    # 再变回实数，上面公式的前两行就得到了
    return q_out.type_as(q), k_out.type_as(k)

In [46]:
# 测试rope
batch_size = 2
seq_len = 4
head_dim = 8
q = torch.randn(batch_size, seq_len, head_dim)
k = torch.randn(batch_size, seq_len, head_dim)

q_rope, k_rope =apply_rotary_embedding(q, k, precompute_freqs_cis(head_dim, seq_len))
print(q_rope.shape)  # 输出形状

torch.Size([4])
torch.Size([4, 4]) torch.Size([2, 4, 4])
torch.Size([2, 4, 8])


In [14]:
# 向量外积是矩阵
t = torch.tensor([0, 1, 2], dtype=torch.float)
freqs = torch.tensor([10, 20, 30], dtype=torch.float)
freqs = torch.outer(t, freqs)
print(freqs)

tensor([[ 0.,  0.,  0.],
        [10., 20., 30.],
        [20., 40., 60.]])


In [15]:
# out = abs · cos(angle) + abs · sin(angle) · j
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
freqs_cis

tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [-0.8391-0.5440j,  0.4081+0.9129j,  0.1543-0.9880j],
        [ 0.4081+0.9129j, -0.6669+0.7451j, -0.9524-0.3048j]])

In [17]:
import math
1 * math.cos(10), 1 * math.sin(10)  # 测试出torch.polar的公式正确

(-0.8390715290764524, -0.5440211108893698)

In [25]:
# 复数乘法法则：(a+bi)(c+di)=(ac-bd)+(bc+ad)i
(1.0000+1.0000j) * (-1-1j)

-2j

### ChatGLM

In [1]:
# 不一定要按维度顺序两两组合旋转，神经元是无序的，不依赖维度顺序
def rotate_half(x):
    """Rotate half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

In [2]:
def rope(x, dim):
    # print(x.shape)
    theta = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)).to(x.device)  # [dim / 2]
    # print(theta.shape)
    seq_len = x.shape[1]
    seq_idx = torch.arange(seq_len).float().to(x.device)  # [seq_len]
    # print(seq_idx.shape)
    idx_theta = torch.einsum('i,j->ij', seq_idx, theta)  # 爱因斯坦求和约定（Einstein summation convention）
    # print(idx_theta.shape)
    idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
    # print(idx_theta2.shape)
    cos_cached = idx_theta2.cos()[None, :, :]  # [seq_len, head_dim]扩展维度到[seq_len, 1, 1, head_dim]
    sin_cached = idx_theta2.sin()[None, :, :]
    # sin_cached = idx_theta2.sin()[:, None, None, :]  # 原版是这样的
    # print(cos_cached.shape)
    x_rope = (x * cos_cached) + (rotate_half(x) * sin_cached)
    # print(x_rope.shape)
    return x_rope

In [None]:
# 测试rope
batch_size = 2
seq_len = 4
head_dim = 8
x = torch.randn(batch_size, seq_len, head_dim)

x_rope = rope(x, head_dim)
print(x_rope.shape)  # 输出形状

# 测试出两种写法结果不一致，原因未知
q_rope, k_rope =apply_rotary_embedding(x, x, precompute_freqs_cis(head_dim, seq_len))
print(x_rope, q_rope, k_rope)

tensor([[[ 0.1419,  0.0721, -0.6689, -0.9504, -1.3637,  1.5168, -1.0835,
           0.1435],
         [ 0.2478, -0.3760,  0.2059, -0.7293,  0.5550,  0.1493, -1.0972,
          -0.1126],
         [ 0.0925, -0.9905, -0.6838, -0.3321, -1.2511, -0.6986,  0.2962,
          -1.1512],
         [ 0.9016,  0.1168,  1.4854,  0.6318,  0.6785, -0.6640,  0.7347,
          -1.5072]],

        [[ 0.7554, -0.9979, -1.1566,  0.1857,  0.9541, -0.3422,  0.3528,
           0.8431],
         [ 3.0744,  0.3404,  0.2961,  0.4119,  0.1635,  0.2764, -0.7737,
           0.3776],
         [-0.2102,  1.1085, -0.9568, -0.0940,  0.2554, -0.7549,  1.1950,
          -0.0265],
         [ 0.4608, -1.0372,  2.6274, -2.1610, -1.0221, -0.1619,  0.3776,
           1.2964]]]) tensor([[[ 0.1419,  0.0721, -0.6689, -0.9504, -1.3637,  1.5168, -1.0835,
           0.1435],
         [ 0.6269,  0.3115,  0.2667, -0.7063,  0.0894,  0.1870, -1.0991,
          -0.1130],
         [ 1.4984, -0.6077, -0.5978, -0.4624,  0.4462, -0.4791,  0

### 冒泡排序
公司来人面试，忍不住小试牛刀

In [1]:
# 解释冒泡排序原理：冒泡排序是一种简单的排序算法，它通过 repeatedly 遍历整个列表，比较相邻的元素，如果发现它们的顺序错误，就将它们交换位置。
def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(0, n-i-1):
            if arr[j] > arr[j+1] :
                arr[j], arr[j+1] = arr[j+1], arr[j]

In [3]:
# 测试冒泡排序
arr = [3, 5, 1, 2, 4]
bubble_sort(arr)
print(arr)

[1, 2, 3, 4, 5]
