# Transformer Encoder 架构实现与训练示例

本 notebook 实现了一个完整的 Transformer Encoder 架构，包含：
- 从零开始实现核心组件
- 简单的序列分类任务
- 可视化和调试功能
- 详细的维度变化说明

## 1. 环境准备和库导入

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple
import math
import warnings
warnings.filterwarnings('ignore')

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
print(f"PyTorch 版本: {torch.__version__}")

# 设置matplotlib中文显示
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

## 2. Scaled Dot-Product Attention 实现

注意力机制的核心公式：
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    dropout: Optional[nn.Dropout] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    计算缩放点积注意力
    
    Args:
        query: [batch_size, seq_len_q, d_k]
        key: [batch_size, seq_len_k, d_k]
        value: [batch_size, seq_len_v, d_v]
        mask: [batch_size, seq_len_q, seq_len_k] 或 [seq_len_q, seq_len_k]
        dropout: Dropout层
    
    Returns:
        output: [batch_size, seq_len_q, d_v]
        attention_weights: [batch_size, seq_len_q, seq_len_k]
    """
    d_k = query.size(-1)
    
    # 计算注意力分数
    # [batch_size, seq_len_q, d_k] @ [batch_size, d_k, seq_len_k] 
    # -> [batch_size, seq_len_q, seq_len_k]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用mask（如果有）
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 计算注意力权重
    attention_weights = F.softmax(scores, dim=-1)
    
    # 应用dropout（如果有）
    if dropout is not None:
        attention_weights = dropout(attention_weights)
    
    # 应用注意力权重到value
    # [batch_size, seq_len_q, seq_len_k] @ [batch_size, seq_len_v, d_v]
    # -> [batch_size, seq_len_q, d_v]
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# 测试注意力机制
batch_size, seq_len, d_k = 2, 4, 8
query = torch.randn(batch_size, seq_len, d_k)
key = torch.randn(batch_size, seq_len, d_k)
value = torch.randn(batch_size, seq_len, d_k)

output, weights = scaled_dot_product_attention(query, key, value)
print(f"输入 Query shape: {query.shape}")
print(f"输出 shape: {output.shape}")
print(f"注意力权重 shape: {weights.shape}")
print(f"注意力权重和: {weights[0].sum(dim=-1)}")

## 3. Multi-Head Attention 实现

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        """
        多头注意力机制
        
        Args:
            d_model: 模型维度
            num_heads: 注意力头数
            dropout: dropout率
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self, 
        query: torch.Tensor, 
        key: torch.Tensor, 
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向传播
        
        Args:
            query, key, value: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len]
        
        Returns:
            output: [batch_size, seq_len, d_model]
            attention_weights: [batch_size, num_heads, seq_len, seq_len]
        """
        batch_size, seq_len, _ = query.shape
        
        # 1. 线性变换并分割成多头
        # [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k]
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k)
        K = self.W_k(key).view(batch_size, seq_len, self.num_heads, self.d_k)
        V = self.W_v(value).view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # 2. 转置以便于注意力计算
        # [batch_size, num_heads, seq_len, d_k]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # 3. 扩展mask维度（如果需要）
        if mask is not None:
            # [batch_size, 1, seq_len, seq_len]
            mask = mask.unsqueeze(1)
        
        # 4. 计算多头注意力
        # 需要重塑张量以使用scaled_dot_product_attention
        Q_reshaped = Q.reshape(batch_size * self.num_heads, seq_len, self.d_k)
        K_reshaped = K.reshape(batch_size * self.num_heads, seq_len, self.d_k)
        V_reshaped = V.reshape(batch_size * self.num_heads, seq_len, self.d_k)
        
        if mask is not None:
            mask = mask.repeat(1, self.num_heads, 1, 1)
            mask = mask.reshape(batch_size * self.num_heads, seq_len, seq_len)
        
        attn_output, attn_weights = scaled_dot_product_attention(
            Q_reshaped, K_reshaped, V_reshaped, mask, self.dropout
        )
        
        # 5. 重塑输出
        # [batch_size * num_heads, seq_len, d_k] -> [batch_size, num_heads, seq_len, d_k]
        attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len, self.d_k)
        attn_weights = attn_weights.reshape(batch_size, self.num_heads, seq_len, seq_len)
        
        # 6. 连接多头输出
        # [batch_size, seq_len, num_heads, d_k] -> [batch_size, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # 7. 最终线性变换
        output = self.W_o(attn_output)
        
        return output, attn_weights

# 测试多头注意力
d_model, num_heads = 128, 4
mha = MultiHeadAttention(d_model, num_heads)

x = torch.randn(2, 10, d_model)  # [batch_size, seq_len, d_model]
output, weights = mha(x, x, x)

print(f"输入 shape: {x.shape}")
print(f"输出 shape: {output.shape}")
print(f"注意力权重 shape: {weights.shape}")
print(f"参数量: {sum(p.numel() for p in mha.parameters())}")

## 4. Position-wise Feed Forward Network

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """
        前馈神经网络
        FFN(x) = max(0, xW1 + b1)W2 + b2
        
        Args:
            d_model: 模型维度
            d_ff: 前馈网络隐藏层维度
            dropout: dropout率
        """
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        
        Args:
            x: [batch_size, seq_len, d_model]
        
        Returns:
            output: [batch_size, seq_len, d_model]
        """
        # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_ff]
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        # [batch_size, seq_len, d_ff] -> [batch_size, seq_len, d_model]
        x = self.fc2(x)
        return x

