# Attention
![image.png](attachment:fb787607-5747-4dcf-be4e-7d88d61f1227.png)

## 1.MHA(Multi-Head-Attention)
![image.png](attachment:ab25d7c1-1bbe-4ae7-8d76-8e236dc28ec8.png)

In [2]:
import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head, dropout_rate=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim  # 模型总隐藏维度
        self.nums_head = nums_head    # 注意力头的数量

        # 确保每个头的维度相同，且总维度能被头数整除
        assert hidden_dim % nums_head == 0
        self.head_dim = hidden_dim // nums_head  # 每个头的维度

        self.dropout = nn.Dropout(dropout_rate)  # 注意力权重上的 dropout 防止过拟合

        # 定义线性映射层，将输入映射为 Q、K、V 向量
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)  # Query 投影
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)  # Key 投影
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)  # Value 投影
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)  # 多头输出后的融合线性层

    def forward(self, x, attention_mask=None):
        # 输入 x 的形状: (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, _ = x.size()

        # 对输入进行线性变换，得到 Q, K, V（形状相同）
        Q = self.q_proj(x)  # (batch_size, seq_len, hidden_dim)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # 重新排列形状以支持多头注意力计算
        # 将 (batch_size, seq_len, hidden_dim) → (batch_size, nums_head, seq_len, head_dim)
        q = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)
        k = K.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)
        v = V.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)

        # 计算注意力分数: Q × K^T，并除以缩放因子 √d_k
        # 结果形状: (batch_size, nums_head, seq_len, seq_len)
        attention_val = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)

        print(f"attention_val shape is {attention_val.size()}")
        print(f"attention_mask shape is {attention_mask.size()}")

        if attention_mask is not None:
            # attention_mask 的形状应为 (batch_size, nums_head, seq_len, seq_len)
            # 将 mask 为 0 的位置替换为 -inf，防止参与 softmax
            assert attention_val.size() == attention_mask.size()
            attention_val = torch.masked_fill(attention_val, attention_mask == 0, float("-inf"))

        # 对注意力分数进行 softmax 得到注意力权重
        attention_weight = torch.softmax(attention_val, dim=-1)

        # 对注意力权重应用 dropout
        attention_weight = self.dropout(attention_weight)

        # 计算输出：注意力权重 × Value
        # (batch_size, nums_head, seq_len, seq_len) @ (batch_size, nums_head, seq_len, head_dim)
        # → (batch_size, nums_head, seq_len, head_dim)
        output_tmp = attention_weight @ v

        # 将多头输出合并：转换维度并拼接
        # 从 (batch_size, nums_head, seq_len, head_dim)
        # → (batch_size, seq_len, nums_head, head_dim)
        # → (batch_size, seq_len, hidden_dim)
        output_tmp = output_tmp.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)

        # 通过输出线性层映射回原始维度
        output = self.output_proj(output_tmp)
        return output

# 示例代码（测试）
if __name__ == "__main__":
    x = torch.randn(2, 3, 4)  # 输入张量，形状为 (batch_size=2, seq_len=3, hidden_dim=4)
    batch_size, seq_len, hidden_dim = x.size()
    nums_head = 2  # 设置注意力头数为 2

    # 构造注意力掩码 (batch_size, nums_head, seq_len, seq_len)
    # 使用下三角掩码实现因果注意力（只看当前及之前的位置）
    attention_mask = torch.tril(torch.ones(batch_size, nums_head, seq_len, seq_len))
    print(attention_mask)

    # 实例化多头注意力模块
    multi_head_attention = MultiHeadAttention(hidden_dim=hidden_dim, nums_head=nums_head)
    
    # 前向传播
    x_forward = multi_head_attention.forward(x, attention_mask=attention_mask)
    print(x_forward)
    print(x_forward.size())


