In [None]:
import os
from transformers import MBartForConditionalGeneration, MBart50Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback
from datasets import load_dataset
import torch

# 確保 Google Drive 已掛載
from google.colab import drive
drive.mount('/content/drive')

# 模型輸出路徑
output_model_dir = "/content/drive/MyDrive/mbart_finetuned_full_training"

# 載入 mBART 模型和 Tokenizer
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to('cuda' if torch.cuda.is_available() else 'cpu')

# 設定源語言和目標語言
tokenizer.src_lang = "en_XX"
tokenizer.tgt_lang = "zh_CN"

# 載入 IWSLT 2017 英中翻譯資料集
dataset = load_dataset('iwslt2017', 'iwslt2017-en-zh', split='train[:20000]', trust_remote_code=True)

# 預處理函數
def preprocess_function(examples):
    inputs = [ex["en"] for ex in examples["translation"]]
    targets = [ex["zh"] for ex in examples["translation"]]

    model_inputs = tokenizer(inputs, max_length=256, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=256, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# 對數據集進行 Tokenize
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# 劃分數據集
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

# 訓練參數設置
training_args = Seq2SeqTrainingArguments(
    output_dir=output_model_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,  # 訓練20個EPOCH
    predict_with_generate=True,
    logging_dir='./logs',
    load_best_model_at_end=True,
)

# 初始化 Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],  # 早停機制
)

# 開始訓練
print("Starting full training...")
trainer.train()

# 保存模型權重
model.save_pretrained(output_model_dir)
tokenizer.save_pretrained(output_model_dir)
print(f"Final model and tokenizer saved to {output_model_dir}")
