Skip to content

Long-LLM/Transformer-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

手写 Transformer:从零开始构建一个 Decoder-only 语言模型

代码地址:https://github.com/Long-LLM/Transformer-model


目录

  1. 我们要实现什么?
  2. 项目文件结构
  3. 超参数配置:config.py
  4. 模型核心:model.py
  5. 训练流程:train.py
  6. 推理生成:inference.py
  7. 总结

1. 我们要实现什么?

我们要实现的是一个 Decoder-only 的 Transformer 语言模型,和 GPT 系列的结构一致。整体架构如下:

Model (完整模型)
├── Token Embedding (词嵌入)
├── Positional Encoding (位置编码)
├── N × TransformerBlock (多个 Transformer 块)
│   ├── LayerNorm
│   ├── Multi-Head Attention
│   ├── LayerNorm
│   └── Feed Forward Network
├── Final LayerNorm (最后的归一化)
└── Output Linear (输出投影到词表)

数据流向:输入 Token → Token Embedding + Positional Encoding → N 个 Transformer Block → Final LayerNorm → Linear → Logits → 预测下一个 Token。


2. 项目文件结构

MyLLM/
├── config.py              # 超参数集中配置
├── model.py               # Transformer 模型定义
├── train.py               # 训练脚本
├── inference.py           # 推理/生成脚本
├── Chinese_poetry.txt     # 本地数据集
├── best_model.pt          # 训练保存的最优模型
└── loss_curve.png         # 损失曲线图

3. 超参数配置:config.py

训练前,我们把所有超参数集中放在 config.py 中,方便统一管理和复现。

"""共享超参数配置"""

batch_size = 4                   # 每个批次的样本数
context_length = 16             # 模型能看到的最大上下文长度(序列长度)
d_model = 64                    # 模型隐藏层维度(Embedding 维度、Attention 维度等)
num_blocks = 4                   # Transformer Block 的堆叠层数
num_heads = 4                    # 多头注意力的头数
learning_rate = 1e-4             # Adam 优化器的学习率
dropout = 0.2                    # Dropout 概率,防止过拟合
max_iters = 500                # 训练总步数
eval_interval = 500              # 每隔多少步做一次验证
eval_iters = 10                  # 每次验证抽取多少个 batch 计算平均损失

import torch
# 自动选择设备:优先 CUDA,其次 Intel XPU,最后 CPU
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
    device = 'xpu'
else:
    device = 'cpu'

TORCH_SEED = 1337                # 固定随机种子,保证实验可复现

gradient_accumulation_steps = 2  # 梯度累积步数:变相扩大 batch size,训练更稳定
gradient_checkpointing = True    # 梯度检查点:用计算换显存,可节省约 50% 显存

关键设计说明

  • context_length = 16:模型一次最多看 16 个 token,超过的部分在生成时会被截断,用16仅用于演示流程。
  • d_model = 64num_heads = 4:每个头的维度 d_k = 64// 4 = 16
  • gradient_accumulation_steps = 2:实际等效 batch size = 4 × 2 = 8,小显存也能训大 batch。
  • gradient_checkpointing = True:Transformer Block 的前向传播不保存中间激活值,反向传播时重新计算,大幅降低显存占用。

4. 模型核心:model.py

model.py 是整个项目的核心,包含以下组件:

类名 作用
FeedForwardNet 前馈网络(FFN)
ScaledDotProductAttention 单头缩放点积注意力
MultiHeadAttention 多头注意力(并行多个单头)
TransformerBlock Transformer 基本块(Attention + FFN + 残差)
Model 完整模型(Embedding + N×Block + LayerNorm + Linear)

下面逐模块详细讲解。


4.1 Feed Forward Network(前馈网络)

4.1.1 结构回顾

FFN 是一个两层的全连接网络,作用是为模型引入非线性变换和更强的表达能力:

输入 [batch, seq, d_model]
     ↓