tensor([[[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]]])
attention_val shape is torch.Size([2, 2, 3, 3])
attention_mask shape is torch.Size([2, 2, 3, 3])
tensor([[[ 0.0117,  0.0545,  0.2684, -0.3192],
         [-0.4979,  0.0173,  0.4014, -0.2413],
         [-0.2924,  0.2526,  0.3353, -0.0917]],

        [[-0.1564, -0.3443, -0.1695, -0.9423],
         [-0.3423, -0.1217,  0.0141, -0.6216],
         [-0.2511, -0.0289,  0.0864, -0.4957]]], grad_fn=<ViewBackward0>)
torch.Size([2, 3, 4])


## 2.MQA(Multi Query Attention)
![image.png](attachment:a31eb248-7716-47b8-a377-f1441d64ce14.png)

In [4]:
import torch
import torch.nn as nn
import math

class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim           # 模型总维度
        self.nums_head = nums_head             # 注意力头数
        assert hidden_dim % nums_head == 0
        self.head_dim = hidden_dim // nums_head  # 每个头的维度

        self.dropout = nn.Dropout(p=dropout)     # dropout 层防止过拟合

        # Q 使用独立头：每个 head 独立地投影
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        
        # K 和 V 使用共享权重：只投影一次，总维度就是 head_dim
        self.k_proj = nn.Linear(hidden_dim, self.head_dim * 1)
        self.v_proj = nn.Linear(hidden_dim, self.head_dim * 1)

        self.output_proj = nn.Linear(hidden_dim, hidden_dim)  # 多头合并后再映射到输出

    def forward(self, x, attention_mask=None):
        batch_size, seq_len, _ = x.size()

        # Q: (batch_size, seq_len, hidden_dim)
        # K, V: (batch_size, seq_len, head_dim)
        Q, K, V = self.q_proj(x), self.k_proj(x), self.v_proj(x)

        # 拆分 Q 为多个头：(batch_size, head_num, seq_len, head_dim)
        q = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)

        # 对 K 和 V 添加一个 head 维度，使其可以广播到所有 head：
        # -> (batch_size, 1, seq_len, head_dim)
        k = K.unsqueeze(1)
        v = V.unsqueeze(1)

        # 注意力打分：Q × K^T / sqrt(d_k)
        # (batch_size, head_num, seq_len, head_dim) × (batch_size, 1, head_dim, seq_len)
        # -> (batch_size, head_num, seq_len, seq_len)
        attention_val = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        print(f"attention_val shape is {attention_val.size()}")  # 应为 (batch, head, seq, seq)

        if attention_mask is not None:
            # 使用掩码屏蔽无效位置
            attention_val = torch.masked_fill(attention_val, attention_mask == 0, float("-inf"))

        # softmax 获得注意力权重
        attention_weight = torch.softmax(attention_val, dim=-1)
        print(f"attention_weight is {attention_weight}")
        attention_weight = self.dropout(attention_weight)

        # 加权求和：attention_weight @ V
        # (batch, head, seq, seq) @ (batch, 1, seq, head_dim)
        # -> (batch, head, seq, head_dim)
        output_tmp = attention_weight @ v

        # 将多头输出拼接回 hidden_dim：
        # -> (batch_size, seq_len, nums_head, head_dim)
        # -> (batch_size, seq_len, hidden_dim)
        output_tmp = output_tmp.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)

        # 映射回原始 hidden_dim 空间
        output = self.output_proj(output_tmp)
        return output
if __name__ == "__main__":
    x = torch.randn(2, 3, 4)  # batch_size=2, seq_len=3, hidden_dim=4
    batch_size, seq_len, hidden_dim = x.size()
    nums_head = 2

    # 下三角掩码（只看当前和过去位置）
    attention_mask = torch.tril(torch.ones(batch_size, nums_head, seq_len, seq_len))
    print(attention_mask)

    multi_query_attention = MultiQueryAttention(hidden_dim=hidden_dim, nums_head=nums_head, dropout=0.2)
    
    x_forward = multi_query_attention.forward(x, attention_mask=attention_mask)
    print(x_forward)
    print(x_forward.size())  # 应该为 (2, 3, 4)


