In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


# 复用已实现的注意力相关组件
class SelfAttention(nn.Module):
    """Scaled Dot-product Attention层：计算注意力分数并缩放，支持可选掩码"""
    def __init__(self, dim_q, dim_k, dim_v):
        super(SelfAttention, self).__init__()
        self.linear_q = nn.Linear(dim_q, dim_k, bias=False)
        self.linear_k = nn.Linear(dim_q, dim_k, bias=False)
        self.linear_v = nn.Linear(dim_q, dim_v, bias=False)
        self._norm_fact = 1 / math.sqrt(dim_k)  # 缩放因子，避免分数值过大
        self.dim_q = dim_q
        self.dim_k = dim_k
        self.dim_v = dim_v

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_q = x.shape
        assert dim_q == self.dim_q, f"输入维度{dim_q}与初始化dim_q{self.dim_q}不匹配"
        
        # 生成Q、K、V向量
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        
        # 计算注意力分数并缩放
        attn_scores = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact
        # 应用掩码（屏蔽无效位置）
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # 注意力权重归一化与上下文向量计算
        attn_weights = F.softmax(attn_scores, dim=-1)
        att = torch.bmm(attn_weights, v)
        return att, attn_weights


class MultiHeadAttention(nn.Module):
    """Multi-Head Attention层：将输入拆分为多子空间并行计算注意力，提升表达能力"""
    def __init__(self, dim_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert dim_model % num_heads == 0, f"输入维度{dim_model}需能被头数{num_heads}整除"
        
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.head_dim = dim_model // num_heads  # 单个注意力头的维度
        
        # Q、K、V线性映射层
        self.linear_q = nn.Linear(dim_model, dim_model, bias=False)
        self.linear_k = nn.Linear(dim_model, dim_model, bias=False)
        self.linear_v = nn.Linear(dim_model, dim_model, bias=False)
        # 复用Scaled Dot-product Attention
        self.self_attn = SelfAttention(dim_q=self.head_dim, dim_k=self.head_dim, dim_v=self.head_dim)
        # 多头结果拼接后的线性变换层
        self.linear_out = nn.Linear(dim_model, dim_model, bias=False)

    def _split_heads(self, x):
        """将输入拆分为多个注意力头，确保张量内存连续"""
        batch_size, seq_len, dim_model = x.shape
        return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _concat_heads(self, x):
        """将多个注意力头的输出拼接为完整维度"""
        batch_size, num_heads, seq_len, head_dim = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_model = x.shape
        assert dim_model == self.dim_model, f"输入维度{dim_model}与初始化dim_model{self.dim_model}不匹配"
        
        # Q、K、V线性映射
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        
        # 拆分多头并计算注意力
        q_split = self._split_heads(q)
        k_split = self._split_heads(k)
        v_split = self._split_heads(v)
        
        # 展平批次与头维度，适配SelfAttention输入
        q_reshaped = q_split.view(-1, seq_len, self.head_dim)
        k_reshaped = k_split.view(-1, seq_len, self.head_dim)
        v_reshaped = v_split.view(-1, seq_len, self.head_dim)
        mask_reshaped = mask.repeat(self.num_heads, 1, 1) if mask is not None else None
        
        att_split, att_weights_split = self.self_attn(q_reshaped, mask=mask_reshaped)
        
        # 拼接多头结果并线性变换
        att_reshaped = att_split.view(batch_size, self.num_heads, seq_len, self.head_dim)
        att_concat = self._concat_heads(att_reshaped)
        out = self.linear_out(att_concat)
        
        # 计算所有头的平均注意力权重
        att_weights = att_weights_split.view(batch_size, self.num_heads, seq_len, seq_len).mean(dim=1)
        return out, att_weights


# 实现Add&Norm层（Transformer Encoder核心组件）
class AddNorm(nn.Module):
    """Add&Norm层：通过残差连接缓解梯度消失，通过层归一化稳定训练"""
    def __init__(self, dim_model, eps=1e-6):
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(dim_model, eps=eps)  # 层归一化（对特征维度归一化）
        # 残差连接权重：增强灵活性，初始为纯残差
        self.residual_weight = nn.Parameter(torch.ones(1))

    def forward(self, x, residual):
        """
        Args:
            x: 当前模块输出（如注意力层输出），形状[batch_size, seq_len, dim_model]
            residual: 残差输入（模块原始输入），形状与x一致
        Returns:
            out: Add&Norm后输出，形状与x一致
        """
        # 残差相加（Add）
        add_out = x + self.residual_weight * residual
        # 层归一化（Norm）
        out = self.norm(add_out)
        return out


# 实现Feed Forward层（Transformer Encoder核心组件）
class FeedForward(nn.Module):
    """前馈网络：对注意力输出做非线性变换，增强模型表达能力"""
    def __init__(self, dim_model, hidden_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(dim_model, hidden_dim, bias=True)  # 特征升维
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout)  # Dropout正则化，防止过拟合
        self.linear2 = nn.Linear(hidden_dim, dim_model, bias=True)  # 特征降维回原维度

    def forward(self, x):
        """
        Args:
            x: 输入张量，形状[batch_size, seq_len, dim_model]
        Returns:
            out: 前馈网络输出，形状与x一致
        """
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        out = self.linear2(x)
        return out


# 实现单个Encoder Layer（Transformer Encoder的基础单元）
class EncoderLayer(nn.Module):
    """单个Encoder Layer：注意力层→Add&Norm→前馈网络→Add&Norm的经典结构"""
    def __init__(self, dim_model, num_heads, hidden_dim=2048, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.multi_head_attn = MultiHeadAttention(dim_model=dim_model, num_heads=num_heads)
        self.add_norm1 = AddNorm(dim_model=dim_model)
        self.feed_forward = FeedForward(dim_model=dim_model, hidden_dim=hidden_dim, dropout=dropout)
        self.add_norm2 = AddNorm(dim_model=dim_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: 输入张量，形状[batch_size, seq_len, dim_model]
            mask: 可选掩码，形状[batch_size, seq_len, seq_len]
        Returns:
            out: 单个Encoder Layer输出，形状与x一致
            att_weights: 注意力权重，形状[batch_size, seq_len, seq_len]
        """
        # 注意力层 + 第一次Add&Norm
        att_out, att_weights = self.multi_head_attn(x, mask=mask)
        att_out = self.dropout(att_out)
        add_norm1_out = self.add_norm1(att_out, residual=x)
        
        # 前馈网络 + 第二次Add&Norm
        ff_out = self.feed_forward(add_norm1_out)
        ff_out = self.dropout(ff_out)
        out = self.add_norm2(ff_out, residual=add_norm1_out)
        
        return out, att_weights


# 实现完整Transformer Encoder
class TransformerEncoder(nn.Module):
    """Transformer Encoder：堆叠多个Encoder Layer，实现序列的深度特征编码"""
    def __init__(self, dim_model=512, num_heads=8, num_layers=6, hidden_dim=2048, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        # 堆叠指定数量的Encoder Layer
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(
                dim_model=dim_model,
                num_heads=num_heads,
                hidden_dim=hidden_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        self.dim_model = dim_model

    def forward(self, x, mask=None):
        """
        Args:
            x: 输入序列张量，形状[batch_size, seq_len, dim_model]
            mask: 可选掩码，形状[batch_size, seq_len, seq_len]
        Returns:
            out: Encoder最终输出，形状[batch_size, seq_len, dim_model]
            all_att_weights: 所有层的注意力权重，形状[num_layers, batch_size, seq_len, seq_len]
        """
        batch_size, seq_len, dim_model = x.shape
        assert dim_model == self.dim_model, f"输入维度{dim_model}与Encoder dim_model{self.dim_model}不一致"
        
        out = x
        all_att_weights = []  # 记录所有层的注意力权重，便于后续分析
        
        # 逐层传递计算
        for encoder_layer in self.encoder_layers:
            out, att_weights = encoder_layer(out, mask=mask)
            all_att_weights.append(att_weights)
        
        # 整理注意力权重形状
        all_att_weights = torch.stack(all_att_weights, dim=0)
        return out, all_att_weights


# 测试Transformer Encoder功能
if __name__ == "__main__":
    # 1. 测试参数配置（参考Transformer常规设置）
    batch_size = 2       # 批次大小
    seq_len = 16         # 序列长度
    dim_model = 512      # 输入特征维度
    num_heads = 8        # 注意力头数（确保dim_model能被整除）
    num_layers = 6       # Encoder堆叠层数
    
    # 2. 生成随机输入张量
    x = torch.randn(batch_size, seq_len, dim_model)  # [2, 16, 512]
    
    # 3. 生成掩码（屏蔽序列后8个无效位置）
    mask = torch.ones(batch_size, seq_len, seq_len)  # [2, 16, 16]
    mask[:, :, 8:] = 0
    
    # 4. 初始化Encoder并执行前向传播
    encoder = TransformerEncoder(
        dim_model=dim_model,
        num_heads=num_heads,
        num_layers=num_layers
    )
    encoder_out, encoder_att_weights = encoder(x, mask=mask)
    
    # 5. 验证输出结果
    print("=== Transformer Encoder测试结果 ===")
    print(f"输入随机矩阵形状：{x.shape}")
    print(f"Encoder输出形状：{encoder_out.shape}（预期：[{batch_size}, {seq_len}, {dim_model}]）")
    print(f"注意力权重形状：{encoder_att_weights.shape}（预期：[{num_layers}, {batch_size}, {seq_len}, {seq_len}]）")
    print(f"掩码有效性：第1层第1批次第1序列后8位权重均值 = {encoder_att_weights[0, 0, 0, 8:].mean():.6f}（接近0为正常）")
    print("\nTransformer Encoder实现完成，可正常用于后续任务！")

=== Transformer Encoder测试结果 ===
输入随机矩阵形状：torch.Size([2, 16, 512])
Encoder输出形状：torch.Size([2, 16, 512])（预期：[2, 16, 512]）
注意力权重形状：torch.Size([6, 2, 16, 16])（预期：[6, 2, 16, 16]）
掩码有效性：第1层第1批次第1序列后8位权重均值 = 0.000000（接近0为正常）

Transformer Encoder实现完成，可正常用于后续任务！
