# 自然语言transformer基础架构实现

In [10]:
import math
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

########################
#   1) 构建词表
########################

def build_vocab(poem_file):
    """
    从文本文件中读取全部古诗，收集出现过的字符；
    返回:
      stoi (dict): {char: id}
      itos (list): 其中索引位置对应字符，索引即id
      poems (list of str): 每个元素是一首(或一段)诗的文本
    """
    # 预留三个特殊token
    special_tokens = ['<PAD>', '<START>', '<END>']
    stoi = {}
    itos = []
    
    for i, token in enumerate(special_tokens):
        stoi[token] = i
        itos.append(token)
    
    # 读取古诗
    with open(poem_file, 'r', encoding='utf-8') as f:
        # 以空行分割多首诗，也可以改为按行处理
        lines = f.read().split('\n\n')

    poems = []
    for line in lines:
        line = line.strip()
        if line:
            poems.append(line)  # 收集成一个list

    # 收集所有字符
    for poem in poems:
        for ch in poem:
            if ch not in stoi:
                stoi[ch] = len(stoi)  # 从3开始递增
                itos.append(ch)
    
    return stoi, itos, poems

########################
#   2) 数据集处理
########################

class PoetryDataset(Dataset):
    """
    将每首诗转成整数序列，并在开头结尾加 <START> 和 <END>。
    """
    def __init__(self, poems, stoi, max_len=64):
        super().__init__()
        self.stoi = stoi
        self.max_len = max_len
        self.data = []
        
        for poem in poems:
            # 去除空白符；也可在此做更多清洗
            poem = poem.replace(' ', '').replace('\n', '')
            # 转成 token id 序列
            seq = [stoi['<START>']] + [stoi[ch] for ch in poem if ch in stoi] + [stoi['<END>']]
            
            # 如果超过 max_len，则截断
            if len(seq) > max_len:
                seq = seq[:max_len]
                seq[-1] = stoi['<END>']  # 保证最后一位是END
                
            self.data.append(torch.LongTensor(seq))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def poetry_collate_fn(batch):
    """
    对 batch 中各序列做对齐(补PAD至本batch中最长的序列长度)，
    分别生成输入 inp 和目标 tgt。
    """
    max_len = max(len(seq) for seq in batch)
    inputs, targets = [], []
    
    for seq in batch:
        # inp = seq[:-1], tgt = seq[1:]
        inp = seq[:-1]
        tgt = seq[1:]
        
        # 对齐到max_len-1
        inp_pad = torch.cat([inp, torch.zeros(max_len - 1 - len(inp), dtype=torch.long)])
        tgt_pad = torch.cat([tgt, torch.zeros(max_len - 1 - len(tgt), dtype=torch.long)])
        
        inputs.append(inp_pad)
        targets.append(tgt_pad)
    
    inputs = torch.stack(inputs, dim=0)   # [batch_size, seq_len-1]
    targets = torch.stack(targets, dim=0) # [batch_size, seq_len-1]
    return inputs, targets

########################
#   3) 模型组件
########################

def create_masks(src, tgt):
    """
    根据 src, tgt 序列生成掩码:
    1) 忽略 <PAD>
    2) 自回归掩码 (nopeak_mask)
    """
    # [batch_size, src_seq_len]
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
    if tgt is None:
        return src_mask, None
    
    tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
    seq_length = tgt.size(1)
    nopeak_mask = (1 - torch.triu(torch.ones((1, seq_length, seq_length)), diagonal=1)).bool()
    tgt_mask = tgt_mask & nopeak_mask.to(tgt.device)
    return src_mask, tgt_mask

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        return output, attn
    
    def split_heads(self, x):
        batch_size = x.size(0)
        # x: [batch_size, seq_len, d_model] -> [batch_size, num_heads, seq_len, d_k]
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        output, attn = self.scaled_dot_product_attention(Q, K, V, mask)
        # [batch_size, num_heads, seq_len, d_k] -> [batch_size, seq_len, d_model]
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(output)
        return output, attn

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_seq_length, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_output, _ = self.mha(x, x, x, mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.layernorm2(x + self.dropout(ffn_output))
        return x

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        attn1, _ = self.mha1(x, x, x, tgt_mask)
        x = self.layernorm1(x + self.dropout(attn1))
        attn2, _ = self.mha2(x, enc_output, enc_output, src_mask)
        x = self.layernorm2(x + self.dropout(attn2))
        ffn_output = self.ffn(x)
        x = self.layernorm3(x + self.dropout(ffn_output))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length=5000, dropout=0.1):
        super().__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def encode(self, src, src_mask):
        x = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
    
    def decode(self, tgt, enc_output, src_mask, tgt_mask):
        x = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        for layer in self.decoder_layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.fc(x)
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_output = self.encode(src, src_mask)
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
        return dec_output

########################
#   4) 训练函数
########################

