代码地址:https://github.com/Long-LLM/Transformer-model
我们要实现的是一个 Decoder-only 的 Transformer 语言模型,和 GPT 系列的结构一致。整体架构如下:
Model (完整模型)
├── Token Embedding (词嵌入)
├── Positional Encoding (位置编码)
├── N × TransformerBlock (多个 Transformer 块)
│ ├── LayerNorm
│ ├── Multi-Head Attention
│ ├── LayerNorm
│ └── Feed Forward Network
├── Final LayerNorm (最后的归一化)
└── Output Linear (输出投影到词表)
数据流向:输入 Token → Token Embedding + Positional Encoding → N 个 Transformer Block → Final LayerNorm → Linear → Logits → 预测下一个 Token。
MyLLM/
├── config.py # 超参数集中配置
├── model.py # Transformer 模型定义
├── train.py # 训练脚本
├── inference.py # 推理/生成脚本
├── Chinese_poetry.txt # 本地数据集
├── best_model.pt # 训练保存的最优模型
└── loss_curve.png # 损失曲线图
训练前,我们把所有超参数集中放在 config.py 中,方便统一管理和复现。
"""共享超参数配置"""
batch_size = 4 # 每个批次的样本数
context_length = 16 # 模型能看到的最大上下文长度(序列长度)
d_model = 64 # 模型隐藏层维度(Embedding 维度、Attention 维度等)
num_blocks = 4 # Transformer Block 的堆叠层数
num_heads = 4 # 多头注意力的头数
learning_rate = 1e-4 # Adam 优化器的学习率
dropout = 0.2 # Dropout 概率,防止过拟合
max_iters = 500 # 训练总步数
eval_interval = 500 # 每隔多少步做一次验证
eval_iters = 10 # 每次验证抽取多少个 batch 计算平均损失
import torch
# 自动选择设备:优先 CUDA,其次 Intel XPU,最后 CPU
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
device = 'xpu'
else:
device = 'cpu'
TORCH_SEED = 1337 # 固定随机种子,保证实验可复现
gradient_accumulation_steps = 2 # 梯度累积步数:变相扩大 batch size,训练更稳定
gradient_checkpointing = True # 梯度检查点:用计算换显存,可节省约 50% 显存关键设计说明:
context_length = 16:模型一次最多看 16 个 token,超过的部分在生成时会被截断,用16仅用于演示流程。d_model = 64,num_heads = 4:每个头的维度d_k = 64// 4 = 16。gradient_accumulation_steps = 2:实际等效 batch size =4 × 2 = 8,小显存也能训大 batch。gradient_checkpointing = True:Transformer Block 的前向传播不保存中间激活值,反向传播时重新计算,大幅降低显存占用。
model.py 是整个项目的核心,包含以下组件:
| 类名 | 作用 |
|---|---|
FeedForwardNet |
前馈网络(FFN) |
ScaledDotProductAttention |
单头缩放点积注意力 |
MultiHeadAttention |
多头注意力(并行多个单头) |
TransformerBlock |
Transformer 基本块(Attention + FFN + 残差) |
Model |
完整模型(Embedding + N×Block + LayerNorm + Linear) |
下面逐模块详细讲解。
FFN 是一个两层的全连接网络,作用是为模型引入非线性变换和更强的表达能力:
输入 [batch, seq, d_model]
↓
Linear1: d_model → d_model × 4
↓
SiLU 激活(Swish)
↓
Linear2: d_model × 4 → d_model
↓
Dropout
↓
输出 [batch, seq, d_model]
先扩展 4 倍再缩回,这是 Transformer 论文中的经典设计。现代模型(如 LLaMA、Mistral)通常将 ReLU 替换为 SiLU(也称 Swish),训练更稳定。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.checkpoint import checkpoint
class FeedForwardNet(nn.Module):
def __init__(self, d_model, dropout):
super().__init__()
# 使用 nn.Sequential 将三层操作打包:
# 1. 第一层线性变换:d_model → d_model * 4,升维
# 2. SiLU 激活函数:SiLU(x) = x * sigmoid(x),比 ReLU 更平滑
# 3. 第二层线性变换:d_model * 4 → d_model,降维
# 4. Dropout:随机置零部分神经元输出,防止过拟合
self.net = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.SiLU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout),
)
def forward(self, x):
# x: [B, T, d_model]
# 输出: [B, T, d_model],维度不变,内容经过非线性变换
return self.net(x)| 代码 | 作用 | 维度变化 |
|---|---|---|
nn.Linear(d_model, d_model * 4) |
第一层全连接升维 | [B, T, 64] → [B, T, 256] |
nn.SiLU() |
激活函数,引入非线性 | 不变 |
nn.Linear(d_model * 4, d_model) |
第二层全连接降维 | [B, T, 256] → [B, T, 64] |
nn.Dropout(dropout) |
随机丢弃,防止过拟合 | 不变 |
单头注意力的核心公式:
代码实现需要完成以下步骤:
- 通过线性层从输入
x生成 Q(Query)、K(Key)、V(Value) - 计算注意力分数:
Q @ K^T - 缩放:除以
√d_k,防止点积结果过大导致 Softmax 梯度消失 - 应用 Causal Mask(下三角掩码),防止模型看到未来的 Token
- Softmax 归一化,使每行和为 1
- 与 V 相乘,得到加权的输出表示
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, d_k, context_length):
super().__init__()
self.d_k = d_k # 每个头的维度(单头的维度)
# 三个独立的线性层,分别将输入 x 投影到 Q、K、V 空间
# bias=False:注意力层通常不使用偏置,减少参数量
self.Wq = nn.Linear(d_model, d_k, bias=False)
self.Wk = nn.Linear(d_model, d_k, bias=False)
self.Wv = nn.Linear(d_model, d_k, bias=False)
# Causal Mask(因果掩码):下三角矩阵
# register_buffer:将 mask 注册为模型的持久状态(非参数,不参与训练)
# 但会随模型一起移动到 GPU,并被保存在 state_dict 中
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
def forward(self, x):
B, T, C = x.shape # B: Batch size, T: 序列长度, C: d_model(输入维度)
# 1. 生成 Q, K, V:通过线性变换将输入投影到 d_k 维度
Q = self.Wq(x) # [B, T, d_k]
K = self.Wk(x) # [B, T, d_k]
V = self.Wv(x) # [B, T, d_k]
# 2. 计算注意力分数:Q 和 K 的点积
# Q: [B, T, d_k], K^T: [B, d_k, T] → 结果: [B, T, T]
# 每个位置 (i, j) 表示第 i 个 token 对第 j 个 token 的关注程度
attention = (Q @ K.transpose(-1, -2)) / math.sqrt(self.d_k)
# 3. 应用 Causal Mask(未来位置设为 -inf,Softmax 后变为 0)
# mask[:T, :T] 适配当前序列长度(可能小于最大 context_length)
# masked_fill:将 mask 为 0 的位置(上三角)填充为 -inf
attention = attention.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
# 4. Softmax 归一化:每行(每个 Query)的注意力权重和为 1
attention = F.softmax(attention, dim=-1)
# 5. 与 V 相乘:用注意力权重对 Value 做加权求和
# attention: [B, T, T], V: [B, T, d_k] → 输出: [B, T, d_k]
attention = attention @ V
return attentionCausal Mask 的构造:
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))torch.tril 生成下三角矩阵:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
- 位置
i只能看到位置0到i的信息 - 位置
i看不到i+1及之后的信息(被掩码为-inf) - 这保证了模型的自回归特性:预测下一个 token 时不能偷看未来
为什么用 register_buffer?
- Mask 不是可学习的参数(不需要梯度更新)
- 但它需要随模型一起
.to(device)移动到 GPU register_buffer就是专门用于这种持久化的非参数张量
多头注意力 = 多个单头注意力并行计算,最后拼接并投影。
每个注意力头可以学习到不同的关注模式(如有的关注语法,有的关注语义),多个头的信息拼接后能得到更丰富的表示。
本项目采用的是物理上分开的实现方式:每个头有独立的 Wq/Wk/Wv。虽然效率略低于论文版的逻辑切分,但代码更直观、更易理解。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, context_length, dropout):
super().__init__()
# 计算每个头的维度:总维度均分给所有头
# 例如 d_model=512, num_heads=8 → d_k=64
self.d_k = d_model // num_heads
# 创建 num_heads 个独立的 ScaledDotProductAttention 头
# nn.ModuleList:PyTorch 管理子模块的列表,确保参数被正确注册
self.heads = nn.ModuleList([
ScaledDotProductAttention(d_model, self.d_k, context_length)
for _ in range(num_heads)
])
# 输出投影层 Wo:将拼接后的多头输出投影回 d_model 维度
# 这是多头注意力的最后一个线性变换
self.projection_layer = nn.Linear(d_model, d_model)
# Dropout,防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 1. 并行运行所有注意力头,收集每个头的输出
# 每个 head(x) 的输出形状: [B, T, d_k]
output = torch.cat([head(x) for head in self.heads], dim=-1)
# 拼接后: [B, T, num_heads * d_k] = [B, T, d_model]
# 2. 通过输出投影层 Wo
output = self.projection_layer(output)
# 3. 应用 Dropout
output = self.dropout(output)
return output假设 d_model = 64,num_heads = 4:
输入 x: [B, T, 64]
↓
每个头输出: [B, T, 16] # d_k = 64 // 4 = 16
↓
拼接 8 个头: [B, T, 64] # 16 × 4 = 64
↓
Wo 投影: [B, T, 64]
↓
Dropout 后输出: [B, T, 64]
关键公式:d_k = d_model // num_heads,必须整除。
本项目采用 Pre-Norm 结构(GPT-2、LLaMA 等现代模型都在用),每个 Block 包含两个子层:
输入 x
↓
LayerNorm → Multi-Head Attention → 与 x 残差相加
↓
LayerNorm → Feed Forward Network → 与 x 残差相加
↓
输出
相比原始 Transformer 的 Post-Norm,Pre-Norm 将 LayerNorm 放在子层之前,训练更稳定,深层网络也不容易梯度爆炸或消失。
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, context_length, dropout, use_gradient_checkpointing=False):
super().__init__()
# 是否启用梯度检查点:用额外的计算换取显存节省
self.use_gradient_checkpointing = use_gradient_checkpointing
# 两个 LayerNorm:分别作用于 Attention 子层和 FFN 子层之前
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
# 多头注意力子层
self.multihead_attention = MultiHeadAttention(d_model, num_heads, context_length, dropout)
# 前馈网络子层
self.feed_forward_network = FeedForwardNet(d_model, dropout)
def _attn_block(self, x):
"""Attention 子层的前向计算(封装,用于梯度检查点)"""
return self.multihead_attention(x)
def _ffn_block(self, x):
"""FFN 子层的前向计算(封装,用于梯度检查点)"""
return self.feed_forward_network(x)
def forward(self, x):
if self.use_gradient_checkpointing and self.training:
# ========== 梯度检查点模式(训练时省显存)==========
# checkpoint 的原理:前向时不保存中间激活值,反向传播时重新计算
# use_reentrant=False:使用非重入检查点,推荐用于新版 PyTorch
# 代价:训练速度略微变慢(因为反向时要重算两次前向)
# 收益:显存占用可降低约 30%~50%,深层模型必开
# 子层1:LayerNorm → Attention → 残差连接
x = x + checkpoint(self._attn_block, self.layer_norm1(x), use_reentrant=False)
# 子层2:LayerNorm → FFN → 残差连接
x = x + checkpoint(self._ffn_block, self.layer_norm2(x), use_reentrant=False)
else:
# ========== 普通模式(推理或关闭 checkpoint 时)==========
# 先对输入做 LayerNorm,再过 Attention,最后与原始输入 x 做残差相加
x = x + self.multihead_attention(self.layer_norm1(x))
# 先对输入做 LayerNorm,再过 FFN,最后与原始输入 x 做残差相加
x = x + self.feed_forward_network(self.layer_norm2(x))
return x| 结构 | 公式 | 特点 |
|---|---|---|
| Pre-Norm(本项目) | x = x + Sublayer(LayerNorm(x)) |
LayerNorm 在子层之前,训练更稳定,现代模型主流 |
| Post-Norm(原始 Transformer) | x = LayerNorm(x + Sublayer(x)) |
LayerNorm 在子层之后,深层时训练容易不稳定 |
为什么 Pre-Norm 更稳定?
- Pre-Norm 中,子层的输入已经被归一化到均值为 0、方差为 1 的分布,梯度流经残差连接时不会爆炸
- 深层堆叠时,每一层的梯度都能通过残差连接直接回传,缓解了梯度消失
from torch.utils.checkpoint import checkpoint
x = x + checkpoint(self._attn_block, self.layer_norm1(x), use_reentrant=False)- 问题:训练深层 Transformer 时,中间激活值占用大量显存(和层数成正比)
- 方案:
checkpoint只保存输入,不保存中间结果;反向传播时,用保存的输入重新计算前向,再计算梯度 - trade-off:显存 ↓,计算时间 ↑(大约增加 20% 前向计算量)
- 适用场景:显存有限但想训大模型、大 batch 时必开
class Model(nn.Module):
def __init__(self, max_token_value, d_model, num_blocks, num_heads,
context_length, dropout, use_gradient_checkpointing=False):
super().__init__()
self.context_length = context_length
self.d_model = d_model
# 1. Token Embedding:将离散的 token ID 映射为稠密的 d_model 维向量
# 词表大小为 max_token_value,每个 token 对应一个 d_model 维的向量
self.token_embedding_lookup_table = nn.Embedding(max_token_value, d_model)
# 2. N 个 Transformer Block 堆叠
# 使用 nn.ModuleList 而不是 nn.Sequential,因为每个 Block 结构相同但参数独立
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, context_length, dropout, use_gradient_checkpointing)
for _ in range(num_blocks)
])
# 3. 最后的 LayerNorm:在所有 Block 之后做一次归一化
self.finally_layer = nn.LayerNorm(d_model)
# 4. 输出投影层:将 d_model 维隐藏状态映射回词表大小,得到每个 token 的预测 logits
self.model_out_linear_layer = nn.Linear(d_model, max_token_value) def forward(self, idx, targets=None):
B, T = idx.shape # B: batch size, T: 当前序列长度
device = idx.device # 获取输入所在的设备(CPU/GPU)
# ========== 1. 正弦/余弦位置编码(Sinusoidal Positional Encoding)==========
# 创建一个 (context_length, d_model) 的零矩阵
position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model, device=device)
# position: [context_length, 1],表示每个位置的索引 0, 1, 2, ...
position = torch.arange(0, self.context_length, dtype=torch.float).unsqueeze(1)
# div_term:不同维度的频率衰减项
# 偶数维和奇数维使用不同的频率,形成从低到高的正弦波
# 公式:10000^(-2i/d_model),其中 i 是维度索引
div_term = torch.exp(
torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model)
)
# 偶数维(0, 2, 4, ...)用 sin
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
# 奇数维(1, 3, 5, ...)用 cos
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
# 截取前 T 个位置(因为当前序列可能短于最大长度)
position_embedding = position_encoding_lookup_table[:T, :]
# ========== 2. Token Embedding + Position Encoding ==========
# Token Embedding 提供语义信息,Position Embedding 提供位置信息
# 两者相加后,模型既知道"是什么词",也知道"词在句子哪个位置"
x = self.token_embedding_lookup_table(idx) + position_embedding
# ========== 3. 通过所有 Transformer Blocks ==========
for block in self.transformer_blocks:
x = block(x)
# ========== 4. Final LayerNorm ==========
x = self.finally_layer(x)
# ========== 5. 输出投影到词表 ==========
# logits: [B, T, max_token_value],每个位置对应词表中每个词的得分
logits = self.model_out_linear_layer(x)
# ========== 6. 计算损失(仅在训练时)==========
if targets is not None:
B, T, C = logits.shape
# 将 logits 展平为 [B*T, C],targets 展平为 [B*T]
# 这样可以直接用 F.cross_entropy 计算每个位置的分类损失
logits = logits.reshape(B * T, C)
targets = targets.reshape(B * T)
loss = F.cross_entropy(logits, targets)
else:
loss = None
return logits, loss位置编码公式:
- 偶数维度用
sin,奇数维度用cos - 维度越高,波长越长(频率越低),模型可以通过不同频率感知相对位置
- 和可学习的位置编码不同,正弦编码是固定的、与训练无关的,但能很好地泛化到比训练时更长的序列
损失函数:
loss = F.cross_entropy(logits, targets)- 这是语言模型训练的核心目标:预测下一个 token
targets是输入idx向后偏移一位的结果(在train.py中构造)- 模型在每个位置
t的目标是预测位置t+1的真实 token
@torch.no_grad() # 禁用梯度计算,节省显存,加速推理
def generate(self, idx, max_new_tokens=100):
"""
自回归文本生成
Args:
idx: 初始 token 序列 [B, T]
max_new_tokens: 最多生成多少个新 token
Returns:
idx: 拼接后的完整序列 [B, T + max_new_tokens]
"""
for _ in range(max_new_tokens):
# 1. 截断到最大上下文长度:模型只能看最后 context_length 个 token
# 如果序列太长,前面的 token 会被丢弃
idx_crop = idx[:, -self.context_length:]
# 2. 前向传播,获取 logits
logits, loss = self.forward(idx_crop)
# 3. 只取最后一个时间步的 logits(预测下一个 token)
# logits_last_timestep: [B, max_token_value]
logits_last_timestep = logits[:, -1, :]
# 4. Softmax 转换为概率分布
probs = F.softmax(logits_last_timestep, dim=-1)
# 5. 多项式采样:从概率分布中随机抽取 1 个 token
# 概率越大的 token,被抽中的可能性越高
idx_next = torch.multinomial(probs, num_samples=1)
# 6. 将新 token 拼接到序列末尾
idx = torch.cat((idx, idx_next), dim=-1)
return idx自回归生成流程:
初始 prompt: [那天]
↓
预测第 1 个新 token → 拼接
[那天我]
↓
预测第 2 个新 token → 拼接
[那天我想]
↓
...重复 max_new_tokens 次
为什么每次只取 logits[:, -1, :]?
- 语言模型的训练目标就是预测下一个 token
- 序列最后一个位置的输出,编码了之前所有上下文的信息,专门用于预测下一个词
- 生成的新 token 又会作为下一步的输入,循环往复
模型定义好了,下面看如何训练。train.py 完整覆盖了数据加载、模型初始化、训练循环、验证保存和可视化。
import torch
import tiktoken
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端,适合服务器/无显示器环境
import matplotlib.pyplot as plt
from config import * # 导入所有超参数
from model import Model # 导入模型
def main():
# 固定随机种子,保证每次训练结果可复现
torch.manual_seed(TORCH_SEED)
# ===================== 1. 加载数据 =====================
with open('Chinese_poetry.txt', 'r', encoding='utf-8') as f:
text = f.read()
print(f"✅ 数据集加载完成!总文本长度:{len(text):,} 字符")
# 使用 OpenAI 的 cl100k_base 分词器(GPT-4 同款)
encoding = tiktoken.get_encoding("cl100k_base")
tokenized_text = encoding.encode(text)
# 词表大小 = 最大 token ID + 1(因为 ID 从 0 开始)
max_token_value = max(tokenized_text) + 1
# 90% 训练集,10% 验证集
train_size = int(len(tokenized_text) * 0.9)
train_data = tokenized_text[:train_size]
valid_data = tokenized_text[train_size:]
# ===================== 2. 初始化模型 =====================
model = Model(
max_token_value=max_token_value,
d_model=d_model,
num_blocks=num_blocks,
num_heads=num_heads,
context_length=context_length,
dropout=dropout,
use_gradient_checkpointing=gradient_checkpointing
).to(device)
# 统计并打印模型参数量
print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M")
# ===================== 3. 获取批次函数 =====================
def get_batch(split: str):
"""
从训练集或验证集中随机采样一个 batch
Args:
split: 'train' 或 'valid'
Returns:
x: 输入序列 [batch_size, context_length]
y: 目标序列(x 向后偏移一位)[batch_size, context_length]
"""
data = train_data if split == 'train' else valid_data
# 随机采样 batch_size 个起始位置
idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,))
# 构建输入 x 和目标 y
x = torch.stack([torch.tensor(data[idx:idx + context_length]) for idx in idxs]).to(device)
y = torch.stack([torch.tensor(data[idx + 1:idx + context_length + 1]) for idx in idxs]).to(device)
return x, y
# ===================== 4. 评估损失函数 =====================
@torch.no_grad()
def estimate_loss():
"""
在训练集和验证集上分别计算平均损失
用于监控过拟合(train loss ↓ 但 valid loss ↑ 说明过拟合)
"""
out = {}
model.eval() # 切换到评估模式(关闭 Dropout)
for split in ['train', 'valid']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
x_batch, y_batch = get_batch(split)
logits, loss = model(x_batch, y_batch)
losses[k] = loss.item()
out[split] = losses.mean()
model.train() # 切回训练模式
return out
# ===================== 5. 训练循环 =====================
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
tracked_losses = [] # 记录每次评估的 loss,用于绘图
best_val_loss = float('inf')
best_model_path = 'best_model.pt'
for step in range(max_iters):
# 每隔 eval_interval 步做一次验证,最后一步也验证
if step % eval_interval == 0 or step == max_iters - 1:
losses = estimate_loss()
tracked_losses.append(losses)
print(f"step: {step}, Training loss: {losses['train'].item():.4f}, "
f"Validation loss: {losses['valid'].item():.4f}")
# 保存验证损失最低的模型(早停/最优模型选择策略)
current_val_loss = losses['valid'].item()
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
torch.save(model.state_dict(), best_model_path)
print(f" 已保存最优模型!最优验证损失: {best_val_loss:.4f}")
# ===================== 6. 梯度累积 =====================
accumulated_loss = 0
for micro_step in range(gradient_accumulation_steps):
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
# 损失除以累积步数:等效于扩大 batch size
# 例如 batch_size=4, accumulation=2 → 等效 batch=8
loss = loss / gradient_accumulation_steps
loss.backward() # 反向传播,累加梯度
accumulated_loss += loss.item()
# 累积完成后,统一更新参数并清空梯度
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# ===================== 7. 绘制 Loss 曲线 =====================
def plot_loss_curve(tracked_losses, eval_interval):
train_losses = [loss['train'].item() for loss in tracked_losses]
val_losses = [loss['valid'].item() for loss in tracked_losses]
steps = [i * eval_interval for i in range(len(train_losses))]
plt.figure(figsize=(10, 5))
plt.plot(steps, train_losses, label='Training Loss', color='blue')
plt.plot(steps, val_losses, label='Validation Loss', color='red')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training & Validation Loss Curve')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('loss_curve.png', dpi=300, bbox_inches='tight')
plt.close()
print('已保存损失曲线图:loss_curve.png')
plot_loss_curve(tracked_losses, eval_interval)
print(f"训练完成,最优模型已保存至: {best_model_path}")
if __name__ == '__main__':
main()梯度累积(Gradient Accumulation):
for micro_step in range(gradient_accumulation_steps):
...
loss = loss / gradient_accumulation_steps
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)- 显存不够大 batch?那就用小 batch 多跑几次,梯度累加起来再更新
loss / accumulation_steps:保证等效学习率不变- 等效 batch size =
batch_size × gradient_accumulation_steps = 4 × 2 = 8
最优模型保存(Early Stopping):
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
torch.save(model.state_dict(), best_model_path)- 不保存最后一步的模型,而是保存验证损失最低的模型
- 防止过拟合:训练后期的模型可能在训练集上 loss 更低,但在验证集上表现更差
训练完成后,用 inference.py 加载最优模型并生成文本。
import torch
import tiktoken
from config import *
from model import Model
def main():
# ============= 1. 加载数据(仅用于获取 vocab size)====================
with open('Chinese_poetry.txt', 'r', encoding='utf-8') as f:
text = f.read()
encoding = tiktoken.get_encoding("cl100k_base")
tokenized_text = encoding.encode(text)
max_token_value = max(tokenized_text) + 1
# ===================== 2. 初始化模型并加载权重 =====================
model = Model(
max_token_value=max_token_value,
d_model=d_model,
num_blocks=num_blocks,
num_heads=num_heads,
context_length=context_length,
dropout=dropout,
use_gradient_checkpointing=gradient_checkpointing
).to(device)
# 加载训练时保存的最优模型权重
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.eval() # 推理时一定要设为 eval 模式!
print("模型加载成功!")
# ===================== 3. 文本生成 =====================
start = '经月愁闻雨' # 提示词(prompt)
start_idx = encoding.encode(start)
x = torch.tensor(start_idx, dtype=torch.long, device=device)[None, ...]
# [None, ...] 增加 batch 维度:从 [T] 变为 [1, T]
# 调用模型的自回归生成方法,生成 200 个新 token
y = model.generate(x, max_new_tokens=16)
print('---------------')
print(encoding.decode(y[0].tolist()))
print('---------------')
if __name__ == '__main__':
main()-
model.eval()必须在推理前调用- 关闭 Dropout,否则每次生成结果会随机变化
- 关闭 BatchNorm/LayerNorm 的训练时统计更新(虽然 LayerNorm 不影响,但养成好习惯)
-
map_location=device- 如果模型在 CUDA 上训练,但在 CPU 上推理,用这个参数自动映射设备
- 避免
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False
-
Prompt 编码
- 必须用和训练时完全相同的 tokenizer(
cl100k_base) - 输入需要增加 batch 维度:
[T] → [1, T]
- 必须用和训练时完全相同的 tokenizer(
[输入 Token IDs] shape: [B, T]
↓
Token Embedding shape: [B, T, d_model]
+ Positional Encoding shape: [T, d_model]
↓
┌─────────────────────────────────────┐
│ TransformerBlock × num_blocks │
│ ├── LayerNorm + Multi-Head Attn │
│ └── LayerNorm + FFN │
└─────────────────────────────────────┘
↓
Final LayerNorm shape: [B, T, d_model]
↓
Linear → Logits shape: [B, T, vocab_size]
↓
Softmax → 采样 → 下一个 Token
| 特性 | 实现方式 | 说明 |
|---|---|---|
| 位置编码 | 正弦/余弦(Sinusoidal) | 固定编码,可外推到更长序列 |
| 归一化 | Pre-LayerNorm | 训练稳定,深层堆叠友好 |
| 激活函数 | SiLU(Swish) | 比 ReLU 更平滑,现代 LLM 标配 |
| 多头实现 | 物理分开(多个独立头) | 代码直观,易于理解 |
| 显存优化 | Gradient Checkpointing | 用计算换显存,小显卡也能训大模型 |
| 训练策略 | 梯度累积 + 最优模型保存 | 等效大 batch,防止过拟合 |
| Tokenizer | tiktoken (cl100k_base) | GPT-4 同款,中文支持好 |
- 加入 Temperature 和 Top-K/Top-P 采样:当前
generate是贪婪/纯采样,加入这些参数可以控制生成的多样性和质量 - 使用学习率衰减(LR Decay):当前是固定学习率,使用 Cosine Annealing 或 Warmup 通常能提升最终效果
- 混合精度训练(AMP):
torch.cuda.amp可以进一步加速训练、降低显存 - 权重初始化:当前使用 PyTorch 默认初始化,可以参考 GPT-2 的初始化策略
- 旋转位置编码(RoPE):替代正弦编码,现代模型(LLaMA、Qwen)的主流选择
至此,你已经完整理解了 Transformer 的每一行代码。建议对照
model.py、train.py、config.py和inference.py亲手跑一遍训练,观察 loss 的下降和生成效果的变化,这才是"手写 Transformer"的最佳实践。