## 1. 多头注意力机制

**为什么多头注意力（Multi-Head Attention）有效？**

多头注意力通过并行计算多个独立的注意力头，让模型从不同角度（如局部依赖、长程依赖、语法结构等）学习输入序列的多样化特征，从而提升模型的表达能力和泛化能力。

In [None]:

import torch
import torch.nn as nn 


class MHA(nn.Module):
    def __init__(self, hidden_dim=5, head=2,**kwargs):
        super().__init__()
        self.qkv_proj = nn.Linear(hidden_dim, 3 * hidden_dim)  # 合并QKV投影以提高效率
        self.hidden_dim=hidden_dim
        self.head=head

    def softmax(self,x):
        x_max=torch.max(x,dim=-1,keepdim=True).values
        x_exp=torch.exp(x-x_max)
        x_exp_sum=torch.sum(x_exp,dim=-1,keepdim=True)
        
        return x_exp/x_exp_sum

    def forward(self,x):
        bs,l,dim=x.shape
        assert dim ==self.hidden_dim
        qkv=self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)  # 正确方式
        # q,k,v=qkv.unbind(2)
        dim0=dim//self.head
        q=q.reshape(bs,l,self.head,dim0).transpose(1,2)
        k=k.reshape(bs,l,self.head,dim0).transpose(1,2)
        v=v.reshape(bs,l,self.head,dim0).transpose(1,2)

        qk=torch.matmul(q,k.transpose(2,3))/(dim0**0.5)
        # qk=qk/torch.sqrt(dim0)

        qk=self.softmax(qk)

        atten=torch.matmul(qk,v)    
        return atten.transpose(1,2).reshape(bs,l,dim)
        
        
        


x=torch.randn(4,10,6)  
mha=MHA(hidden_dim=6,head=2)
out=mha(x)

torch.Size([4, 10, 18])
18


## 2. 稀疏注意力（Sparse Attention）在保持模型性能的同时，通过减少计算量来提高效率，其效果不会显著下降的原因可以归结为以下几点？

1. 在标准注意力中，大部分位置的注意力权重接近于零，只有少数位置对当前token有显著贡献。
2. 稀疏注意力可视为对稠密注意力矩阵的低秩近似（保留主要特征，忽略次要特征），理论证明这种近似在多数任务中误差可控。

## 3. 局部窗口类稀疏机制

每个 Query 仅与前后 k 个相邻元素计算注意力

In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

