In [None]:
import torch
from torch import nn

nn.Transformer只包含编码器和解码器，因此位置编码和嵌入以及最终线性层依然需要我们自己实现

# 位置嵌入与编码

In [None]:
class EmbeddingPositionEncode(nn.Module):
    def __init__(self, d_model, dropout: float, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        # x_1:(batch_size,seq_len,d_model)
        x_1 = self.embedding(input_tensor)
        seq_len = input_tensor.size(1)

        # 创建位置编码(正余弦)
        position = torch.arange(seq_len, device=input_tensor.device).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, device=input_tensor.device)
            * (
                    -torch.log(torch.tensor(10000.0, device=input_tensor.device))
                    / self.d_model
            )
        )

        pos_encoding = torch.zeros(seq_len, self.d_model, device=input_tensor.device)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)

        # 添加位置编码
        x_2 = pos_encoding.unsqueeze(0)
        return self.dropout(x_1 + x_2)

# 完整模型

In [None]:
class ChEnTransformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        nhead: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.pos_embed_encoder_layer = EmbeddingPositionEncode(
            d_model, dropout, src_vocab_size
        )
        self.pos_embed_decoder_layer = EmbeddingPositionEncode(
            d_model, dropout, tgt_vocab_size
        )
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            batch_first=True,
        )
        self.linear = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, tgt_mask, src_key_padding_mask, tgt_key_padding_mask):
        src_processed = self.pos_embed_encoder_layer(src)
        tgt_processed = self.pos_embed_decoder_layer(tgt)
        out = self.transformer(
            src=src_processed,
            tgt=tgt_processed,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            tgt_is_causal=True,
        )

        return self.linear(out)

# 训练

In [None]:
import matplotlib.pyplot as plt
from TranslationDataset import TranslationDataset
from torch.utils.data import DataLoader


def save_checkpoint(epoch, model, optimizer, scheduler, loss, path):
    checkpoint = {
        "epoch": epoch ,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "scheduler_type": type(scheduler).__name__,
        "loss": loss,
    }
    torch.save(checkpoint, path)


def load_checkpoint(model, optimizer, scheduler, path):
    if path is not None:
        checkpoint = torch.load(path)
        if model:
            model.load_state_dict(checkpoint["model_state_dict"])
        if optimizer:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            print(f"从{checkpoint['epoch']}开始训练")
        if scheduler:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        return checkpoint["epoch"], checkpoint["loss"]

    print("未发现检查点")
    return 0, float("inf")


dataset = TranslationDataset(
    file_path="../data/translate/TranslationData.csv",
    max_lines=10,
)
device = "cuda"
src_vocab_size = len(dataset.ch_token_to_index)
tgt_vocab_size = len(dataset.en_token_to_index)
model = ChEnTransformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=64,
    dropout=0,
).to(device)

epochs = 150
lr = 1.0
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=lr,
    betas=(0.9, 0.98),
    eps=1e-9,  # weight_decay=1e-4
)
padding_idx = dataset.en_token_to_index["<pad>"]
loss_fn = nn.CrossEntropyLoss(ignore_index=padding_idx)
batch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size)

total_steps = int(epochs * (len(dataset) / batch_size))
warmup_steps = int(total_steps * 0.4)
print(f"total_steps: {total_steps},warmup_steps: {warmup_steps}")
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer=optimizer,
    lr_lambda=lambda step: (
        512 ** (-0.5)  # 模型维度的平方根倒数
        * min(
            (step + 1) ** (-0.5),  # 衰减阶段：步长的平方根倒数
            (step + 1) * (warmup_steps ** (-1.5)),  # 预热阶段：线性增长
        )
    ),
)