Linear1: d_model → d_model × 4
     ↓
SiLU 激活(Swish)
     ↓
Linear2: d_model × 4 → d_model
     ↓
Dropout
     ↓
输出 [batch, seq, d_model]

先扩展 4 倍再缩回,这是 Transformer 论文中的经典设计。现代模型(如 LLaMA、Mistral)通常将 ReLU 替换为 SiLU(也称 Swish),训练更稳定。

4.1.2 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.checkpoint import checkpoint


class FeedForwardNet(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        # 使用 nn.Sequential 将三层操作打包:
        # 1. 第一层线性变换:d_model → d_model * 4,升维
        # 2. SiLU 激活函数:SiLU(x) = x * sigmoid(x),比 ReLU 更平滑
        # 3. 第二层线性变换:d_model * 4 → d_model,降维
        # 4. Dropout:随机置零部分神经元输出,防止过拟合
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: [B, T, d_model]
        # 输出: [B, T, d_model],维度不变,内容经过非线性变换
        return self.net(x)

4.1.3 维度追踪与代码解读

代码 作用 维度变化
nn.Linear(d_model, d_model * 4) 第一层全连接升维 [B, T, 64] → [B, T, 256]
nn.SiLU() 激活函数,引入非线性 不变
nn.Linear(d_model * 4, d_model) 第二层全连接降维 [B, T, 256] → [B, T, 64]
nn.Dropout(dropout) 随机丢弃,防止过拟合 不变

4.2 Scaled Dot-Product Attention(单头注意力)

4.2.1 公式回顾

单头注意力的核心公式:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

代码实现需要完成以下步骤:

  1. 通过线性层从输入 x 生成 Q(Query)、K(Key)、V(Value)
  2. 计算注意力分数:Q @ K^T
  3. 缩放:除以 √d_k,防止点积结果过大导致 Softmax 梯度消失
  4. 应用 Causal Mask(下三角掩码),防止模型看到未来的 Token
  5. Softmax 归一化,使每行和为 1
  6. V 相乘,得到加权的输出表示

4.2.2 代码实现

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, d_k, context_length):
        super().__init__()
        self.d_k = d_k  # 每个头的维度(单头的维度)

        # 三个独立的线性层,分别将输入 x 投影到 Q、K、V 空间
        # bias=False:注意力层通常不使用偏置,减少参数量
        self.Wq = nn.Linear(d_model, d_k, bias=False)
        self.Wk = nn.Linear(d_model, d_k, bias=False)
        self.Wv = nn.Linear(d_model, d_k, bias=False)

        # Causal Mask(因果掩码):下三角矩阵
        # register_buffer:将 mask 注册为模型的持久状态(非参数,不参与训练)
        # 但会随模型一起移动到 GPU,并被保存在 state_dict 中
        self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        B, T, C = x.shape  # B: Batch size, T: 序列长度, C: d_model(输入维度)

        # 1. 生成 Q, K, V:通过线性变换将输入投影到 d_k 维度
        Q = self.Wq(x)  # [B, T, d_k]
        K = self.Wk(x)  # [B, T, d_k]
        V = self.Wv(x)  # [B, T, d_k]

        # 2. 计算注意力分数:Q 和 K 的点积
        # Q: [B, T, d_k], K^T: [B, d_k, T] → 结果: [B, T, T]
        # 每个位置 (i, j) 表示第 i 个 token 对第 j 个 token 的关注程度
        attention = (Q @ K.transpose(-1, -2)) / math.sqrt(self.d_k)

        # 3. 应用 Causal Mask(未来位置设为 -inf,Softmax 后变为 0)
        # mask[:T, :T] 适配当前序列长度(可能小于最大 context_length)
        # masked_fill:将 mask 为 0 的位置(上三角)填充为 -inf
        attention = attention.masked_fill(self.mask[:T, :T] == 0, float('-inf'))

        # 4. Softmax 归一化:每行(每个 Query)的注意力权重和为 1
        attention = F.softmax(attention, dim=-1)

        # 5. 与 V 相乘:用注意力权重对 Value 做加权求和
        # attention: [B, T, T], V: [B, T, d_k] → 输出: [B, T, d_k]
        attention = attention @ V
        return attention