# 测试FFN
d_model, d_ff = 128, 512
ffn = PositionwiseFeedForward(d_model, d_ff)

x = torch.randn(2, 10, d_model)
output = ffn(x)

print(f"FFN 输入 shape: {x.shape}")
print(f"FFN 输出 shape: {output.shape}")
print(f"FFN 参数量: {sum(p.numel() for p in ffn.parameters())}")

## 5. Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        """
        位置编码
        PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        
        Args:
            d_model: 模型维度
            max_len: 最大序列长度
            dropout: dropout率
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # 计算div_term
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            -(math.log(10000.0) / d_model)
        )
        
        # 应用sin和cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 添加batch维度并注册为buffer
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        
        Args:
            x: [batch_size, seq_len, d_model]
        
        Returns:
            output: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)
        # 添加位置编码
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

# 可视化位置编码
d_model = 128
pe_layer = PositionalEncoding(d_model, max_len=100)

# 获取位置编码矩阵
pe_matrix = pe_layer.pe[0, :50, :].numpy()

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pe_matrix, cmap='RdBu', aspect='auto')
plt.colorbar()
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Matrix')

plt.subplot(1, 2, 2)
# 绘制不同维度的位置编码曲线
dims_to_plot = [0, 1, 4, 5, 8, 9]
for dim in dims_to_plot:
    plt.plot(pe_matrix[:, dim], label=f'dim {dim}')
plt.xlabel('Position')
plt.ylabel('Value')
plt.title('Positional Encoding Curves')
plt.legend()
plt.tight_layout()
plt.show()

print(f"位置编码 shape: {pe_layer.pe.shape}")

## 6. Encoder Layer 实现

