# Transformers 架构

一个单文件、可运行的 PyTorch 实现：接近现代 decoder-only Transformer（GPT 风格），  
包含关键现代改进与工程要点：  
 - 多头自注意力（scaled dot-product）
 - Rotary positional embeddings(RoPE)（可选）
 - Gated Feed-Forward（GEGLU）
 - LayerNorm + Residual 结构
 - Causal mask（自回归）
 - Mixed precision training (AMP)
 - Activation checkpointing（可选）
 - 学习率线性 warmup -> cosine decay
 - 简单训练循环（含 gradient accumulation 与 checkpoint 保存）
 - 采样（top-k / top-p）用于生成

设计目标：可读、现代化、易修改，注释突出关键细节以便教学与工程化扩展。

注意：此实现为教学/原型级别。生产环境请使用经过高度优化的内核（FlashAttention、xFormers）、分布式并行与高效 IO。

运行示例：
  python modern_transformer.py --device cuda --epochs 2


In [1]:
import math
import argparse
import time
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
# -----------------------------
# Utilities
# -----------------------------

def exists(x):
    return x is not None


def gelu_new(x):
    # GELU approximation that's widely used
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x ** 3)))

## 旋转位置编码（rope）
RoPE 的灵感来源于复数旋转。在复数空间中，用一个复数 $e^{im\theta_i}$ 去乘另一个复数 $(q_j + i q_{j+1})$，等价于将其旋转 $m\theta_i$ 弧度。  
$$
\begin{bmatrix}
   \cos{(m\theta_i)} & -\sin{(m\theta_i)} \\
   \sin{(m\theta_i)} & \cos{(m\theta_i)}
\end{bmatrix}
$$
所以，对于一对元素$[x_1, x_2]$旋转后的结果是：
$$
x_{1_rotated} = x_1 \cos{(m\theta_i)} - x_2 \sin{(m\theta_i)}\\
x_{2_rotated} = x_1 \sin{(m\theta_i)} + x_2 \cos{(m\theta_i)}
$$
代码中的 `x * cos + x_rotated * sin` 正是以向量化且高效的方式实现了上述两个公式。

In [3]:
# -----------------------------
# Rotary Positional Embeddings (RoPE)
# -----------------------------
# RoPE provides a way to inject relative-position bias into the attention by rotating query/key vectors.
# It has become very common in modern decoder-only LLMs (GPT-NeoX, etc.).

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        # dim is the per-head dim (d_head). We create cos/sin caches lazily.
        # 计算频率的倒数： theta_i = 10000^(-2i/dim)
        # 我们要把 dim 维的空间分成 dim/2 (视作复数i的两个维度)个二维子空间(步长为2)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq) #注册为不参与训练的缓冲区域
        self._seq_len_cached = 0 # 记录当前缓存的长度

    def _build_cache(self, seq_len, device):
        # 如果超过缓存则跳过
        if seq_len == self._seq_len_cached:
            return
        
        # 位置索引 m: [0, 1, 2, ..., seq_len-1]
        t = torch.arange(seq_len, device=device).type(self.inv_freq.dtype)
    
        # 外积: [seq_len] , [dim/2] -> [seq_len, dim/2]
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)  # [seq_len, dim/2]

        # 对dim/2复制，[seq_len, dim/2] -> [seq_len, dim]
        emb = torch.cat((freqs, freqs), dim=-1)  # duplicate for sin/cos pairing m*theta_i
        cos = emb.cos()[None, None, :, :]  # shape [1,1,seq_len,dim]
        sin = emb.sin()[None, None, :, :]

        # 缓存起来
        self.register_buffer('cos_cached', cos)
        self.register_buffer('sin_cached', sin)
        self._seq_len_cached = seq_len

    def forward(self, x, seq_dim=-2):
        # x是query 或者 key 向量 shape is [batch_size, n_head, seq_len, head_dim(dim))]
        seq_len = x.shape[seq_dim]
        device = x.device
        self._build_cache(seq_len, device)
        return self.cos_cached[..., :seq_len, :], self.sin_cached[..., :seq_len, :]


