# 位置编码

学习顺序：
1、位置编码介绍：https://www.bilibili.com/video/BV1xR1RY9ECm/?spm_id_from=333.337.search-card.all.click&vd_source=071b23b9c7175dbaf674c65294124341  
2、transformer位置编码介绍：https://www.bilibili.com/video/BV1AD421g7hs/?spm_id_from=333.337.search-card.all.click&vd_source=071b23b9c7175dbaf674c65294124341  
3、RoPE：视频一： https://www.bilibili.com/video/BV12x42127Pb?spm_id_from=333.788.videopod.sections&vd_source=071b23b9c7175dbaf674c65294124341  
视频二：https://www.bilibili.com/video/BV1F1421B7iv/?spm_id_from=333.337.search-card.all.click&vd_source=071b23b9c7175dbaf674c65294124341
  
  

优势：和相对位置编码相比，RoPE 具有更好的外推性，即对于超出训练数据长度的序列，RoPE 仍然能够提供有效的位置编码。这是因为 RoPE 的位置编码是基于正弦和余弦函数的周期性特性，而相对位置编码则依赖于训练数据中的相对位置关系。例如，如果一个模型在训练时只使用了512个 token 的文本，那么在预测时如果输入超过512个 token，模型可能无法正确处理。这就限制了大模型在处理长文本或多轮对话等任务时的效果。

## 1、学习三角式绝对位置编码  
该方法无法进行学习与更新，位置写定

In [3]:
# 绝对位置编码，transformer中的三角式
import math
import torch

seq_len = 50           # sequence 长度，block_size
embedding_dim = 512    # 单个 token 的编码维度

def get_pe(pos, j, dim):
    pe = pos / (10000 ** (2 * j / dim))
    return math.sin(pe), math.cos(pe)

