**包含内容**
- Batch/mask 构造
- 学习率调度可视化
- Label Smoothing 可视化
- 复制任务训练与贪婪解码
- 模型保存/加载示例

In [None]:
# 导入训练相关工具
import torch
import matplotlib.pyplot as plt
from transformer_utils import (
    device,
    Batch, rate, visualize_learning_rate_schedule,
    LabelSmoothing, visualize_label_smoothing,
    data_gen, SimpleLossCompute, greedy_decode,
    run_copy_task_example, save_model, load_model, make_model,
)
print('训练工具已加载，设备:', device)

In [None]:
# 学习率调度曲线
visualize_learning_rate_schedule()

In [None]:
# Label Smoothing效果
visualize_label_smoothing()

## 复制任务训练
默认运行20个epoch，耗时较短；如需快速验证可在函数内部减少epoch或batch数。

In [None]:
# 训练小模型用于复制任务
copy_model, losses = run_copy_task_example()
print('训练完成，损失条目数:', len(losses))

In [None]:
# Visualize training curve
plt.figure(figsize=(8,4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Copy Task Training Curve')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# 使用训练好的模型做贪婪解码
src = torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]])
src_mask = torch.ones(1,1,10)
result = greedy_decode(copy_model, src, src_mask, max_len=10, start_symbol=1)
print('输入序列 :', src[0].tolist())
print('输出序列 :', result[0].tolist())

In [None]:
# 保存与加载示例
save_model(copy_model, 'copy_task_model.pt')
reloaded = make_model(11, 11, N=2)
load_model(reloaded, 'copy_task_model.pt')