好的，你现在词表已经构建好了，下一步就是准备 Transformer 的训练数据 。

🎯 目标
我们要将原始的上下联文本对（如：“上联：春风入喜财入户”，“下联：岁月更新福满门”）转换为：

输入序列（上联）
输出序列（下联）
Token ID 形式
满足 Transformer seq2seq 模型输入格式
📦 假设你已经有了以下内容：
vocab.pkl：包含 word2idx 和 idx2word
shanglian.txt：每行一个上联
xialian.txt：每行一个下联
🧱 一、数据预处理流程
✅ 步骤概览：
加载词表映射
加载上下联数据
将每个句子转换为 token ID 序列
添加特殊标记（SOS, EOS）
填充或截断到统一长度
构建 PyTorch Dataset 或 DataLoader

In [19]:
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import re

# 数据清洗函数，去除对联数据中的空格
def clean_couplet(line, keep_punct=True):
    line = line.strip()
    if not keep_punct:
        line = re.sub(r'[，。：；、！？,"\'\.\?\!]','',line)
    else:
        line = line.replace(',', '，').replace('.', '。').replace(';', '；')
    # 去除多余空格
    line = re.sub(r'\s+','',line)
    return line

# -----------------------------
# 1. 加载词表
# -----------------------------
with open('vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)
word2idx = vocab['word2idx']
idx2word = vocab['idx2word']

# 特殊符号
PAD = word2idx['<PAD>']
SOS = word2idx['<BOS>']
EOS = word2idx['<EOS>']
UNK = word2idx['<UNK>']

# -----------------------------
# 2. 数据加载函数
# -----------------------------
def load_data(shang_path, xia_path) -> list:
    with open(shang_path, 'r', encoding='utf-8') as f:
        shangs = [clean_couplet(line) for line in f.readlines()]
    with open(xia_path, 'r', encoding='utf-8') as f:
        xias = [clean_couplet(line) for line in f.readlines()]
    return list(zip(shangs, xias))

# -----------------------------
# 3. 文本转ID函数
# -----------------------------
def text_to_ids(text, word2idx, max_len=64, add_sos_eos=True):
    ids = [word2idx.get(c, UNK) for c in text]
    if add_sos_eos:
        ids = [SOS] + ids + [EOS]
    # 截断或填充
    if len(ids) > max_len:
        ids = ids[:max_len]
    else:
        ids += [PAD] * (max_len - len(ids))
    return ids

# -----------------------------
# 4. 自定义Dataset类
# -----------------------------
class CoupletDataset(Dataset):
    def __init__(self, data, word2idx, max_len=32):
        self.data = data
        self.word2idx = word2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        shang, xia = self.data[idx]

        src = text_to_ids(shang, self.word2idx, self.max_len, add_sos_eos=True)   # 上联作为输入
        tgt = text_to_ids(xia, self.word2idx, self.max_len, add_sos_eos=True)     # 下联作为目标

        return {
            'src': torch.tensor(src, dtype=torch.long),
            'tgt': torch.tensor(tgt, dtype=torch.long)
        }

# -----------------------------
# 5. 创建DataLoader
# -----------------------------
data_pairs = load_data('..\\week08\\data\\fixed_couplets_in.txt', '..\\week08\\data\\fixed_couplets_out.txt')
dataset = CoupletDataset(data_pairs, word2idx, max_len=64)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

接下来就是构建 模型

In [25]:
import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=3,
                 num_decoder_layers=3,dim_feedforward=2048,max_len=64,dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        # Embedding层
        self.embedding = nn.Embedding(vocab_size,d_model)
        # 位置编码器
        self.pos_encoder = PositionalEncoding(d_model,dropout,max_len)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        # 输出层
        self.fc_out = nn.Linear(d_model,vocab_size)
    def forward(self,src,tgt):
        src_emb = self.pos_encoder(self.embedding(src))
        tgt_emb = self.pos_encoder(self.embedding(tgt))
        memory = self.transformer.encoder(src_emb)
        output = self.transformer.decoder(tgt_emb,memory)
        return self.fc_out(output)
# 位置编码类
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
        

In [33]:
# 模型训练+验证+推理+保存/加载
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# -------------------------------------
# 超参数设置
# -------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 20
learning_rate = 1e-4
save_path = './couplet_transformer.pth'
# -------------------------------------
# 初始化模型 & 优化器 & 损失函数
# -------------------------------------
model = TransformerModel(
    vocab_size=len(word2idx),
    d_model=512,
    nhead=8,
    num_decoder_layers=3,
    num_encoder_layers=3,
    dim_feedforward=2048,
    max_len=64
).to(device)
optimizer = optim.Adam(model.parameters(),learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)

# -------------------------------------
# 训练一个 epoch
# -------------------------------------
def train_epoch(model,dataloader,optimizer,criterion,device):
    model.train
    total_loss = 0
    progress_bar = tqdm(dataloader,desc='Training',leave=False)
    for batch in progress_bar:
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)
        tgt_input = tgt[:,:-1]
        tgt_label = tgt[:,1:]
        logits = model(src,tgt_input) # B L V
        loss = criterion(logits.reshape(-1,logits.size(-1)),tgt_label.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss = loss.item())
    avg_loss = total_loss / len(dataloader)
# -------------------------------------
# 验证一个 epoch
# -------------------------------------
def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating", leave=False):
            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)

            tgt_input = tgt[:, :-1]
            tgt_label = tgt[:, 1:]

            logits = model(src, tgt_input)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_label.reshape(-1))
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
out = train_epoch(model,dataloader,optimizer,criterion,device)

In [None]:
# -------------------------------------
# 自回归生成函数（推理）
# -------------------------------------
def generate(model,src_sentence,word2idx,idx2word,device,max_len=64):
    model.eval
    with torch.no_grad():
        src_ids = text_to_ids(src_sentence,word2idx,max_len)
        src_tensor = torch.tensor([src_ids],device=device) # [1,L]
        # 初始输入是SOS
        tgt_ids = [SOS]

        for _ in range(max_len):
            tgt_tensor = torch.tensor([tgt_ids],device) # [1,current_len]
            logits = model(src_tensor,tgt_tensor)
            pred_id = logits.argmax(dim=-1)[0,-1].item()
            if pred_id == EOS:
                break

            tgt_ids.append(pred_id)

        # 转换为汉字
        generated = ''.join([idx2word[i] for i in tgt_ids[1:]])  # 去掉 SOS
        return generated          

In [None]:
# -------------------------------------
# 主训练循环
# -------------------------------------
best_val_loss = float('inf')

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")

    train_loss = train_epoch(model, dataloader, optimizer, criterion, device)
    print(f"Train Loss: {train_loss:.4f}")

    val_loss = eval_epoch(model, dataloader, criterion, device)
    print(f"Val Loss: {val_loss:.4f}")

    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
        print("Best model saved!")

    # 示例生成
    example_shang = "国泰民安春满地"
    generated_xia = generate(model, example_shang, word2idx, idx2word, device)
    print(f"上联：{example_shang}")
    print(f"下联：{generated_xia}")

# -------------------------------------
# 保存和加载模型
# -------------------------------------

# 保存整个模型结构和参数
torch.save(model, 'full_couplet_model.pth')

# 加载模型
loaded_model = torch.load('full_couplet_model.pth')
loaded_model.to(device)
loaded_model.eval()