# 在验证集上微调3个epoch

In [3]:
import torch
from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import sacrebleu
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm 
import os


# 加载分词器
tokenizer = PreTrainedTokenizerFast.from_pretrained("../user_data/bart_tokenizer")

# 读取source和target文件
def load_txt_data(source_path, target_path):
    with open(source_path, "r", encoding="utf-8") as src_file, open(target_path, "r", encoding="utf-8") as tgt_file:
        source_sentences = [f"<zh> {line.strip()} </s>" for line in src_file.readlines()]
        target_sentences = [line.strip() for line in tgt_file.readlines()]
    return source_sentences, target_sentences

data_files = {
    "source": "../xfdata/多语言机器翻译挑战赛数据集更新（以此测试集提交得分为准）/val/中文/en-zh.txt",
    "target": "../xfdata/多语言机器翻译挑战赛数据集更新（以此测试集提交得分为准）/val/其他语言/en-zh.txt"
}

# 加载txt文件中的句子
source_sentences, target_sentences = load_txt_data(data_files["source"], data_files["target"])

# 将数据转换为datasets格式
dataset_dict = {"source": source_sentences, "target": target_sentences}
dataset = Dataset.from_dict(dataset_dict)
# Tokenize 函数
def tokenize_function(examples):
    source_texts = examples["source"]
    target_texts = examples["target"]

    # Tokenize source texts
    model_inputs = tokenizer(source_texts, max_length=128, truncation=True, padding="max_length", return_tensors="pt")
    
    # Tokenize target texts without using as_target_tokenizer context
    labels = tokenizer(target_texts, max_length=128, truncation=True, padding="max_length", return_tensors="pt")

    # 将 labels 直接添加到 model_inputs
    model_inputs["labels"] = labels["input_ids"]

    # 转换成字典格式，便于 datasets 库使用
    return {key: value.tolist() for key, value in model_inputs.items()}


tokenized_dataset = dataset.map(tokenize_function, batched=True)
print("Data loaded.")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("en-zh")
model = BartForConditionalGeneration.from_pretrained("../user_data/step1/en/results/checkpoint-154690").eval().to(device)
print("model is done!")

# 6. 设置训练参数
output_dir = "../user_data/step1/en/continue"
training_args = Seq2SeqTrainingArguments(
    output_dir=os.path.join(output_dir, "results"),         # 训练结果保存路径
    save_strategy="epoch",                                   # 按步数进行保存
    logging_strategy="epoch",
    logging_dir=os.path.join(output_dir, "logs"),           # 日志保存路径
    learning_rate=5e-5,                                     # 学习率
    per_device_train_batch_size=64,                         # 每个设备的训练批次大小
    per_device_eval_batch_size=256,                         # 每个设备的验证批次大小
    weight_decay=0.01,                                      # 权重衰减
    save_total_limit=3,                                    # 保存的 checkpoint 数量上限
    num_train_epochs=5,                                     # 训练 epoch 数
    predict_with_generate=True,                             # 使用生成模式进行评估
    bf16=True,                                              # 使用 bf16 精度
    
)

print("训练参数已设置完成！")


# 7. 使用 Seq2SeqTrainer 进行 微调
trainer = Seq2SeqTrainer(
    model=model,                                     # 模型
    args=training_args,                              # 训练参数
    train_dataset=tokenized_dataset,           # 训练数据集
    tokenizer=tokenizer,                             # 分词器
)

# 8. 开始训练
trainer.train()

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.
Map: 100%|██████████| 500/500 [00:00<00:00, 4365.78 examples/s]


Data loaded.
en-zh
model is done!
训练参数已设置完成！


Step,Training Loss
8,0.5878
16,0.4179
24,0.3632
32,0.3303
40,0.314


TrainOutput(global_step=40, training_loss=0.402643746137619, metrics={'train_runtime': 22.788, 'train_samples_per_second': 109.707, 'train_steps_per_second': 1.755, 'total_flos': 133115412480000.0, 'train_loss': 0.402643746137619, 'epoch': 5.0})