4.2.3 关键代码解读

Causal Mask 的构造:

self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))

torch.tril 生成下三角矩阵:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]
  • 位置 i 只能看到位置 0i 的信息
  • 位置 i 看不到 i+1 及之后的信息(被掩码为 -inf
  • 这保证了模型的自回归特性:预测下一个 token 时不能偷看未来

为什么用 register_buffer

  • Mask 不是可学习的参数(不需要梯度更新)
  • 但它需要随模型一起 .to(device) 移动到 GPU
  • register_buffer 就是专门用于这种持久化的非参数张量

4.3 Multi-Head Attention(多头注意力)

4.3.1 核心思想

多头注意力 = 多个单头注意力并行计算,最后拼接并投影。

每个注意力头可以学习到不同的关注模式(如有的关注语法,有的关注语义),多个头的信息拼接后能得到更丰富的表示。

本项目采用的是物理上分开的实现方式:每个头有独立的 Wq/Wk/Wv。虽然效率略低于论文版的逻辑切分,但代码更直观、更易理解。

4.3.2 代码实现

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, context_length, dropout):
        super().__init__()
        # 计算每个头的维度:总维度均分给所有头
        # 例如 d_model=512, num_heads=8 → d_k=64
        self.d_k = d_model // num_heads

        # 创建 num_heads 个独立的 ScaledDotProductAttention 头
        # nn.ModuleList:PyTorch 管理子模块的列表,确保参数被正确注册
        self.heads = nn.ModuleList([
            ScaledDotProductAttention(d_model, self.d_k, context_length)
            for _ in range(num_heads)
        ])

        # 输出投影层 Wo:将拼接后的多头输出投影回 d_model 维度
        # 这是多头注意力的最后一个线性变换
        self.projection_layer = nn.Linear(d_model, d_model)

        # Dropout,防止过拟合
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 1. 并行运行所有注意力头,收集每个头的输出
        # 每个 head(x) 的输出形状: [B, T, d_k]
        output = torch.cat([head(x) for head in self.heads], dim=-1)
        # 拼接后: [B, T, num_heads * d_k] = [B, T, d_model]

        # 2. 通过输出投影层 Wo
        output = self.projection_layer(output)

        # 3. 应用 Dropout
        output = self.dropout(output)
        return output

4.3.3 维度追踪

假设 d_model = 64num_heads = 4

输入 x:        [B, T, 64]
     ↓
每个头输出:     [B, T, 16]      # d_k = 64 // 4 = 16
     ↓
拼接 8 个头:    [B, T, 64]     # 16 × 4 = 64
     ↓
Wo 投影:       [B, T, 64]
     ↓
Dropout 后输出: [B, T, 64]

关键公式d_k = d_model // num_heads,必须整除。


4.4 Transformer Block(Transformer 块)

4.4.1 结构回顾

本项目采用 Pre-Norm 结构(GPT-2、LLaMA 等现代模型都在用),每个 Block 包含两个子层:

输入 x
     ↓
LayerNorm → Multi-Head Attention → 与 x 残差相加
     ↓
LayerNorm → Feed Forward Network → 与 x 残差相加
     ↓
输出

相比原始 Transformer 的 Post-Norm,Pre-Norm 将 LayerNorm 放在子层之前,训练更稳定,深层网络也不容易梯度爆炸或消失。

