In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SelfAttention(nn.Module):
    def __init__(self, dim_q, dim_k, dim_v):
        super(SelfAttention, self).__init__()
        self.linear_q = nn.Linear(dim_q, dim_k, bias=False)
        self.linear_k = nn.Linear(dim_q, dim_k, bias=False)
        self.linear_v = nn.Linear(dim_q, dim_v, bias=False)
        self._norm_fact = 1 / math.sqrt(dim_k)  # 缩放因子，避免分数值过大
        self.dim_q = dim_q
        self.dim_k = dim_k
        self.dim_v = dim_v

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_q = x.shape
        assert dim_q == self.dim_q, "输入维度与初始化dim_q不匹配"
        
        # 计算Q、K、V
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        
        # 注意力分数计算与缩放
        attn_scores = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact
        # 应用掩码（若有），掩码位置设为极小值以屏蔽注意力
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # 注意力权重归一化与上下文向量计算
        attn_weights = F.softmax(attn_scores, dim=-1)
        att = torch.bmm(attn_weights, v)
        
        return att, attn_weights


class MultiHeadAttention(nn.Module):
    def __init__(self, dim_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert dim_model % num_heads == 0, f"输入维度{dim_model}需能被头数{num_heads}整除"
        
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.head_dim = dim_model // num_heads  # 每个头的维度
        
        # Q/K/V线性映射层
        self.linear_q = nn.Linear(dim_model, dim_model, bias=False)
        self.linear_k = nn.Linear(dim_model, dim_model, bias=False)
        self.linear_v = nn.Linear(dim_model, dim_model, bias=False)
        
        # 复用自注意力模块
        self.self_attn = SelfAttention(
            dim_q=self.head_dim, 
            dim_k=self.head_dim, 
            dim_v=self.head_dim
        )
        
        # 多头拼接后的输出线性层
        self.linear_out = nn.Linear(dim_model, dim_model, bias=False)

    def _split_heads(self, x):
        """将输入拆分为多个头
        输入x：[batch_size, seq_len, dim_model]
        输出：[batch_size, num_heads, seq_len, head_dim]
        注：transpose后调用contiguous()确保张量内存连续，避免view报错
        """
        batch_size, seq_len, dim_model = x.shape
        return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _concat_heads(self, x):
        """将多个头的输出拼接为整体
        输入x：[batch_size, num_heads, seq_len, head_dim]
        输出：[batch_size, seq_len, dim_model]
        """
        batch_size, num_heads, seq_len, head_dim = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_model = x.shape
        assert dim_model == self.dim_model, "输入维度与初始化dim_model不匹配"
        
        # 1. Q/K/V线性映射
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        
        # 2. 拆分多头
        q_split = self._split_heads(q)
        k_split = self._split_heads(k)
        v_split = self._split_heads(v)
        
        # 3. 子空间自注意力计算（展平batch+heads维度以适配SelfAttention）
        q_reshaped = q_split.view(-1, seq_len, self.head_dim)
        k_reshaped = k_split.view(-1, seq_len, self.head_dim)
        v_reshaped = v_split.view(-1, seq_len, self.head_dim)
        mask_reshaped = mask.repeat(self.num_heads, 1, 1) if mask is not None else None
        
        att_split, att_weights_split = self.self_attn(q_reshaped, mask=mask_reshaped)
        
        # 4. 多头拼接与输出线性变换
        att_reshaped = att_split.view(batch_size, self.num_heads, seq_len, self.head_dim)
        att_concat = self._concat_heads(att_reshaped)
        out = self.linear_out(att_concat)
        
        # 计算所有头的平均注意力权重
        att_weights = att_weights_split.view(batch_size, self.num_heads, seq_len, seq_len).mean(dim=1)
        
        return out, att_weights


# 测试模块功能
if __name__ == "__main__":
    batch_size = 2
    seq_len = 16
    dim_model = 512
    num_heads = 8
    
    # 随机生成输入张量
    x = torch.randn(batch_size, seq_len, dim_model)
    # 生成掩码（屏蔽序列后8位）
    mask = torch.ones(batch_size, seq_len, seq_len)
    mask[:, :, 8:] = 0
    
    # 初始化Multi-Head Attention并执行前向传播
    multi_head_attn = MultiHeadAttention(dim_model=dim_model, num_heads=num_heads)
    mh_output, mh_att_weights = multi_head_attn(x, mask=mask)
    
    # 验证输出结果
    print("=== Multi-Head Attention测试结果 ===")
    print(f"输入形状：{x.shape}")
    print(f"输出形状：{mh_output.shape}（预期：[{batch_size}, {seq_len}, {dim_model}]）")
    print(f"注意力权重形状：{mh_att_weights.shape}（预期：[{batch_size}, {seq_len}, {seq_len}]）")
    print(f"掩码有效性：第一批次第一序列后8位权重均值 = {mh_att_weights[0, 0, 8:].mean():.6f}（接近0为正常）")

=== Multi-Head Attention测试结果 ===
输入形状：torch.Size([2, 16, 512])
输出形状：torch.Size([2, 16, 512])（预期：[2, 16, 512]）
注意力权重形状：torch.Size([2, 16, 16])（预期：[2, 16, 16]）
掩码有效性：第一批次第一序列后8位权重均值 = 0.000000（接近0为正常）