def apply_rotary_pos_emb(x, cos, sin):
    # x: [batch, n_head, seq_len, d_head]
    # 分成奇数和偶数两组
    x1 = x[..., ::2]
    x2 = x[..., 1::2]

    # 应用旋转嵌入公式：x_rotated = x * cos + x_rotated * sin
    # 这等价于对 x 的每一对元素 [x_i, x_{i+1}] 进行了一个旋转矩阵操作
    x_rotated = torch.stack((-x2, x1), dim=-1).reshape_as(x)
    return x * cos + x_rotated * sin

In [4]:
# -----------------------------
# Multi-head attention (with optional fused/flash path if available)
# -----------------------------

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, attn_dropout=0.0, causal=True, use_rope=True):
        super().__init__()
        assert d_model % n_head == 0, 'd_model must be divisible by n_head'
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.causal = causal
        self.use_rope = use_rope

        # We project to QKV in one linear for efficiency
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(attn_dropout)

        if use_rope:
            # Rotatory embeddings per head dimension
            self.rotary = RotaryEmbedding(self.d_head)
        else:
            self.rotary = None

    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
        # x: [B, T, d_model] == [批次大小，序列长度，模型维度]
        B, T, _ = x.shape
        qkv = self.qkv(x)  # [B, T, 3*d]
        
        # reshape+转置： [B, T, 3, n_head, d_head] ->  [3, B, n_head, T, d_head]
        qkv = qkv.view(B, T, 3, self.n_head, self.d_head).permute(2, 0, 3, 1, 4)

        # 取出合并对应的q k v
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: [B, n_head, T, d_head]


        # 是否使用旋转位置编码 shape 不变
        if self.use_rope and self.rotary is not None:
            cos, sin = self.rotary(q)
            q = apply_rotary_pos_emb(q, cos, sin)
            k = apply_rotary_pos_emb(k, cos, sin)

        # compute scaled dot-product attention
        # scores: [B, n_head, T, T]-->
        # --> [B, n_head, T, d_head] × [B, n_head, d_head, T] = [B, n_head, T, T]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)

        # causal mask: prevent attending to future tokens in decoder-only models
        # 因果掩码: [1, 1, T, T]
        # causal_mask 现在是一个 [T, T] 的矩阵，看起来像这样（T=4）：
        # [[ True, False, False, False],
        #  [ True,  True, False, False],
        #  [ True,  True,  True, False],
        #  [ True,  True,  True, True]]
        if self.causal:
            causal_mask = torch.tril(torch.ones((T, T), device=x.device, dtype=torch.bool)).view(1, 1, T, T)
            scores = scores.masked_fill(~causal_mask, float('-inf'))

        # external attention mask (e.g., padding masks)
        if attn_mask is not None:
            # attn_mask expected broadcasting to [B, 1, 1, T] or [B, 1, T, T]
            scores = scores + attn_mask

        # 标准注意力计算
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        out = torch.matmul(attn, v)  # [B, n_head, T, d_head]
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.out(out)