4.4.2 代码实现

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, context_length, dropout, use_gradient_checkpointing=False):
        super().__init__()
        # 是否启用梯度检查点:用额外的计算换取显存节省
        self.use_gradient_checkpointing = use_gradient_checkpointing

        # 两个 LayerNorm:分别作用于 Attention 子层和 FFN 子层之前
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

        # 多头注意力子层
        self.multihead_attention = MultiHeadAttention(d_model, num_heads, context_length, dropout)

        # 前馈网络子层
        self.feed_forward_network = FeedForwardNet(d_model, dropout)

    def _attn_block(self, x):
        """Attention 子层的前向计算(封装,用于梯度检查点)"""
        return self.multihead_attention(x)

    def _ffn_block(self, x):
        """FFN 子层的前向计算(封装,用于梯度检查点)"""
        return self.feed_forward_network(x)

    def forward(self, x):
        if self.use_gradient_checkpointing and self.training:
            # ========== 梯度检查点模式(训练时省显存)==========
            # checkpoint 的原理:前向时不保存中间激活值,反向传播时重新计算
            # use_reentrant=False:使用非重入检查点,推荐用于新版 PyTorch
            # 代价:训练速度略微变慢(因为反向时要重算两次前向)
            # 收益:显存占用可降低约 30%~50%,深层模型必开

            # 子层1:LayerNorm → Attention → 残差连接
            x = x + checkpoint(self._attn_block, self.layer_norm1(x), use_reentrant=False)
            # 子层2:LayerNorm → FFN → 残差连接
            x = x + checkpoint(self._ffn_block, self.layer_norm2(x), use_reentrant=False)
        else:
            # ========== 普通模式(推理或关闭 checkpoint 时)==========
            # 先对输入做 LayerNorm,再过 Attention,最后与原始输入 x 做残差相加
            x = x + self.multihead_attention(self.layer_norm1(x))
            # 先对输入做 LayerNorm,再过 FFN,最后与原始输入 x 做残差相加
            x = x + self.feed_forward_network(self.layer_norm2(x))
        return x

4.4.3 Pre-Norm vs Post-Norm

结构 公式 特点
Pre-Norm(本项目) x = x + Sublayer(LayerNorm(x)) LayerNorm 在子层之前,训练更稳定,现代模型主流
Post-Norm(原始 Transformer) x = LayerNorm(x + Sublayer(x)) LayerNorm 在子层之后,深层时训练容易不稳定

为什么 Pre-Norm 更稳定?

  • Pre-Norm 中,子层的输入已经被归一化到均值为 0、方差为 1 的分布,梯度流经残差连接时不会爆炸
  • 深层堆叠时,每一层的梯度都能通过残差连接直接回传,缓解了梯度消失

4.4.4 梯度检查点(Gradient Checkpointing)详解

from torch.utils.checkpoint import checkpoint

x = x + checkpoint(self._attn_block, self.layer_norm1(x), use_reentrant=False)
  • 问题:训练深层 Transformer 时,中间激活值占用大量显存(和层数成正比)
  • 方案checkpoint 只保存输入,不保存中间结果;反向传播时,用保存的输入重新计算前向,再计算梯度
  • trade-off:显存 ↓,计算时间 ↑(大约增加 20% 前向计算量)
  • 适用场景:显存有限但想训大模型、大 batch 时必开

4.5 完整 Model 类

4.5.1 整体结构

class Model(nn.Module):
    def __init__(self, max_token_value, d_model, num_blocks, num_heads,
                 context_length, dropout, use_gradient_checkpointing=False):
        super().__init__()
        self.context_length = context_length
        self.d_model = d_model

        # 1. Token Embedding:将离散的 token ID 映射为稠密的 d_model 维向量
        # 词表大小为 max_token_value,每个 token 对应一个 d_model 维的向量
        self.token_embedding_lookup_table = nn.Embedding(max_token_value, d_model)

        # 2. N 个 Transformer Block 堆叠
        # 使用 nn.ModuleList 而不是 nn.Sequential,因为每个 Block 结构相同但参数独立
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, context_length, dropout, use_gradient_checkpointing)
            for _ in range(num_blocks)
        ])

        # 3. 最后的 LayerNorm:在所有 Block 之后做一次归一化
        self.finally_layer = nn.LayerNorm(d_model)

        # 4. 输出投影层:将 d_model 维隐藏状态映射回词表大小,得到每个 token 的预测 logits
        self.model_out_linear_layer = nn.Linear(d_model, max_token_value)