pe = torch.empty(seq_len, embedding_dim)
for i in range(seq_len):
    for j in range(0, embedding_dim // 2):
        pe[i, 2*j], pe[i, 2*j+1] = get_pe(i, j, embedding_dim)

## 2、可学习式绝对位置编码

In [4]:
from torch import nn

wpe=nn.Embedding(seq_len, embedding_dim)
# wpe_para = nn.Parameter(torch.randn(seq_len, embedding_dim))
pos = torch.arange(0, seq_len, dtype=torch.long) # shape (T)
pos_emb = wpe(pos)
# wpe_para = wpe[pos] 

## 3、旋转位置编码

In [None]:
# 无多头注意力机制版

import torch
import torch.nn as nn
import math
from torch.nn import functional as F

def rotate_half(x):
    """将输入张量的后半部分旋转"""
    # x.chunk(2, dim=-1)为将最后一个维度切成两半，返回两个张量
    x1, x2 = x.chunk(2, dim=-1)
    # 等价于下面两行
    # x1 = x[..., : x.shape[-1] // 2]
    # x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, sin, cos):
    """应用旋转位置编码到查询和键向量"""
    # cos = cos.unsqueeze(1)
    # sin = sin.unsqueeze(1)
    # 对查询向量应用旋转
    q_rot = (q * cos) + (rotate_half(q) * sin)
    # 对键向量应用旋转
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot

class RotaryPositionEmbedding(nn.Module):
    """旋转位置编码模块"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def _get_sin_cos(self, seq_len, device):
        pos = torch.arange(seq_len, device=device).type_as(self.inv_freq)   # 加上绝对位置编码信息
        freqs = torch.einsum("i,j->ij", pos, self.inv_freq) # shape (seq_len, dim//2)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.sin(), emb.cos()

    def forward(self, q, k):
        batch_size, seq_len, _, _ = q.size()
        sin, cos = self._get_sin_cos(seq_len, q.device) # shape (seq_len, dim)
        # sin[None, :, None, :]就等于sin.view(1,seq_len,1,dim)
        sin = sin[None, :, None, :].expand(batch_size, -1, -1, -1)  # 添加batch_size维度
        cos = cos[None, :, None, :].expand(batch_size, -1, -1, -1)
        return apply_rotary_pos_emb(q, k, sin, cos)

class AttentionWithRoPE(nn.Module):
    """使用旋转位置编码的自注意力模块"""
    def __init__(self, dim, head_dim):
        super().__init__()
        self.head_dim = head_dim
        self.Wq = nn.Linear(dim, head_dim, bias=False)
        self.Wk = nn.Linear(dim, head_dim, bias=False)
        self.Wv = nn.Linear(dim, head_dim, bias=False)
        self.rotary = RotaryPositionEmbedding(head_dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 生成Q, K, V
        q = self.Wq(x).view(batch_size, seq_len, 1, self.head_dim)
        k = self.Wk(x).view(batch_size, seq_len, 1, self.head_dim)
        v = self.Wv(x).view(batch_size, seq_len, 1, self.head_dim)
        
        # 应用旋转位置编码
        q, k = self.rotary(q, k)
        
        # 计算注意力分数
        scores = torch.einsum("bnid,bnjd->bnij", q, k) / math.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)
        # 应用注意力到值向量
        output = torch.einsum("bnij,bnjd->bnid", attn, v)
        output = output.view(batch_size, seq_len, self.head_dim)

        # output = F.scaled_dot_product_attention(q, k, v,is_causal=True) # 使用 PyTorch 的内置函数计算注意力
        # output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.head_dim)
        return output

# 示例使用
if __name__ == "__main__":
    dim = 128  # 模型维度
    head_dim = 64  # 每个注意力头的维度
    seq_len = 10  # 序列长度
    batch_size = 2  # 批大小

    # 初始化注意力模块
    attn = AttentionWithRoPE(dim, head_dim)
    
    # 生成随机输入
    x = torch.randn(batch_size, seq_len, dim)
    
    # 前向传播
    output = attn(x)
    
    print("输入形状:", x.shape)
    print("输出形状:", output.shape)
    print("输出示例:")
    print(output[0, 0, :10])  # 打印第一个样本第一个位置的输出前10维

输入形状: torch.Size([2, 10, 128])
输出形状: torch.Size([2, 10, 64])
输出示例:
tensor([-0.1196, -0.7089,  0.0172, -0.8604, -0.4860, -0.5124, -0.6146, -0.3723,
         0.3353,  0.0694], grad_fn=<SliceBackward0>)


In [14]:
# 有多头注意力机制版

import torch
import torch.nn as nn
import math

def rotate_half(x):
    """将输入张量的后半部分旋转"""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, sin, cos):
    """应用旋转位置编码到查询和键向量"""
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot

class RotaryPositionEmbedding(nn.Module):
    """旋转位置编码模块（支持多头）"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def _get_sin_cos(self, seq_len, device):
        pos = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", pos, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.sin(), emb.cos()

    def forward(self, q, k):
        batch_size, seq_len, num_heads, _ = q.size()
        sin, cos = self._get_sin_cos(seq_len, q.device) # sin.shape = cos.shape = (seq_len, head_dim)
        
        # 扩展维度适配多头 [batch_size, seq_len, num_heads, head_dim]
        sin = sin.view(1, seq_len, 1, -1).expand(batch_size, -1, num_heads, -1)
        cos = cos.view(1, seq_len, 1, -1).expand(batch_size, -1, num_heads, -1)
        
        return apply_rotary_pos_emb(q, k, sin, cos)

class MultiHeadAttentionWithRoPE(nn.Module):
    """集成4个头并支持旋转位置编码的自注意力模块"""
    def __init__(self, dim, num_heads=4):
        super().__init__()
        assert dim % num_heads == 0, "dim必须能被num_heads整除"
        
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        # 初始化投影矩阵
        self.Wq = nn.Linear(dim, dim, bias=False)
        self.Wk = nn.Linear(dim, dim, bias=False)
        self.Wv = nn.Linear(dim, dim, bias=False)
        
        # 旋转位置编码
        self.rotary = RotaryPositionEmbedding(self.head_dim)
        
        # 输出投影
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 生成Q, K, V并分头
        q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # 应用旋转位置编码
        q, k = self.rotary(q, k)
        
        # 调整维度顺序 [batch, heads, seq_len, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)
        
        # 应用注意力到值向量
        output = torch.matmul(attn, v)  # [batch, heads, seq_len, head_dim]
        
        # 合并多头输出
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, self.dim)
        
        # 最终投影
        return self.out_proj(output)

# 示例使用
if __name__ == "__main__":
    dim = 128    # 总维度
    num_heads = 4 # 注意力头数
    seq_len = 10  # 序列长度
    batch_size = 2 # 批大小

    # 初始化注意力模块
    attn = MultiHeadAttentionWithRoPE(dim, num_heads)
    
    # 生成随机输入
    x = torch.randn(batch_size, seq_len, dim)
    
    # 前向传播
    output = attn(x)
    
    print("输入形状:", x.shape)
    print("输出形状:", output.shape)
    print("输出示例:")
    print(output[0, 0, :10])  # 打印第一个样本第一个位置的输出前10维

输入形状: torch.Size([2, 10, 128])
输出形状: torch.Size([2, 10, 128])
输出示例:
tensor([-0.0527, -0.1726,  0.0473,  0.1299,  0.0536, -0.1116, -0.1719,  0.0064,
        -0.0302,  0.1980], grad_fn=<SliceBackward0>)