这里的多头注意力可以改进为GQA
```python
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_head, n_kv_heads=None, attn_dropout=0.0, causal=True, use_rope=True):
        super().__init__()
        assert d_model % n_head == 0, 'd_model must be divisible by n_head'
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.causal = causal
        self.use_rope = use_rope
        
        # 设置KV头的数量（默认为查询头数的1/4或1/8）
        self.n_kv_heads = n_kv_heads if n_kv_heads is not None else max(1, n_head // 8)
        assert n_head % self.n_kv_heads == 0, 'n_head must be divisible by n_kv_heads'
        self.n_rep = n_head // self.n_kv_heads  # 每个KV头重复的次数
        
        # 分开的Q、K、V投影
        self.q_proj = nn.Linear(d_model, n_head * self.d_head)
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.d_head)
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.d_head)
        self.out = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(attn_dropout)
        
        if use_rope:
            self.rotary = RotaryEmbedding(self.d_head)
        else:
            self.rotary = None
    
    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
        B, T, _ = x.shape
        
        # 分别计算Q、K、V
        q = self.q_proj(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)  # [B, n_head, T, d_head]
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)  # [B, n_kv_heads, T, d_head]
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)  # [B, n_kv_heads, T, d_head]
        
        # 应用旋转位置编码
        if self.use_rope and self.rotary is not None:
            cos, sin = self.rotary(q)
            q = apply_rotary_pos_emb(q, cos, sin)
            k = apply_rotary_pos_emb(k, cos, sin)
        
        # 重复K和V以匹配Q的头数
        k = k.repeat_interleave(self.n_rep, dim=1)  # [B, n_head, T, d_head]
        v = v.repeat_interleave(self.n_rep, dim=1)  # [B, n_head, T, d_head]
        
        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        
        # 应用因果掩码
        if self.causal:
            causal_mask = torch.tril(torch.ones((T, T), device=x.device, dtype=torch.bool)).view(1, 1, T, T)
            scores = scores.masked_fill(~causal_mask, float('-inf'))
        
        # 应用外部注意力掩码
        if attn_mask is not None:
            scores = scores + attn_mask
        
        # 计算注意力权重
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        
        # 应用注意力权重到值向量
        out = torch.matmul(attn, v)  # [B, n_head, T, d_head]
        
        # 合并多头输出
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        
        # 最终输出投影
        return self.out(out)

```

In [5]:
# -----------------------------
# Feed-Forward (GEGLU) - gated linear unit variant that often performs better
# -----------------------------

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.0):
        super().__init__()
        # GEGLU: gate = W2(x) * GELU(W1(x)) but implemented as two linear layers
        self.w1 = nn.Linear(d_model, d_ff)   # 升维: d_model -> d_ff
        self.w2 = nn.Linear(d_model, d_ff)   # 另一个升维路径（用于门控）
        self.proj = nn.Linear(d_ff, d_model) # 降维/投影
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, T, d_model]
        return self.proj(self.dropout(gelu_new(self.w1(x)) * self.w2(x)))
        #        ^^^^^^                          ^^               ^^
        #    投影回d_model                       d_ff             d_ff

这里的GEGLU可以切换为swishGLU，优化硬件提高效率。

In [6]:
# -----------------------------
# Transformer Block
# -----------------------------

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout=0.1, 
                 attn_dropout=0.0, use_rope=True, checkpoint=False):
        super().__init__()
        # Pre-LN architecture is the modern default for stability for deep stacks
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(d_model, n_head, attn_dropout=attn_dropout, 
                                           causal=True, use_rope=use_rope)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.checkpoint = checkpoint

    def forward(self, x, attn_mask=None):
        # Pre-LN then residual
        # (标准的是post-LN)
        def _attn_forward(x, attn_mask):
            return self.attn(self.ln1(x), attn_mask)

        attn_out = _attn_forward(x, attn_mask)
        x = x + self.dropout(attn_out)

        ff_out = self.ff(self.ln2(x))
        x = x + self.dropout(ff_out)
        return x

![Traditional Transformers](img/transformers.png)