4.5.2 前向传播(Forward)

    def forward(self, idx, targets=None):
        B, T = idx.shape  # B: batch size, T: 当前序列长度
        device = idx.device  # 获取输入所在的设备(CPU/GPU)

        # ========== 1. 正弦/余弦位置编码(Sinusoidal Positional Encoding)==========
        # 创建一个 (context_length, d_model) 的零矩阵
        position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model, device=device)

        # position: [context_length, 1],表示每个位置的索引 0, 1, 2, ...
        position = torch.arange(0, self.context_length, dtype=torch.float).unsqueeze(1)

        # div_term:不同维度的频率衰减项
        # 偶数维和奇数维使用不同的频率,形成从低到高的正弦波
        # 公式:10000^(-2i/d_model),其中 i 是维度索引
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model)
        )

        # 偶数维(0, 2, 4, ...)用 sin
        position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
        # 奇数维(1, 3, 5, ...)用 cos
        position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)

        # 截取前 T 个位置(因为当前序列可能短于最大长度)
        position_embedding = position_encoding_lookup_table[:T, :]

        # ========== 2. Token Embedding + Position Encoding ==========
        # Token Embedding 提供语义信息,Position Embedding 提供位置信息
        # 两者相加后,模型既知道"是什么词",也知道"词在句子哪个位置"
        x = self.token_embedding_lookup_table(idx) + position_embedding

        # ========== 3. 通过所有 Transformer Blocks ==========
        for block in self.transformer_blocks:
            x = block(x)

        # ========== 4. Final LayerNorm ==========
        x = self.finally_layer(x)

        # ========== 5. 输出投影到词表 ==========
        # logits: [B, T, max_token_value],每个位置对应词表中每个词的得分
        logits = self.model_out_linear_layer(x)

        # ========== 6. 计算损失(仅在训练时)==========
        if targets is not None:
            B, T, C = logits.shape
            # 将 logits 展平为 [B*T, C],targets 展平为 [B*T]
            # 这样可以直接用 F.cross_entropy 计算每个位置的分类损失
            logits = logits.reshape(B * T, C)
            targets = targets.reshape(B * T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None

        return logits, loss

4.5.3 关键代码解读

位置编码公式:

$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) $$

$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) $$

  • 偶数维度用 sin,奇数维度用 cos
  • 维度越高,波长越长(频率越低),模型可以通过不同频率感知相对位置
  • 和可学习的位置编码不同,正弦编码是固定的、与训练无关的,但能很好地泛化到比训练时更长的序列

损失函数:

loss = F.cross_entropy(logits, targets)
  • 这是语言模型训练的核心目标:预测下一个 token
  • targets 是输入 idx 向后偏移一位的结果(在 train.py 中构造)
  • 模型在每个位置 t 的目标是预测位置 t+1 的真实 token

4.5.4 生成函数(Generate)

    @torch.no_grad()  # 禁用梯度计算,节省显存,加速推理
    def generate(self, idx, max_new_tokens=100):
        """
        自回归文本生成

        Args:
            idx: 初始 token 序列 [B, T]
            max_new_tokens: 最多生成多少个新 token

        Returns:
            idx: 拼接后的完整序列 [B, T + max_new_tokens]
        """
        for _ in range(max_new_tokens):
            # 1. 截断到最大上下文长度:模型只能看最后 context_length 个 token
            # 如果序列太长,前面的 token 会被丢弃
            idx_crop = idx[:, -self.context_length:]

            # 2. 前向传播,获取 logits
            logits, loss = self.forward(idx_crop)

            # 3. 只取最后一个时间步的 logits(预测下一个 token)
            # logits_last_timestep: [B, max_token_value]
            logits_last_timestep = logits[:, -1, :]

            # 4. Softmax 转换为概率分布
            probs = F.softmax(logits_last_timestep, dim=-1)

            # 5. 多项式采样:从概率分布中随机抽取 1 个 token
            # 概率越大的 token,被抽中的可能性越高
            idx_next = torch.multinomial(probs, num_samples=1)

            # 6. 将新 token 拼接到序列末尾
            idx = torch.cat((idx, idx_next), dim=-1)

        return idx