tensor([[[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]]])
attention_val shape is torch.Size([2, 2, 3, 3])
attention_weight is tensor([[[[1.0000, 0.0000, 0.0000],
          [0.6390, 0.3610, 0.0000],
          [0.2514, 0.4027, 0.3460]],

         [[1.0000, 0.0000, 0.0000],
          [0.5723, 0.4277, 0.0000],
          [0.3653, 0.3053, 0.3294]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.2356, 0.7644, 0.0000],
          [0.2994, 0.3721, 0.3285]],

         [[1.0000, 0.0000, 0.0000],
          [0.2222, 0.7778, 0.0000],
          [0.4326, 0.2576, 0.3098]]]], grad_fn=<SoftmaxBackward0>)
tensor([[[ 0.0666,  0.7195, -0.6821, -0.8065],
         [-0.3388,  1.5989, -0.7055, -0.7176],
         [-0.1142,  1.2263, -0.5566, -0.4603]],

        [[ 0.3196, 

## GQA(Grouped Query Attention)
![image.png](attachment:1f84576b-0bcd-4482-8df0-a8555c99991f.png)

In [5]:
import math
import torch
import torch.nn as nn

class GQABroadcast(nn.Module):
    """
    Group Query Attention（分组查询注意力）实现：
    可通过设置 nums_kv_head（K/V 分组数）支持：
      - nums_kv_head == nums_head：等价于标准多头注意力（MHA）
      - nums_kv_head == 1：等价于 Multi-Query Attention（MQA）
      - 1 < nums_kv_head < nums_head：通用 GQA 形式
    """
    def __init__(self, hidden_dim, nums_head, nums_kv_head, dropout_rate=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.nums_head = nums_head  # Q头的总数（H）
        self.nums_kv_head = nums_kv_head  # K/V头的数量（G），即组数

        # 要求 hidden_dim 能整除 Q 头数，且 Q 头数能整除 KV 组数
        assert hidden_dim % nums_head == 0
        assert nums_head % nums_kv_head == 0

        self.head_dim = hidden_dim // nums_head  # 每个头的维度
        self.q_heads_per_group = nums_head // nums_kv_head  # 每组中 Q 的数量
        self.dropout = nn.Dropout(dropout_rate)

        # Q 使用标准全维度映射
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        # K/V 只产生 G 个头，每个 head_dim 维
        self.k_proj = nn.Linear(hidden_dim, nums_kv_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, nums_kv_head * self.head_dim)

        # 输出线性映射
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, attention_mask=None):
        batch_size, seq_len, _ = x.size()

        # 得到 Q, K, V
        Q = self.q_proj(x)  # (batch, seq, hidden_dim)
        K = self.k_proj(x)  # (batch, seq, G * head_dim)
        V = self.v_proj(x)

        # Q: 拆分成多个 Q 头
        q = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)
        # 进一步 reshape 为 (batch, G, Q/G, seq, head_dim)
        q = q.view(batch_size, self.nums_kv_head, self.q_heads_per_group, seq_len, self.head_dim)

        # K/V 处理为 (batch, G, 1, seq, head_dim)，用于广播匹配多个 Q
        k = K.view(batch_size, seq_len, self.nums_kv_head, self.head_dim).transpose(1, 2).unsqueeze(2)
        v = V.view(batch_size, seq_len, self.nums_kv_head, self.head_dim).transpose(1, 2).unsqueeze(2)

        # 注意力分数计算（广播 Q 与 K）
        attention_val = q @ k.transpose(-1, -2) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            # 使用 mask 屏蔽非法位置
            attention_val = torch.masked_fill(attention_val, attention_mask == 0, float("-inf"))

        # softmax 得到注意力权重
        attention_weight = torch.softmax(attention_val, dim=-1)
        attention_weight = self.dropout(attention_weight)

        # 使用注意力权重加权 V，得到每个 Q 的输出
        output_tmp = attention_weight @ v

        # 将输出合并为 (batch, nums_head, seq_len, head_dim)
        output_tmp = output_tmp.view(batch_size, self.nums_head, seq_len, self.head_dim)

        # 最终拼接并映射为原始维度
        output_concat = output_tmp.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
        output = self.output_proj(output_concat)
        return output


class GQARepeat(nn.Module):
    """
    Group Query Attention（重复实现版本）：
    与 GQABroadcast 结果一致，但显式复制 K/V 以匹配每个 Q 头
    """
    def __init__(self, hidden_dim, nums_head, nums_kv_head, dropout_rate=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.nums_head = nums_head
        self.nums_kv_head = nums_kv_head

        assert hidden_dim % nums_head == 0
        assert nums_head % nums_kv_head == 0

        self.head_dim = hidden_dim // nums_head
        self.q_head_per_group = nums_head // nums_kv_head

        self.q_proj = nn.Linear(hidden_dim, nums_head * self.head_dim)
        self.k_proj = nn.Linear(hidden_dim, nums_kv_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, nums_kv_head * self.head_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, attention_mask=None):
        batch_size, seq_len, _ = x.size()

        # 投影 Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Q 拆分为多个 Q 头
        q = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)

        # K/V 生成 G 个头后 repeat 扩展为 H 个头（复制）
        k = K.view(batch_size, seq_len, self.nums_kv_head, self.head_dim).transpose(1, 2)
        v = V.view(batch_size, seq_len, self.nums_kv_head, self.head_dim).transpose(1, 2)

        # 显式复制 K/V：每组共享一份，但对每个 Q 头重复
        k_repeat = k.repeat_interleave(self.q_head_per_group, dim=1)
        v_repeat = v.repeat_interleave(self.q_head_per_group, dim=1)

        # 注意力计算
        attention_val = q @ k_repeat.transpose(-1, -2) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_val = torch.masked_fill(attention_val, attention_mask == 0, float('-inf'))

        attention_weight = torch.softmax(attention_val, dim=-1)
        attention_weight = self.dropout(attention_weight)

        # 输出
        output_tmp = attention_weight @ v_repeat

        # 拼接为 (batch_size, seq_len, hidden_dim)
        output_concat = output_tmp.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
        output = self.output_proj(output_concat)
        return output


if __name__ == "__main__":
    # 输入示例张量
    x = torch.randn(2, 3, 16)  # (batch_size=2, seq_len=3, hidden_dim=16)
    batch_size, seq_len, hidden_dim = x.size()
    nums_head = 8
    head_dim = hidden_dim // nums_head
    nums_kv_head = 4
    q_heads_per_group = nums_head // nums_kv_head

    # Broadcast 模式的注意力 mask: (batch_size, G, H/G, seq, seq)
    attention_mask_v1 = torch.tril(torch.ones(batch_size, nums_kv_head, q_heads_per_group, seq_len, seq_len))
    gqa_broadcast = GQABroadcast(hidden_dim=hidden_dim, nums_head=nums_head,
                                  nums_kv_head=nums_kv_head, dropout_rate=0.1)
    x_forward_v1 = gqa_broadcast.forward(x, attention_mask=attention_mask_v1)
    print(x_forward_v1.size())  # 输出应为 (2, 3, 16)

    # Repeat 模式的注意力 mask: (batch_size, nums_head, seq, seq)
    attention_mask_v2 = torch.tril(torch.ones(batch_size, nums_head, seq_len, seq_len))
    gqa_repeat = GQARepeat(hidden_dim=hidden_dim, nums_head=nums_head,
                           nums_kv_head=nums_kv_head, dropout_rate=0.1)
    x_forward_v2 = gqa_repeat.forward(x, attention_mask=attention_mask_v2)
    print(x_forward_v2.size())  # 输出应为 (2, 3, 16)


torch.Size([2, 3, 16])
torch.Size([2, 3, 16])