In [7]:
# -----------------------------
# Decoder-only Transformer (stack of blocks) + final lm head
# -----------------------------

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=768, n_layer=12, n_head=12, d_ff=3072, max_seq_len=1024,
                 dropout=0.1, attn_dropout=0.0, use_rope=True, checkpoint=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model

        # 首先把 token_id 映射到连续的向量空间，embeding 维度 = d_model
        self.token_emb = nn.Embedding(vocab_size, d_model)

        # 位置编码矩阵 [1, T, d]
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))  # optional learned pos
        
        # 堆叠 n_layer 个TransformerBlock
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_head, d_ff, dropout=dropout, 
                             attn_dropout=attn_dropout, use_rope=use_rope, checkpoint=checkpoint)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(d_model)

        # weight tying between token_emb and lm_head is common 
        # (improves sample quality)
        # 把隐藏层映射回词表大小，得到logits
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight  # 权重共享减少参数量，经典论文：Press & Wolf, 2016
        self.max_seq_len = max_seq_len

    def forward(self, input_ids, attn_mask=None):
        # input_ids: [B, T]
        B, T = input_ids.shape
        assert T <= self.max_seq_len, 'sequence length exceeds model maximum'

        x = self.token_emb(input_ids)  # [B, T, d]
        # add learned positional embeddings (alternative: use RoPE only)
        x = x + self.pos_emb[:, :T, :]

        for layer in self.layers:
            x = layer(x, attn_mask=attn_mask)

        x = self.ln_f(x)
        logits = self.lm_head(x) 
        return logits

    # 生成函数，推理不计算梯度
    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens=100, device=None, temperature=1.0, top_k=0, top_p=0.0):
        # Simple autoregressive sampling. Not optimized for long sequences.
        if device is None:
            device = next(self.parameters()).device
        input_ids = input_ids.to(device)


        B, T = input_ids.shape
        out = input_ids
        for _ in range(max_new_tokens):
            if out.shape[1] > self.max_seq_len:
                inp = out[:, -self.max_seq_len:]  # 只保留最近 max_seq_len 个
            else:
                inp = out
            logits = self.forward(inp)
            # temperature：softmax 平滑控制（>1 更随机，<1 更确定）
            next_logits = logits[:, -1, :] / max(temperature, 1e-8)

            # Top-k filtering
            # 只保留概率前 k 个 token，其余置为 -Inf
            if top_k > 0:
                vals, _ = torch.topk(next_logits, top_k)
                min_vals = vals[:, -1].unsqueeze(1)
                next_logits = torch.where(next_logits < min_vals, torch.full_like(next_logits, -float('Inf')), next_logits)

            # Top-p (nucleus) filtering
            if top_p > 0.0:
                # 按概率排序，保留累计概率 ≤ top_p 的 token（常设置为0.9）
                sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
                cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p

                # shift right to keep first above-threshold token
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_logits[..., indices_to_remove] = -float('Inf')
            
            # 归一化概率后采样一个token
            probs = torch.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            out = torch.cat([out, next_token], dim=1)  # [B, T + max_new_tokens]
        return out


In [8]:
# -----------------------------
# Mini dataset for language modeling (character-level or small subword)
# -----------------------------

class SimpleTextDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len=128):
        # texts: list of str; tokenizer: simple function mapping str->list[int]
        # seq_len: 每个训练样本的最大长度
        self.examples = []
        for t in texts:
            ids = tokenizer(t)

            # break into blocks
            for i in range(0, max(1, len(ids) - 1), seq_len):
                block = ids[i:i + seq_len]
                if len(block) < 2:
                    continue

                # 把 token 序列存成 torch.Tensor，方便后续训练。
                # 所有块存到 self.examples 里。
                self.examples.append(torch.tensor(block, dtype=torch.long))

    # 必要实现
    # 返回长度和一个样本
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

In [9]:
# 举个例子：
texts = ["hello world", "transformers are great"]
tokenizer = lambda x: [ord(c) for c in x]  # 简单用ASCII码作为tokenizer

ds = SimpleTextDataset(texts, tokenizer, seq_len=5)
print(len(ds))       # 多少个样本
print(ds[0])         # 第一个样本


7
tensor([104, 101, 108, 108, 111])


In [10]:
# -----------------------------
# Simple tokenizer (byte-level BPE is heavy, we provide a toy tokenizer for demo)
# In practical use, use Hugging Face tokenizers / SentencePiece / BPE.
# -----------------------------

