In [1]:
import os, json, glob, subprocess, pandas as pd

DATA = "/root/Course/final_project/nmt_data_jieba_100k"
PY = "python"
SCRIPT = "/root/Course/final_project/train_rnn.py"   # 你的脚本路径

SAVE_ROOT = os.path.join(DATA, "checkpoints")  # 默认也行，你也可以自定义

In [2]:
base = dict(
    data_dir=DATA,
    vocab_zh=f"{DATA}/vocab_zh.json",
    vocab_en=f"{DATA}/vocab_en.json",
    epochs=10,
    batch_size=64,
    emb=256,
    hidden=512,
    rnn_type="lstm",
    lr=3e-4,
    seed=42,
    decode="greedy",   # 训练期先 greedy，后面单独做 beam 对比
    beam_size=4,
)

experiments = [
    dict(**base, attn="dot",     free_running=False),
    dict(**base, attn="general", free_running=False),
    dict(**base, attn="additive",free_running=False),
]

In [3]:
def to_cmd(cfg):
    cmd = [PY, SCRIPT,
           "--data_dir", cfg["data_dir"],
           "--vocab_zh", cfg["vocab_zh"],
           "--vocab_en", cfg["vocab_en"],
           "--epochs", str(cfg["epochs"]),
           "--batch_size", str(cfg["batch_size"]),
           "--emb", str(cfg["emb"]),
           "--hidden", str(cfg["hidden"]),
           "--rnn_type", cfg["rnn_type"],
           "--attn", cfg["attn"],
           "--lr", str(cfg["lr"]),
           "--decode", cfg["decode"],
           "--beam_size", str(cfg["beam_size"]),
           "--seed", str(cfg["seed"]),
           "--save_root", SAVE_ROOT
    ]
    if cfg.get("free_running", False):
        cmd.append("--free_running")
    return cmd

def run_experiment(cfg):
    cmd = to_cmd(cfg)
    print(" ".join(cmd))
    subprocess.run(cmd, check=True)

for cfg in experiments:
    run_experiment(cfg)


python /root/Course/final_project/train_rnn.py --data_dir /root/Course/final_project/nmt_data_jieba_100k --vocab_zh /root/Course/final_project/nmt_data_jieba_100k/vocab_zh.json --vocab_en /root/Course/final_project/nmt_data_jieba_100k/vocab_en.json --epochs 10 --batch_size 64 --emb 256 --hidden 512 --rnn_type lstm --attn dot --lr 0.0003 --decode greedy --beam_size 4 --seed 42 --save_root /root/Course/final_project/nmt_data_jieba_100k/checkpoints
[Epoch 1] train_loss=5.9477  valid_BLEU4=1.86
[Epoch 2] train_loss=4.9657  valid_BLEU4=2.16
[Epoch 3] train_loss=4.3846  valid_BLEU4=2.82
[Epoch 4] train_loss=3.9308  valid_BLEU4=4.00
[Epoch 5] train_loss=3.5847  valid_BLEU4=4.18
[Epoch 6] train_loss=3.3108  valid_BLEU4=5.19
[Epoch 7] train_loss=3.0839  valid_BLEU4=5.70
[Epoch 8] train_loss=2.8908  valid_BLEU4=5.78
[Epoch 9] train_loss=2.7207  valid_BLEU4=5.99
[Epoch 10] train_loss=2.5702  valid_BLEU4=5.58
Saved run dir: /root/Course/final_project/nmt_data_jieba_100k/checkpoints/rnn/nmt_data_ji