error = []
path = None  # 如果你有保存好的模型，在这里填充具体的文件路径
start_epoch = 0
# start_epoch, loss = load_checkpoint(model, optimizer, scheduler, path)
model.train()
for epoch in range(start_epoch, epochs):
    total_loss = 0
    for src, tgt in dataloader:
        # 处理解码器输入和目标
        tgt_input = tgt[:, :-1]  # 解码器输入：去掉最后一个token（<eos>/<pad>）
        tgt_target = tgt[:, 1:]  # 预测目标：去掉第一个token（<bos>）

        # 生成填充掩码
        src_key_padding_mask = (src == dataset.ch_token_to_index["<pad>"]).to(device)
        tgt_key_padding_mask = (tgt_input == dataset.en_token_to_index["<pad>"]).to(
            device
        )
        # 生成因果掩码
        # 用nn.Transformer的内置方法生成正方形后续掩码
        tgt_input_seq_len = tgt_input.shape[1]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input_seq_len).to(
            device
        )
        src, tgt_input, tgt_target = (
            src.to(device),
            tgt_input.to(device),
            tgt_target.to(device),
        )

        # 与我们自己实现的模型不同，这里删除tgt_mask参数，依赖tgt_is_causal=True自动生成
        pred = model(
            src=src,
            tgt=tgt_input,
            tgt_mask=tgt_mask,  # 无需手动传掩码
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
        )

        # 计算损失和优化
        loss = loss_fn(pred.reshape(-1, tgt_vocab_size), tgt_target.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        avg_loss = total_loss / len(dataloader)
        print(
            f"epoch {epoch + 1}, loss: {avg_loss:.6f}, perplexity: {torch.exp(torch.tensor(avg_loss)).item():.6f}"
        )
    # 如果你需要保存模型，取消下面的注释并更改path_to_save为具体的文件路径
    # if (epoch + 1) % 50 == 0: #
    #     save_checkpoint(epoch + 1, model, optimizer, scheduler, loss, path)
    #     path_to_save = None
    #     print(f"已保存为{path}")
    error.append(loss.item())
plt.plot(error)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

In [None]:
model.eval()
with torch.no_grad():
    src, tgt = next(iter(dataloader))
    tgt_input_eval = tgt[:, :-1]  # 解码器输入：去掉最后一个token
    tgt_input_seq_len = tgt_input_eval.shape[1]

    # 生成因果掩码
    tgt_mask_eval = nn.Transformer.generate_square_subsequent_mask(tgt_input_seq_len).to(device)

    # 生成填充掩码
    src_key_padding_mask = (src == dataset.ch_token_to_index["<pad>"]).to(device)
    tgt_key_padding_mask = (tgt_input_eval == dataset.en_token_to_index["<pad>"]).to(device)

    # 移到对应设备
    src, tgt_input_eval, tgt = src.to(device), tgt_input_eval.to(device), tgt.to(device)

    # 前向传播
    pred = model(
        src=src,
        tgt=tgt_input_eval,
        tgt_mask=tgt_mask_eval,
        src_key_padding_mask=src_key_padding_mask,
        tgt_key_padding_mask=tgt_key_padding_mask
    )

    pred_tokens = pred.argmax(dim=-1)  # 形状：(batch_size, tgt_input_seq_len)
    # 生成与pred_tokens同batch_size的eos_token，形状：(batch_size, 1)
    eos_token = torch.full(
        (pred_tokens.size(0), 1),  # 确保batch维度一致
        dataset.en_token_to_index["<eos>"],
        device=device,
        dtype=torch.long
    )
    # dim=1表示在序列长度维度拼接
    pred_tokens = torch.cat([pred_tokens, eos_token], dim=1)

    # 打印结果
    print("预测:", [dataset.en_index_to_token[idx] for idx in pred_tokens[0].cpu().numpy()])
    print("真实:", [dataset.en_index_to_token[idx] for idx in tgt[0].cpu().numpy()])

In [None]:
def preprocess_chinese(sentence, ch_token_to_index, max_length):
    """预处理中文句子，转换为张量（保持不变）"""
    tokens = sentence.split()
    index_tokens = [ch_token_to_index.get(token, ch_token_to_index["<unk>"]) for token in tokens]
    index_tokens = [ch_token_to_index["<bos>"]] + index_tokens + [ch_token_to_index["<eos>"]]
    if len(index_tokens) < max_length:
        index_tokens += [ch_token_to_index["<pad>"]] * (max_length - len(index_tokens))
    else:
        index_tokens = index_tokens[:max_length]
    return torch.tensor(index_tokens, dtype=torch.long).unsqueeze(0)


def translate(chinese_sentence, model, dataset, device):
    """
    适配PyTorch官方Transformer的翻译函数
    """
    model.eval()
    ch_token_to_index = dataset.ch_token_to_index
    en_token_to_index = dataset.en_token_to_index
    en_index_to_token = dataset.en_index_to_token
    
    # 预处理中文输入
    src_tensor = preprocess_chinese(
        chinese_sentence, 
        ch_token_to_index, 
        dataset.ch_max_length
    ).to(device)
    src_key_padding_mask = (src_tensor == ch_token_to_index["<pad>"]).to(device)
    
    # 获取编码器输出
    with torch.no_grad():
        # 先进行嵌入和位置编码
        src_processed = model.pos_embed_encoder_layer(src_tensor)
        # 调用官方编码器
        encoder_out = model.transformer.encoder(
            src=src_processed,
            src_key_padding_mask=src_key_padding_mask,
            is_causal=False  # 编码器不需要因果掩码
        )
    
    # 自回归生成英文翻译
    tgt_tokens = [en_token_to_index["<bos>"]]  # 从起始标记开始
    tgt_tensor = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(device)
    
    for _ in range(dataset.en_max_length - 1):
        # 生成当前目标序列的因果掩码
        tgt_seq_len = tgt_tensor.shape[1]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len).to(device)
        
        # 生成目标序列的填充掩码
        tgt_key_padding_mask = (tgt_tensor == en_token_to_index["<pad>"]).to(device)
        
        # 解码器前向传播
        with torch.no_grad():
            # 目标序列嵌入和位置编码
            tgt_processed = model.pos_embed_decoder_layer(tgt_tensor)
            # 调用官方解码器
            decoder_out = model.transformer.decoder(
                tgt=tgt_processed,
                memory=encoder_out,
                tgt_mask=tgt_mask,
                memory_key_padding_mask=src_key_padding_mask,  # 复用源序列的填充掩码
                tgt_key_padding_mask=tgt_key_padding_mask,
                tgt_is_causal=True,  # 明确因果关系
                memory_is_causal=False
            )

            pred_logits = model.linear(decoder_out)
            # 取最后一个位置的预测结果
            next_token_idx = torch.argmax(pred_logits[:, -1, :], dim=-1).item()
        
        # 更新目标序列
        tgt_tokens.append(next_token_idx)
        tgt_tensor = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(device)
        
        # 停止
        if next_token_idx == en_token_to_index["<eos>"]:
            break
    
    return [en_index_to_token[idx] for idx in tgt_tokens]



chinese_input = '缔约国 根据 《 任择 议定书 》 第 12 条 第 1 款 提交 的 初次 报告 的 准则 由 委员会 在 2002 年 2 月 1 日 第 777 次 会议 上 通过 。'

# 执行翻译
translated_result = translate(chinese_input, model, dataset, device)

# 输出结果
print("中文输入:", chinese_input)
print("英文翻译:", ' '.join(translated_result))


# 模型泛化

In [None]:
dataset = TranslationDataset(
    file_path="../data/translate/2m_WMT21.csv",
    max_lines=1500000,# 大样本
)
device = "cuda" if torch.cuda.is_available() else "cpu"
src_vocab_size = len(dataset.ch_token_to_index)
tgt_vocab_size = len(dataset.en_token_to_index)
model = ChEnTransformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=512,
    dropout=0.3,
).to(device)
# 同样地，训练步骤不变,训练10-15个epoch即可
# 训练步骤....