class CharTokenizer:
    def __init__(self, texts):
        # build char-level vocab
        chars = sorted(list(set(''.join(texts))))

        # reserve 0 for padding
        # 从 字符 --> 整数id 的映射，而外加一个<unk>处理此词表以外的字符
        self.stoi = {c: i + 1 for i, c in enumerate(chars)}  
        self.stoi['<unk>'] = len(self.stoi) + 1
        # 逆运算itos
        self.itos = {i: s for s, i in self.stoi.items()}
        self.vocab_size = len(self.stoi) + 1

    # 遍历str寻找id
    def encode(self, text):
        return [self.stoi.get(c, self.stoi['<unk>']) for c in text]

    # 解码id查字符str
    def decode(self, ids):
        return ''.join([self.itos.get(i, '?') for i in ids])
    
    # 可以直接使用：tokenizer("hello") 等价于 tokenizer.encode("hello")
    def __call__(self, text):
        return self.encode(text)

In [11]:
texts = ["hello", "world"]
tokenizer = CharTokenizer(texts)

print("词表大小:", tokenizer.vocab_size)
print("编码:", tokenizer.encode("hello"))
print("解码:", tokenizer.decode(tokenizer.encode("hello")))


词表大小: 9
编码: [3, 2, 4, 4, 5]
解码: hello


In [12]:
# -----------------------------
# Training utilities: optimizer, scheduler, save/load
# -----------------------------

# 训练时动态调整学习率
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=1e-6):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps    # 前多少步用 线性 warmup
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.step_num = 0                   # 当前的训练步数

    def step(self):
        self.step_num += 1
        if self.step_num < self.warmup_steps:
            # Warmup 阶段
            lr_mult = float(self.step_num) / float(max(1, self.warmup_steps))
        else:
            # Cosine Decay 阶段
            # 学习率从 initial_lr 逐渐下降到 0，呈余弦曲线。
            progress = float(self.step_num - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
            lr_mult = max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = max(self.min_lr, param_group['initial_lr'] * lr_mult) 

# 保存模型
def save_checkpoint(model, optimizer, step, path='./checkpoint/checkpoint.pt'):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'step': step
    }
    torch.save(state, path)

In [13]:
# -----------------------------
# Training loop
# -----------------------------

def train(model, dataset, device='cuda', epochs=1, batch_size=8, seq_len=128, lr=2e-4, weight_decay=0.01,
          warmup_steps=100, total_steps=1000, grad_accum_steps=1, save_every=500):
    model.to(device)
    model.train()

    def collate_fn(batch):
        # batch: list of 1D tensors of variable length -> pad to max_len
        # 对于batch内的样本pad 到统一长度
        max_len = max([b.size(0) for b in batch])
        max_len = min(max_len, seq_len)
        input_ids = torch.zeros((len(batch), max_len), dtype=torch.long)
        for i, b in enumerate(batch):
            l = min(b.size(0), max_len)
            input_ids[i, :l] = b[:l]
        return input_ids  # shape [batch_size, seq_len]

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    for g in optimizer.param_groups:
        g.setdefault('initial_lr', g['lr'])

    scheduler = WarmupCosineScheduler(optimizer, warmup_steps, total_steps)
    
    # 用于混合精度训练（AMP），节省显存、加速训练
    scaler = torch.cuda.amp.GradScaler()

    # 主循环
    global_step = 0
    for epoch in range(epochs):
        t0 = time.time()
        for batch in loader:
            input_ids = batch.to(device)
            # For language modeling, inputs and labels are the same but shifted.
            # Simple approach: predict next token for each position -> labels = input_ids (shift internally)
            with torch.cuda.amp.autocast():
                logits = model(input_ids)
                # shift logits and labels for next-token prediction
                # 向右移动标签（滑动一位）:
                # input = [a, b, c]  / tabels = [b, c, d]
                shift_logits = logits[:, :-1, :].contiguous()  # 丢掉最后一个预测，因为没有target
                shift_labels = input_ids[:, 1:].contiguous()   # 丢掉第一个token，因为没有前文可以预测
                loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                loss = loss / grad_accum_steps

            scaler.scale(loss).backward()

            # 反向传播 + 梯度累积
            if (global_step + 1) % grad_accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) #梯度裁剪，防止梯度爆炸
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            if global_step % 50 == 0:
                print(f"step {global_step:6d} epoch {epoch:3d} loss {loss.item()*grad_accum_steps:.6f}")

            global_step += 1
            if global_step >= total_steps:
                break

            if global_step % save_every == 0:
                save_checkpoint(model, optimizer, global_step, path=f'ckpt_step_{global_step}.pt')

        print(f'Epoch {epoch} took {time.time()-t0:.2f}s')
        if global_step >= total_steps:
            break

    save_checkpoint(model, optimizer, global_step, path='model/ckpt_final.pt')