自回归生成流程:

初始 prompt: [那天]
     ↓
预测第 1 个新 token → 拼接
[那天我]
     ↓
预测第 2 个新 token → 拼接
[那天我想]
     ↓
...重复 max_new_tokens 次

为什么每次只取 logits[:, -1, :]

  • 语言模型的训练目标就是预测下一个 token
  • 序列最后一个位置的输出,编码了之前所有上下文的信息,专门用于预测下一个词
  • 生成的新 token 又会作为下一步的输入,循环往复

5. 训练流程:train.py

模型定义好了,下面看如何训练。train.py 完整覆盖了数据加载、模型初始化、训练循环、验证保存和可视化。

import torch
import tiktoken
import matplotlib
matplotlib.use('Agg')  # 使用非交互式后端,适合服务器/无显示器环境
import matplotlib.pyplot as plt

from config import *            # 导入所有超参数
from model import Model         # 导入模型



def main():
    # 固定随机种子,保证每次训练结果可复现
    torch.manual_seed(TORCH_SEED)

    # ===================== 1. 加载数据 =====================
    with open('Chinese_poetry.txt', 'r', encoding='utf-8') as f:
        text = f.read()
    print(f"✅ 数据集加载完成!总文本长度:{len(text):,} 字符")

    # 使用 OpenAI 的 cl100k_base 分词器(GPT-4 同款)
    encoding = tiktoken.get_encoding("cl100k_base")
    tokenized_text = encoding.encode(text)

    # 词表大小 = 最大 token ID + 1(因为 ID 从 0 开始)
    max_token_value = max(tokenized_text) + 1

    # 90% 训练集,10% 验证集
    train_size = int(len(tokenized_text) * 0.9)
    train_data = tokenized_text[:train_size]
    valid_data = tokenized_text[train_size:]

    # ===================== 2. 初始化模型 =====================
    model = Model(
        max_token_value=max_token_value,
        d_model=d_model,
        num_blocks=num_blocks,
        num_heads=num_heads,
        context_length=context_length,
        dropout=dropout,
        use_gradient_checkpointing=gradient_checkpointing
    ).to(device)

    # 统计并打印模型参数量
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M")

    # ===================== 3. 获取批次函数 =====================
    def get_batch(split: str):
        """
        从训练集或验证集中随机采样一个 batch

        Args:
            split: 'train' 或 'valid'

        Returns:
            x: 输入序列 [batch_size, context_length]
            y: 目标序列(x 向后偏移一位)[batch_size, context_length]
        """
        data = train_data if split == 'train' else valid_data
        # 随机采样 batch_size 个起始位置
        idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,))
        # 构建输入 x 和目标 y
        x = torch.stack([torch.tensor(data[idx:idx + context_length]) for idx in idxs]).to(device)
        y = torch.stack([torch.tensor(data[idx + 1:idx + context_length + 1]) for idx in idxs]).to(device)
        return x, y

    # ===================== 4. 评估损失函数 =====================
    @torch.no_grad()
    def estimate_loss():
        """
        在训练集和验证集上分别计算平均损失
        用于监控过拟合(train loss ↓ 但 valid loss ↑ 说明过拟合)
        """
        out = {}
        model.eval()  # 切换到评估模式(关闭 Dropout)
        for split in ['train', 'valid']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                x_batch, y_batch = get_batch(split)
                logits, loss = model(x_batch, y_batch)
                losses[k] = loss.item()
            out[split] = losses.mean()
        model.train()  # 切回训练模式
        return out

    # ===================== 5. 训练循环 =====================
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    tracked_losses = []      # 记录每次评估的 loss,用于绘图
    best_val_loss = float('inf')
    best_model_path = 'best_model.pt'

    for step in range(max_iters):
        # 每隔 eval_interval 步做一次验证,最后一步也验证
        if step % eval_interval == 0 or step == max_iters - 1:
            losses = estimate_loss()
            tracked_losses.append(losses)
            print(f"step: {step}, Training loss: {losses['train'].item():.4f}, "
                  f"Validation loss: {losses['valid'].item():.4f}")

            # 保存验证损失最低的模型(早停/最优模型选择策略)
            current_val_loss = losses['valid'].item()
            if current_val_loss < best_val_loss:
                best_val_loss = current_val_loss
                torch.save(model.state_dict(), best_model_path)
                print(f"  已保存最优模型!最优验证损失: {best_val_loss:.4f}")

        # ===================== 6. 梯度累积 =====================
        accumulated_loss = 0
        for micro_step in range(gradient_accumulation_steps):
            xb, yb = get_batch('train')
            logits, loss = model(xb, yb)
            # 损失除以累积步数:等效于扩大 batch size
            # 例如 batch_size=4, accumulation=2 → 等效 batch=8
            loss = loss / gradient_accumulation_steps
            loss.backward()  # 反向传播,累加梯度
            accumulated_loss += loss.item()

        # 累积完成后,统一更新参数并清空梯度
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

    # ===================== 7. 绘制 Loss 曲线 =====================
    def plot_loss_curve(tracked_losses, eval_interval):
        train_losses = [loss['train'].item() for loss in tracked_losses]
        val_losses = [loss['valid'].item() for loss in tracked_losses]
        steps = [i * eval_interval for i in range(len(train_losses))]

        plt.figure(figsize=(10, 5))
        plt.plot(steps, train_losses, label='Training Loss', color='blue')
        plt.plot(steps, val_losses, label='Validation Loss', color='red')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.title('Training & Validation Loss Curve')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig('loss_curve.png', dpi=300, bbox_inches='tight')
        plt.close()
        print('已保存损失曲线图:loss_curve.png')

    plot_loss_curve(tracked_losses, eval_interval)
    print(f"训练完成,最优模型已保存至: {best_model_path}")