def train_model(model, train_loader, criterion, optimizer, device, num_epochs=10):
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (inp, tgt) in enumerate(train_loader):
            inp, tgt = inp.to(device), tgt.to(device)
            
            # mask
            src_mask, tgt_mask = create_masks(inp, tgt)
            
            optimizer.zero_grad()
            output = model(inp, tgt, src_mask, tgt_mask)
            # output: [batch_size, seq_len, vocab_size]
            # tgt:    [batch_size, seq_len]
            
            loss = criterion(
                output.reshape(-1, output.size(-1)),
                tgt.reshape(-1)
            )
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 50 == 0:
                print(f"[Epoch {epoch}] Batch {batch_idx}, loss={loss.item():.4f}")
        
        avg_loss = total_loss / len(train_loader)
        print(f"[Epoch {epoch}] Average Loss={avg_loss:.4f}")

########################
#   5) 推理(续写)函数
########################

def generate_poetry(model, prompt, stoi, itos, device, max_len=80):
    """
    prompt: str，如 "床前明月光，"
    让模型接着生成后面的字符，直到 <END> 或达到 max_len。
    """
    model.eval()
    
    # prompt -> token id，前面加上<START>
    seq = [stoi['<START>']] + [stoi[ch] for ch in prompt if ch in stoi]
    input_seq = torch.LongTensor(seq).unsqueeze(0).to(device)  # [1, len_prompt]
    
    # 对于Encoder-Decoder结构，这里可以让encoder的输入=整个prompt（含<START>），也可仅<START>
    src_mask, _ = create_masks(input_seq, None)
    enc_output = model.encode(input_seq, src_mask)
    
    # 解码端的初始序列
    generated = input_seq.clone()  # [1, len_prompt]
    
    for _ in range(max_len):
        _, tgt_mask = create_masks(input_seq, generated)
        dec_output = model.decode(generated, enc_output, src_mask, tgt_mask)
        # dec_output: [1, seq_len, vocab_size]
        
        logits = dec_output[:, -1, :]  # 取最后一个时间步
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.argmax(probs, dim=-1).item()
        
        # 拼接到 generated
        next_token_tensor = torch.LongTensor([[next_token]]).to(device)
        generated = torch.cat([generated, next_token_tensor], dim=1)
        
        # 如果是<END>，停止
        if next_token == stoi['<END>']:
            break
    
    # 去掉开头的 <START>，以及可能的 <END>
    gen_ids = generated.squeeze(0).tolist()
    if gen_ids[0] == stoi['<START>']:
        gen_ids = gen_ids[1:]
    if stoi['<END>'] in gen_ids:
        end_idx = gen_ids.index(stoi['<END>'])
        gen_ids = gen_ids[:end_idx]
    
    # 转成字符串
    result = ''.join(itos[i] for i in gen_ids)
    return result

########################
#   6) 主函数
########################

def main():
    # 配置
    poem_file = "poems.txt"  # 存放古诗的文件
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    
    # 1. 构建词表
    stoi, itos, poems = build_vocab(poem_file)
    vocab_size = len(stoi)
    print("vocab_size:", vocab_size)
    print("number_of_poems:", len(poems))
    
    # 2. 构建Dataset和DataLoader
    dataset = PoetryDataset(poems, stoi, max_len=64)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=poetry_collate_fn)
    
    # 3. 构建Transformer模型
    d_model = 128
    num_heads = 8
    num_layers = 3
    d_ff = 512
    dropout = 0.1
    
    model = Transformer(
        src_vocab_size=vocab_size,
        tgt_vocab_size=vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        num_layers=num_layers,
        d_ff=d_ff,
        dropout=dropout
    ).to(device)
    
    # 4. 损失和优化器
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9,0.98), eps=1e-9)
    
    # 5. 训练
    train_model(model, train_loader, criterion, optimizer, device, num_epochs=10)
    
    # 6. 测试续写
    prompt = "床前明月光，"  # 给定一个起始句
    result = generate_poetry(model, prompt, stoi, itos, device, max_len=80)
    print("Prompt:", prompt)
    print("Generated:", result)

if __name__ == "__main__":
    main()

Using device: cpu
vocab_size: 113
number_of_poems: 6
[Epoch 0] Batch 0, loss=4.9904
[Epoch 0] Average Loss=4.9904
[Epoch 1] Batch 0, loss=4.8774
[Epoch 1] Average Loss=4.8774
[Epoch 2] Batch 0, loss=4.8281
[Epoch 2] Average Loss=4.8281
[Epoch 3] Batch 0, loss=4.7173
[Epoch 3] Average Loss=4.7173
[Epoch 4] Batch 0, loss=4.6812
[Epoch 4] Average Loss=4.6812
[Epoch 5] Batch 0, loss=4.5705
[Epoch 5] Average Loss=4.5705
[Epoch 6] Batch 0, loss=4.5064
[Epoch 6] Average Loss=4.5064
[Epoch 7] Batch 0, loss=4.4613
[Epoch 7] Average Loss=4.4613
[Epoch 8] Batch 0, loss=4.3970
[Epoch 8] Average Loss=4.3970
[Epoch 9] Batch 0, loss=4.3426
[Epoch 9] Average Loss=4.3426
Prompt: 床前明月光，
Generated: 床前明月光，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，
