# 模型
我们会设计一个最基础的 GPT 风格（decoder-only）架构的模型来完成生成诗词这个主题，如果你想了解这个架构以及关于 Attention 机制的一些基本原理，我强烈推荐Karpathy的讲解视频，如果你还不了解 Transformer 或者 Attention，那么值得一看。
- [Andrej Karpathy - GPT from scratch](https://www.youtube.com/watch?v=kCc8FmEb1nY)
- [中文翻译](https://www.bilibili.com/video/BV1K4LPzLEoA/)
- [对应的 Notebook(Google Colab)](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)
- [相关文档(Github)](https://github.com/karpathy/ng-video-lecture)

我们会使用一个非常小型的 GPT 风格模型，然后尝试训练它从适应诗词文本和结构开始，逐步到能够按照我们想要的作者以及形式来生成“以假乱真”的诗词。

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F


class FeedForward(nn.Module):
    """前馈层"""

    def __init__(self, emb_size, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_size, 4 * emb_size),
            nn.GELU(),
            nn.Linear(4 * emb_size, emb_size),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [2]:
class ParallelMultiHeadAttention(nn.Module):
    """
    并行化的多头注意力

    关键优化：
    1. 一次性计算所有头的 Q、K、V
    2. 使用 reshape + transpose 实现多头
    3. 并行计算所有头的注意力
    """

    def __init__(self, emb_size, head_num, block_size, dropout=0.0):
        super().__init__()
        assert emb_size % head_num == 0, "emb_size 必须能被 head_num 整除"

        self.emb_size = emb_size
        self.head_num = head_num
        self.head_size = emb_size // head_num

        # ✅ 优化 1: 一次性计算所有头的 Q、K、V
        # 传统方式：每个头 3 次 Linear = head_num * 3 次操作
        # 优化方式：1 次大的 Linear = 1 次操作
        self.qkv = nn.Linear(emb_size, 3 * emb_size, bias=False)

        # 输出投影
        self.proj = nn.Linear(emb_size, emb_size)
        self.dropout = nn.Dropout(dropout)

        # Causal mask
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape  # Batch, Time, Channels

        # ✅ 优化 2: 一次性计算 Q、K、V
        qkv = self.qkv(x)  # (B, T, 3*C)
        q, k, v = qkv.split(self.emb_size, dim=-1)  # 每个都是 (B, T, C)

        # ✅ 优化 3: Reshape 成多头格式
        # (B, T, C) -> (B, T, num_heads, head_size) -> (B, num_heads, T, head_size)
        q = q.view(B, T, self.head_num, self.head_size).transpose(1, 2)  # (B, heads, T, hs)
        k = k.view(B, T, self.head_num, self.head_size).transpose(1, 2)  # (B, heads, T, hs)
        v = v.view(B, T, self.head_num, self.head_size).transpose(1, 2)  # (B, heads, T, hs)

        # ✅ 优化 4: 并行计算所有头的注意力
        # (B, heads, T, hs) @ (B, heads, hs, T) -> (B, heads, T, T)
        wei = (q @ k.transpose(-2, -1)) * (self.head_size ** -0.5)

        # Causal masking
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # (B, heads, T, T) @ (B, heads, T, hs) -> (B, heads, T, hs)
        out = wei @ v

        # ✅ 优化 5: Concat 所有头
        # (B, heads, T, hs) -> (B, T, heads, hs) -> (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # 输出投影
        out = self.proj(out)
        out = self.dropout(out)

        return out

In [3]:
class TransformerBlock(nn.Module):
    """Transformer Block"""

    def __init__(self, emb_size, head_num, block_size, dropout=0.0):
        super().__init__()
        self.sa = ParallelMultiHeadAttention(emb_size, head_num, block_size, dropout)
        self.ffwd = FeedForward(emb_size, dropout)
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [4]:
class GPTLanguageModel(nn.Module):
    """GPT 语言模型"""

    def __init__(self, vocab_size, emb_size, block_size, layer_num, head_num, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, emb_size)
        self.position_embedding = nn.Embedding(block_size, emb_size)

        # ✅ 使用并行版本的 Transformer Block
        self.blocks = nn.Sequential(*[
            TransformerBlock(emb_size, head_num, block_size, dropout)
            for _ in range(layer_num)
        ])

        self.ln_f = nn.LayerNorm(emb_size)
        self.lm_head = nn.Linear(emb_size, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits_flat = logits.view(B * T, C)
            targets_flat = targets.view(B * T)
            loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, stop_token_ids=None):
        """
        生成文本

        Args:
            idx: 初始 token IDs (B, T)
            max_new_tokens: 最多生成多少个 token
            temperature: 温度参数
            top_k: top-k 采样
            stop_token_ids: 停止 token 序列（如 [60, 61, 62, 63, 64] 代表 '<EOP>'）

        Returns:
            生成的完整序列 (B, T+generated)
        """
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

            # 检查是否生成了停止 token 序列
            if stop_token_ids is not None and len(stop_token_ids) > 0:
                # 检查最后 N 个 token 是否匹配停止序列
                if idx.shape[1] >= len(stop_token_ids):
                    last_tokens = idx[0, -len(stop_token_ids):].tolist()
                    if last_tokens == stop_token_ids:
                        break

        return idx

现在，结合准备阶段我们获得的一些数据的统计信息，让我们来来粗略设计一个合适的模型。
- Token 种类 11868
- PRE 数据量 11186141
- MID 数据量 16853134
- SFT 数据量 1806430
- MID 数据中最长序列 444
- SFT 数据中最长序列 242 --> 模型的上下文窗口可以设定在 256，来覆盖所有需要生成的情况

根据 Chinchilla scaling laws（最优训练tokens ≈ 20×模型参数量），这里我们以 PRE 阶段和 MID 阶段的数据总量作为参考，大概做一下下面的设计：
1. EMB_SIZE 256
2. BLOCK_SIZE 256
3. HEAD_NUM 8
4. LAYER_NUM 8

In [13]:
# 先加载一下数据，然后加载一下 tokenizer
from nanopoet.dataset import load_raw_data
from nanopoet.common import CharTokenizer, BEGIN

tokenizer = CharTokenizer(
    raw_text="".join(["".join(list(d.values())) for d in load_raw_data("../raw")])
)
model = GPTLanguageModel(
    vocab_size=tokenizer.vocab_size,
    emb_size=256,
    block_size=256,
    layer_num=8,
    head_num=8,
    dropout=0.1,
)

# 统计一下模型的参数量
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("模型参数量", total, "可训练参数量", trainable)

model.eval()
# 编码一个BEGIN作为提示词，让 model 随机生成一些内容做测试
context = torch.tensor([tokenizer.encode(BEGIN)], dtype=torch.long)
output = model.generate(context, max_new_tokens=10)
print("输出内容",tokenizer.decode(output.tolist()[0]))

模型参数量： 12466268 可训练参数量 12466268
输出内容 B笏膴披覔覇裓住兢參耄


这样，我们的模型参数总量大约是 12M，根据 Chinchilla scaling laws 的建议，训练数据大约需要 0.25B。考虑 PRE 和 MID 的训练阶段，PRE 训练 10 个 Epoch, MID 训练 10 个，这样总的参数量大约是 0.27B，基本满足要求。