if __name__ == '__main__':
    main()

5.1 关键设计解读

梯度累积(Gradient Accumulation):

for micro_step in range(gradient_accumulation_steps):
    ...
    loss = loss / gradient_accumulation_steps
    loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
  • 显存不够大 batch?那就用小 batch 多跑几次,梯度累加起来再更新
  • loss / accumulation_steps:保证等效学习率不变
  • 等效 batch size = batch_size × gradient_accumulation_steps = 4 × 2 = 8

最优模型保存(Early Stopping):

if current_val_loss < best_val_loss:
    best_val_loss = current_val_loss
    torch.save(model.state_dict(), best_model_path)
  • 不保存最后一步的模型,而是保存验证损失最低的模型
  • 防止过拟合:训练后期的模型可能在训练集上 loss 更低,但在验证集上表现更差

6. 推理生成:inference.py

训练完成后,用 inference.py 加载最优模型并生成文本。

import torch
import tiktoken

from config import *
from model import Model



def main():
    # ============= 1. 加载数据(仅用于获取 vocab size)====================
    with open('Chinese_poetry.txt', 'r', encoding='utf-8') as f:
        text = f.read()

    encoding = tiktoken.get_encoding("cl100k_base")
    tokenized_text = encoding.encode(text)
    max_token_value = max(tokenized_text) + 1

    # ===================== 2. 初始化模型并加载权重 =====================
    model = Model(
        max_token_value=max_token_value,
        d_model=d_model,
        num_blocks=num_blocks,
        num_heads=num_heads,
        context_length=context_length,
        dropout=dropout,
        use_gradient_checkpointing=gradient_checkpointing
    ).to(device)

    # 加载训练时保存的最优模型权重
    model.load_state_dict(torch.load('best_model.pt', map_location=device))
    model.eval()  # 推理时一定要设为 eval 模式!
    print("模型加载成功!")

    # ===================== 3. 文本生成 =====================
    start = '经月愁闻雨'  # 提示词(prompt)
    start_idx = encoding.encode(start)
    x = torch.tensor(start_idx, dtype=torch.long, device=device)[None, ...]
    # [None, ...] 增加 batch 维度:从 [T] 变为 [1, T]

    # 调用模型的自回归生成方法,生成 200 个新 token
    y = model.generate(x, max_new_tokens=16)

    print('---------------')
    print(encoding.decode(y[0].tolist()))
    print('---------------')