In [14]:
# 训练参数配置
device = 'cuda' if torch.cuda.is_available() else 'cpu'

epochs = 15
batch_size = 8
seq_len = 128
d_model = 256
n_layer = 6
n_head = 8
d_ff = 1024
lr = 2e-4
total_steps = 1000
warmup_steps = 100
grad_accum_steps = 1


In [15]:
# 数据准备
texts = [
    "Hello world! This is a tiny dataset to demonstrate training a small transformer.",
    "Transformers use attention to mix context. Rotatory embeddings help for autoregressive models.",
    "This dataset is tiny; in real training you need GBs to TBs of quality data."
] * 100  # toy dataset 复制放大

tokenizer = CharTokenizer(texts)
dataset = SimpleTextDataset(texts, tokenizer, seq_len=seq_len)

print("Vocab size:", tokenizer.vocab_size, "Dataset size:", len(dataset))


Vocab size: 34 Dataset size: 300


In [16]:
# 模型初始化
model = DecoderOnlyTransformer(
    vocab_size=tokenizer.vocab_size,
    d_model=d_model,
    n_layer=n_layer,
    n_head=n_head,
    d_ff=d_ff,
    max_seq_len=seq_len,
    dropout=0.1,
    attn_dropout=0.0,
    use_rope=True,
    checkpoint=False
)

In [17]:
# 训练
train(
    model, dataset, device=device, epochs=epochs,
    batch_size=batch_size, seq_len=seq_len,
    lr=lr, warmup_steps=warmup_steps, total_steps=total_steps,
    grad_accum_steps=grad_accum_steps
)


step      0 epoch   0 loss 185.290329
Epoch 0 took 4.25s
step     50 epoch   1 loss 15.663038
Epoch 1 took 3.43s
step    100 epoch   2 loss 2.465168
Epoch 2 took 3.40s
step    150 epoch   3 loss 0.105163
Epoch 3 took 3.37s
Epoch 4 took 3.50s
step    200 epoch   5 loss 0.016550
Epoch 5 took 3.54s
step    250 epoch   6 loss 0.011935
Epoch 6 took 3.51s
step    300 epoch   7 loss 0.010812
Epoch 7 took 3.33s
Epoch 8 took 3.20s
step    350 epoch   9 loss 0.017897
Epoch 9 took 3.26s
step    400 epoch  10 loss 0.018343
Epoch 10 took 3.25s
step    450 epoch  11 loss 0.008710
Epoch 11 took 3.37s
Epoch 12 took 3.22s
step    500 epoch  13 loss 0.000295
Epoch 13 took 3.50s
step    550 epoch  14 loss 0.020455
Epoch 14 took 3.25s


In [20]:
# 推理示例（采样）
model.to(device)
model.eval()

def generate(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, top_k=None):
    ids = torch.tensor([tokenizer(prompt)], dtype=torch.long).to(device)
    for _ in range(max_new_tokens):
        logits = model(ids)[:, -1, :] / temperature
        if top_k is not None:
            v, ix = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        ids = torch.cat([ids, next_id], dim=1)
    return tokenizer.decode(ids[0].tolist())

print(generate(model, tokenizer, "Hellow", max_new_tokens=64, top_k=20))


Helloworld! This is a tiny dataset to demo tratratratraing a smalll tr