class EfficientBandMHA(nn.Module):
    def __init__(self, hidden_dim=5, head=4, band_width=2, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head = head
        self.head_dim = hidden_dim // head
        self.band_width = band_width  # 每个位置仅关注前后band_width个位置
        
        # 投影层
        self.qkv_proj = nn.Linear(hidden_dim, 3 * hidden_dim)  # 合并QKV投影以提高效率
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        assert hidden_dim % head == 0, "hidden_dim must be divisible by head"

    def forward(self, x):
        bs, seq_len, dim = x.shape
        
        # 一次性计算QKV并拆分 (优化内存访问)
        qkv = self.qkv_proj(x).reshape(bs, seq_len, 3, self.head, self.head_dim)
        q, k, v = qkv.unbind(2)  # 拆分出Q, K, V，形状均为 (bs, seq_len, head, head_dim)
        
        # 转置为 (bs, head, seq_len, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 计算带状注意力的关键：只生成有效区域的注意力分数
        # 1. 确定每个位置i的有效注意力范围
        # 2. 仅计算这些位置的Q-K乘积，避免完整的n×n矩阵
        
        # 存储每个头的注意力结果
        attn_outputs = []
        
        for h in range(self.head):
            # 当前头的Q, K, V
            q_h = q[:, h]  # (bs, seq_len, head_dim)
            k_h = k[:, h]  # (bs, seq_len, head_dim)
            v_h = v[:, h]  # (bs, seq_len, head_dim)
            
            # 存储当前头的注意力分数
            sparse_attn_scores = []
            
            for i in range(seq_len):
                # 计算位置i的有效范围 [start, end)
                start = max(0, i - self.band_width)
                end = min(seq_len, i + self.band_width + 1)
                band_len = end - start
                
                # 仅计算Q[i]与K[start:end]的乘积 (节约计算)
                q_i = q_h[:, i:i+1, :]  # (bs, 1, head_dim)
                k_band = k_h[:, start:end, :]  # (bs, band_len, head_dim)
                
                # 计算注意力分数 (bs, 1, band_len)
                scores = torch.matmul(q_i, k_band.transpose(-2, -1)) / (self.head_dim ** 0.5)
                
                # 对带状区域内的分数做softmax
                scores_softmax = F.softmax(scores, dim=-1)  # (bs, 1, band_len)
                
                # 计算注意力加权和 (bs, 1, head_dim)
                v_band = v_h[:, start:end, :]  # (bs, band_len, head_dim)
                attn_i = torch.matmul(scores_softmax, v_band)  # (bs, 1, head_dim)
                
                sparse_attn_scores.append(attn_i)
            
            # 拼接当前头的所有位置结果 (bs, seq_len, head_dim)
            attn_head = torch.cat(sparse_attn_scores, dim=1)
            attn_outputs.append(attn_head)
        
        # 合并所有头 (bs, seq_len, hidden_dim)
        attn_output = torch.cat(attn_outputs, dim=-1)
        output = self.out_proj(attn_output)
        
        return output


# 测试显存占用对比
if __name__ == "__main__":
    # 生成较长序列以观察显存差异
    bs, seq_len, hidden_dim = 4, 1024, 512
    x = torch.randn(bs, seq_len, hidden_dim)  # 使用GPU测试显存
        
    # 高效带状注意力
    band_width=3
    efficient_band_mha = EfficientBandMHA(
        hidden_dim=hidden_dim, 
        head=8, 
        band_width=band_width
    )
    
    # 测试前向传播
    with torch.no_grad():
        out = efficient_band_mha(x)
        print(f"Input shape: {x.shape}")
        print(f"Output shape: {out.shape}")
        print(f"理论显存节省比例: {1 - (2*band_width + 1)/seq_len:.2%}")
    

Input shape: torch.Size([4, 1024, 512])
Output shape: torch.Size([4, 1024, 512])
理论显存节省比例: 99.32%


## 4. 多查询注意力机制
MQA 的核心特点是所有注意力头共享同一组 Key 和 Value，只保留独立的 Query，这样可以显著减少参数量和计算量，同时降低推理时的显存占用。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MQA(nn.Module):
    def __init__(self, hidden_dim=512, num_heads=8, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # 关键区别：MQA中所有头共享K和V
        # Query需要为每个头单独投影
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)  # (hidden_dim, num_heads * head_dim)
        
        # Key和Value只需要一组投影（所有头共享）
        self.k_proj = nn.Linear(hidden_dim, self.head_dim)  # (hidden_dim, head_dim)
        self.v_proj = nn.Linear(hidden_dim, self.head_dim)  # (hidden_dim, head_dim)
        
        # 输出投影
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

    def forward(self, x, attn_mask=None):
        bs, seq_len, dim = x.shape
        assert dim == self.hidden_dim
        
        # 计算Query、Key、Value
        q = self.q_proj(x)  # (bs, seq_len, hidden_dim)
        k = self.k_proj(x)  # (bs, seq_len, head_dim) - 所有头共享
        v = self.v_proj(x)  # (bs, seq_len, head_dim) - 所有头共享
        
        # 重塑Query以适应多头 (每个头有独立的Query)
        # (bs, seq_len, num_heads, head_dim) -> (bs, num_heads, seq_len, head_dim)
        q = q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 为Key和Value增加头维度（所有头共享同一组K/V）
        # (bs, seq_len, head_dim) -> (bs, 1, seq_len, head_dim) -> (bs, num_heads, seq_len, head_dim)
        k = k.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
        v = v.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
        
        # 计算注意力分数
        # (bs, num_heads, seq_len, seq_len)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # 应用掩码（如果提供）
        if attn_mask is not None:
            attn_scores = attn_scores + attn_mask
        
        # 计算注意力权重
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 应用注意力到Value
        # (bs, num_heads, seq_len, head_dim)
        attn_output = torch.matmul(attn_weights, v)
        
        # 合并多头
        # (bs, seq_len, num_heads, head_dim) -> (bs, seq_len, hidden_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bs, seq_len, self.hidden_dim)
        
        # 输出投影
        output = self.out_proj(attn_output)
        
        return output


# 测试代码
if __name__ == "__main__":
    # 配置
    batch_size = 4
    seq_length = 128
    hidden_dim = 512
    num_heads = 8
    
    # 创建随机输入
    x = torch.randn(batch_size, seq_length, hidden_dim)
    
    # 初始化MQA
    mqa = MQA(hidden_dim=hidden_dim, num_heads=num_heads)
    
    # 前向传播
    output = mqa(x)
    
    # 验证输出形状
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    assert output.shape == x.shape, "Output shape mismatch"
    
    # 计算参数量对比（与标准MHA相比）
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    mqa_params = count_parameters(mqa)
    
    # 标准多头注意力作为对比
    class StandardMHA(nn.Module):
        def __init__(self, hidden_dim, num_heads):
            super().__init__()
            self.mha = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        def forward(self, x):
            return self.mha(x, x, x)[0]
    
    standard_mha = StandardMHA(hidden_dim, num_heads)
    standard_params = count_parameters(standard_mha)
    
    print(f"MQA参数数量: {mqa_params}")
    print(f"标准MHA参数数量: {standard_params}")
    print(f"参数减少比例: {1 - mqa_params/standard_params:.2%}")
    

Input shape: torch.Size([4, 128, 512])
Output shape: torch.Size([4, 128, 512])
MQA参数数量: 590976
标准MHA参数数量: 1050624
参数减少比例: 43.75%


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class KVCache:
    """动态扩展的KV缓存管理"""
    def __init__(self, max_size=None):
        self.k_cache = None  # 存储Key缓存
        self.v_cache = None  # 存储Value缓存
        self.cur_len = 0     # 当前缓存长度
        self.max_size = max_size  # 最大缓存长度，None表示无限制

    def update(self, k, v):
        """
        更新缓存，动态扩展容量
        k: 新的Key，形状 [batch, n_heads, 1, head_dim]
        v: 新的Value，形状 [batch, n_heads, 1, head_dim]
        """
        batch, n_heads, _, head_dim = k.shape
        
        # 初始化缓存（如果尚未初始化）
        if self.k_cache is None:
            # 初始容量设为100，或根据max_size调整
            init_size = self.max_size if self.max_size else 100
            self.k_cache = torch.empty(
                (batch, n_heads, init_size, head_dim), 
                device=k.device, 
                dtype=k.dtype
            )
            self.v_cache = torch.empty(
                (batch, n_heads, init_size, head_dim), 
                device=v.device, 
                dtype=v.dtype
            )
        
        # 检查是否需要扩展缓存
        if self.cur_len >= self.k_cache.shape[2]:
            if self.max_size:
                # 如果设置了最大长度，超出则滚动覆盖最旧的元素
                self.k_cache = torch.cat([self.k_cache[:, :, 1:], k], dim=2)
                self.v_cache = torch.cat([self.v_cache[:, :, 1:], v], dim=2)
            else:
                # 动态扩展缓存容量（翻倍）
                new_size = self.k_cache.shape[2] * 2
                new_k_cache = torch.empty(
                    (batch, n_heads, new_size, head_dim),
                    device=k.device,
                    dtype=k.dtype
                )
                new_v_cache = torch.empty(
                    (batch, n_heads, new_size, head_dim),
                    device=v.device,
                    dtype=v.dtype
                )
                new_k_cache[:, :, :self.cur_len] = self.k_cache
                new_v_cache[:, :, :self.cur_len] = self.v_cache
                self.k_cache = new_k_cache
                self.v_cache = new_v_cache
        
        # 更新缓存内容
        self.k_cache[:, :, self.cur_len:self.cur_len+1] = k
        self.v_cache[:, :, self.cur_len:self.cur_len+1] = v
        self.cur_len += 1
        
        # 返回当前完整缓存
        return self.k_cache[:, :, :self.cur_len], self.v_cache[:, :, :self.cur_len]

    def reset(self):
        """重置缓存"""
        self.k_cache = None
        self.v_cache = None
        self.cur_len = 0


class AttentionWithCache(nn.Module):
    def __init__(self, hidden_dim, n_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads
        
        # MQA: 所有头共享K和V
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, self.head_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, cache=None):
        batch, seq_len, _ = x.shape
        
        # 计算Q, K, V
        q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, 1, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, 1, self.head_dim).transpose(1, 2)
        
        # 扩展K和V到所有头（共享）
        k = k.repeat(1, self.n_heads, 1, 1)
        v = v.repeat(1, self.n_heads, 1, 1)
        
        # 处理缓存
        if cache is not None:
            # 更新缓存并获取所有历史KV
            k, v = cache.update(k[:, :, -1:], v[:, :, -1:])
        
        # 计算注意力
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # 应用因果掩码
        causal_mask = torch.triu(torch.ones(seq_len, k.size(2), device=x.device), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(causal_mask[None, None, :, :], float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        # 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_dim)
        return self.out_proj(attn_output)


class TransformerLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads):
        super().__init__()
        self.attn = AttentionWithCache(hidden_dim, n_heads)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

    def forward(self, x, cache=None):
        # 自注意力 + 残差连接
        x = x + self.attn(self.norm1(x), cache)
        # 前馈网络 + 残差连接
        x = x + self.ffn(self.norm2(x))
        return x


class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_dim=128, n_layers=2, n_heads=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList([
            TransformerLayer(hidden_dim, n_heads) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, cache=None):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, cache)
        x = self.norm(x)
        return self.head(x)


def generate(model, prompt, max_length=50, temperature=1.0, top_k=30):
    """生成函数，使用KV缓存加速推理"""
    model.eval()
    generated = prompt.clone()
    batch_size = prompt.size(0)
    
    # 初始化缓存
    cache = KVCache()
    
    # 首先处理初始提示
    with torch.no_grad():
        # 计算初始提示的输出和缓存
        logits = model(generated, cache)
        
    # 自回归生成后续token
    for _ in range(max_length - generated.size(1)):
        # 只关注最后一个token的输出
        next_logits = logits[:, -1, :] / temperature
        
        # Top-k采样
        if top_k is not None:
            top_k_values, top_k_indices = torch.topk(next_logits, top_k)
            next_logits = torch.full_like(next_logits, float('-inf'))
            next_logits.scatter_(1, top_k_indices, top_k_values)
        
        # 计算概率并采样
        probs = F.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # 追加到生成序列
        generated = torch.cat([generated, next_token], dim=1)
        
        # 使用缓存进行下一步预测
        with torch.no_grad():
            logits = model(next_token, cache)
        
        # 检查是否生成结束符（假设0是结束符）
        if (next_token == 0).all():
            break
    
    return generated


# 测试代码
if __name__ == "__main__":
    # 配置参数
    VOCAB_SIZE = 1000
    HIDDEN_DIM = 128
    N_LAYERS = 2
    N_HEADS = 4
    MAX_LENGTH = 100  # 测试长序列生成，验证缓存扩展功能
    
    # 创建模型
    model = SimpleTransformer(
        vocab_size=VOCAB_SIZE,
        hidden_dim=HIDDEN_DIM,
        n_layers=N_LAYERS,
        n_heads=N_HEADS
    )
    
    # 创建初始提示（随机整数序列）
    prompt = torch.randint(1, 100, (1, 5))  # 避免使用0（假设0是结束符）
    print("输入提示:", prompt.tolist())
    
    # 生成文本
    generated = generate(model, prompt, max_length=MAX_LENGTH)
    
    print("生成结果:", generated.tolist())
    print(f"生成序列长度: {generated.size(1)}, 成功避免缓存越界问题!")
    

输入提示: [[10, 52, 93, 47, 50]]
生成结果: [[10, 52, 93, 47, 50, 759, 900, 188, 837, 106, 328, 401, 176, 204, 185, 567, 853, 29, 246, 457, 586, 649, 129, 615, 339, 615, 33, 277, 904, 752, 866, 229, 91, 971, 507, 295, 370, 27, 706, 751, 761, 800, 303, 760, 681, 486, 446, 820, 596, 239, 824, 447, 748, 336, 121, 540, 746, 752, 142, 387, 739, 177, 688, 637, 0]]
生成序列长度: 65, 成功避免缓存越界问题!