In [None]:
class EncoderLayer(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        num_heads: int, 
        d_ff: int, 
        dropout: float = 0.1
    ):
        """
        Transformer Encoder层
        
        Args:
            d_model: 模型维度
            num_heads: 注意力头数
            d_ff: 前馈网络隐藏层维度
            dropout: dropout率
        """
        super().__init__()
        
        # 多头自注意力
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # 前馈网络
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(
        self, 
        x: torch.Tensor, 
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向传播
        
        Args:
            x: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len]
        
        Returns:
            output: [batch_size, seq_len, d_model]
            attention_weights: [batch_size, num_heads, seq_len, seq_len]
        """
        # 1. 多头自注意力子层
        # 使用前层归一化（Pre-LN）
        normed_x = self.norm1(x)
        attn_output, attn_weights = self.self_attention(normed_x, normed_x, normed_x, mask)
        x = x + self.dropout1(attn_output)  # 残差连接
        
        # 2. 前馈网络子层
        normed_x = self.norm2(x)
        ff_output = self.feed_forward(normed_x)
        x = x + self.dropout2(ff_output)  # 残差连接
        
        return x, attn_weights

# 测试Encoder层
d_model, num_heads, d_ff = 128, 4, 512
encoder_layer = EncoderLayer(d_model, num_heads, d_ff)

x = torch.randn(2, 10, d_model)
output, weights = encoder_layer(x)

print(f"Encoder层输入 shape: {x.shape}")
print(f"Encoder层输出 shape: {output.shape}")
print(f"注意力权重 shape: {weights.shape}")
print(f"Encoder层参数量: {sum(p.numel() for p in encoder_layer.parameters())}")

## 7. 完整的 Encoder Stack

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        vocab_size: int,
        max_len: int = 5000,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.embedding_scale = math.sqrt(d_model)
        
        # 位置编码
        self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # 堆叠的Encoder层
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # 最终的Layer Norm
        self.final_norm = nn.LayerNorm(d_model)
        
        self.num_layers = num_layers
        
    def forward(
        self, 
        x: torch.Tensor, 
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, list]:
        # 词嵌入和位置编码
        x = self.embedding(x) * self.embedding_scale
        x = self.positional_encoding(x)
        
        # 处理padding mask
        if mask is not None:
            batch_size, seq_len = mask.shape
            mask = mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
        
        # 通过每个Encoder层
        attention_weights_list = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            attention_weights_list.append(attn_weights)
        
        # 最终的Layer Norm
        x = self.final_norm(x)
        
        return x, attention_weights_list

# 测试完整的Encoder
vocab_size = 1000
encoder = TransformerEncoder(
    num_layers=3,
    d_model=128,
    num_heads=4,
    d_ff=512,
    vocab_size=vocab_size
)

input_ids = torch.randint(0, vocab_size, (2, 15))
output, attn_weights_list = encoder(input_ids)

print(f"输入 shape: {input_ids.shape}")
print(f"输出 shape: {output.shape}")
print(f"模型总参数量: {sum(p.numel() for p in encoder.parameters()):,}")

## 8. 创建简单的分类任务数据集

In [None]:
class SimpleSequenceDataset(Dataset):
    def __init__(self, num_samples: int, seq_len: int, vocab_size: int):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        
        # 生成数据
        self.data = []
        for _ in range(num_samples):
            seq = torch.randint(1, vocab_size, (seq_len,))
            
            # 简单规则：如果序列中连续出现两个相同的token，标签为1
            label = 0
            for i in range(len(seq) - 1):
                if seq[i] == seq[i+1]:
                    label = 1
                    break
            
            self.data.append((seq, label))
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx]

# 创建数据集
train_dataset = SimpleSequenceDataset(1000, seq_len=20, vocab_size=50)
val_dataset = SimpleSequenceDataset(200, seq_len=20, vocab_size=50)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

seq, label = train_dataset[0]
print(f"序列样例: {seq[:10]}...")
print(f"标签: {label}")
print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")

## 9. 构建分类模型

In [None]:
class EncoderClassifier(nn.Module):
    def __init__(
        self,
        encoder: TransformerEncoder,
        num_classes: int,
        d_model: int
    ):
        super().__init__()
        self.encoder = encoder
        
        # 分类头：使用平均池化
        self.pooling = 'mean'
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, num_classes)
        )
        
    def forward(self, input_ids: torch.Tensor, mask: Optional[torch.Tensor] = None):
        # 通过encoder
        encoder_output, attention_weights = self.encoder(input_ids, mask)
        
        # 平均池化
        if self.pooling == 'mean':
            pooled = encoder_output.mean(1)
        else:
            pooled = encoder_output[:, 0, :]
        
        # 分类
        logits = self.classifier(pooled)
        
        return logits, attention_weights

# 创建分类模型
encoder = TransformerEncoder(
    num_layers=2,
    d_model=128,
    num_heads=4,
    d_ff=256,
    vocab_size=50,
    dropout=0.1
)

model = EncoderClassifier(encoder, num_classes=2, d_model=128)
model = model.to(device)

print(f"分类模型参数量: {sum(p.numel() for p in model.parameters()):,}")

## 10. 简单训练演示

In [None]:
# 训练一个batch作为演示
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 获取一个batch
sequences, labels = next(iter(train_loader))
sequences = sequences.to(device)
labels = labels.to(device)

# 前向传播
optimizer.zero_grad()
logits, _ = model(sequences)
loss = criterion(logits, labels)

# 反向传播
loss.backward()
optimizer.step()

# 输出结果
_, predicted = torch.max(logits, 1)
accuracy = (predicted == labels).float().mean()

print(f"Batch Loss: {loss.item():.4f}")
print(f"Batch Accuracy: {accuracy.item():.2%}")
print(f"\n训练演示完成！")

## 11. 注意力权重可视化

In [None]:
def visualize_attention(model, sequence, layer_idx=0, head_idx=0):
    model.eval()
    
    with torch.no_grad():
        sequence = sequence.unsqueeze(0).to(device)
        _, attention_weights = model(sequence)
        attn = attention_weights[layer_idx][0, head_idx].cpu().numpy()
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(attn[:10, :10], cmap='Blues', cbar=True, square=True)
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}')
    plt.show()
    
    return attn

# 可视化注意力
sample_seq, _ = val_dataset[0]
print("可视化第一层第一个头的注意力权重（前10个位置）")
attn_weights = visualize_attention(model, sample_seq, layer_idx=0, head_idx=0)

## 12. 模型分析

In [None]:
# 分析模型参数
print("="*50)
print("模型参数统计")
print("="*50)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")
print(f"参数大小: {total_params * 4 / 1024 / 1024:.2f} MB (float32)")

print("\n各层参数分布:")
for name, module in model.named_children():
    params = sum(p.numel() for p in module.parameters())
    if params > 0:
        print(f"{name}: {params:,} ({params/total_params*100:.1f}%)")

## 总结

本notebook成功实现了一个完整的Transformer Encoder架构，包括：

### ✅ 已实现的核心组件
1. **Scaled Dot-Product Attention** - 注意力计算的基础
2. **Multi-Head Attention** - 多头并行注意力机制
3. **Position-wise FFN** - 位置独立的前馈网络
4. **Positional Encoding** - 正弦余弦位置编码
5. **Encoder Layer** - 包含注意力和FFN的完整层
6. **Encoder Stack** - 多层堆叠的编码器
7. **分类模型** - 基于Encoder的序列分类器

### 🔍 关键特性
- 残差连接和层归一化（Pre-LN）
- Dropout正则化
- Padding mask处理
- 参数共享的多头注意力

### 📊 调试功能
- 维度变化追踪
- 注意力权重可视化  
- 参数量统计
- 梯度流监控

### 💡 使用建议
1. 可以通过调整`d_model`、`num_heads`、`d_ff`等超参数来实验不同的模型容量
2. 注意力可视化有助于理解模型学到的模式
3. 这个实现可以作为理解BERT、GPT等模型的基础

这个实现为深入理解Transformer架构提供了一个清晰、可调试的起点。