if __name__ == '__main__':
    main()

6.1 推理注意事项

  1. model.eval() 必须在推理前调用

    • 关闭 Dropout,否则每次生成结果会随机变化
    • 关闭 BatchNorm/LayerNorm 的训练时统计更新(虽然 LayerNorm 不影响,但养成好习惯)
  2. map_location=device

    • 如果模型在 CUDA 上训练,但在 CPU 上推理,用这个参数自动映射设备
    • 避免 RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False
  3. Prompt 编码

    • 必须用和训练时完全相同的 tokenizercl100k_base
    • 输入需要增加 batch 维度:[T] → [1, T]

7. 总结

7.1 完整数据流回顾

[输入 Token IDs]              shape: [B, T]
        ↓
Token Embedding               shape: [B, T, d_model]
+ Positional Encoding         shape: [T, d_model]
        ↓
┌─────────────────────────────────────┐
│  TransformerBlock × num_blocks      │
│  ├── LayerNorm + Multi-Head Attn    │
│  └── LayerNorm + FFN                │
└─────────────────────────────────────┘
        ↓
Final LayerNorm               shape: [B, T, d_model]
        ↓
Linear → Logits               shape: [B, T, vocab_size]
        ↓
Softmax → 采样 → 下一个 Token

7.2 本项目的技术特点

特性 实现方式 说明
位置编码 正弦/余弦(Sinusoidal) 固定编码,可外推到更长序列
归一化 Pre-LayerNorm 训练稳定,深层堆叠友好
激活函数 SiLU(Swish) 比 ReLU 更平滑,现代 LLM 标配
多头实现 物理分开(多个独立头) 代码直观,易于理解
显存优化 Gradient Checkpointing 用计算换显存,小显卡也能训大模型
训练策略 梯度累积 + 最优模型保存 等效大 batch,防止过拟合
Tokenizer tiktoken (cl100k_base) GPT-4 同款,中文支持好

7.3 如何进一步扩展

  1. 加入 Temperature 和 Top-K/Top-P 采样:当前 generate 是贪婪/纯采样,加入这些参数可以控制生成的多样性和质量
  2. 使用学习率衰减(LR Decay):当前是固定学习率,使用 Cosine Annealing 或 Warmup 通常能提升最终效果
  3. 混合精度训练(AMP)torch.cuda.amp 可以进一步加速训练、降低显存
  4. 权重初始化:当前使用 PyTorch 默认初始化,可以参考 GPT-2 的初始化策略
  5. 旋转位置编码(RoPE):替代正弦编码,现代模型(LLaMA、Qwen)的主流选择

至此,你已经完整理解了 Transformer 的每一行代码。建议对照 model.pytrain.pyconfig.pyinference.py 亲手跑一遍训练,观察 loss 的下降和生成效果的变化,这才是"手写 Transformer"的最佳实践。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages