In [None]:
!pip install transformers torch spacy tqdm evaluate bert-score nltk datasets OpenHowNet jieba

!python -m spacy download en_core_web_trf
import spacy
from datasets import load_dataset
from transformers import MBartForConditionalGeneration, MBart50Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback, TrainerCallback
import torch
import warnings
import os
from tqdm import tqdm
import jieba
import json
from bert_score import score
import OpenHowNet
from nltk.translate.bleu_score import sentence_bleu

# 加载 OpenHowNet 資源
print("Checking OpenHowNet resources...")
OpenHowNet.download()

warnings.filterwarnings("ignore")

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 載入 SpaCy NER 模型
nlp = spacy.load("en_core_web_trf")

# 載入 IWSLT 2017 英中翻譯資料集
train_dataset = load_dataset('iwslt2017', 'iwslt2017-en-zh', split='train', trust_remote_code=True)
test_dataset = load_dataset('iwslt2017', 'iwslt2017-en-zh', split='test', trust_remote_code=True)

# 載入 mBART 模型和 Tokenizer
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)

# 設定源語言和目標語言
tokenizer.src_lang = "en_XX"
tokenizer.tgt_lang = "zh_CN"

# 實體識別和標記的函數
def mark_entities(text):
    doc = nlp(text)
    modified_text = text
    for ent in doc.ents:
        entity_marker = f"<ENTITY type=\"{ent.label_}\">{ent.text}</ENTITY>"
        modified_text = modified_text.replace(ent.text, entity_marker)
    return modified_text

# 預處理函數
def preprocess_function(examples):
    inputs, targets = [], []
    for ex in examples["translation"]:
        marked_text = mark_entities(ex["en"])
        inputs.append(marked_text)
        targets.append(ex["zh"])

    model_inputs = tokenizer(inputs, max_length=256, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=256, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# 預處理數據集
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True)

# 設定輸出目錄
output_dir = "/content/mbart_finetuned"
if os.path.exists(output_dir):
    if os.path.isfile(output_dir):
        os.remove(output_dir)
    else:
        import shutil
        shutil.rmtree(output_dir)
os.makedirs(output_dir)

# 自定義回調函數來顯示 Training Loss
class CustomLogCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if "loss" in logs:
            print(f"Step {state.global_step}, Training Loss: {logs['loss']:.4f}")
        if "eval_loss" in logs:
            print(f"Step {state.global_step}, Validation Loss: {logs['eval_loss']:.4f}")

# 訓練參數設置
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=1,
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir='/content/logs',
    load_best_model_at_end=True,
)

# 初始化 Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3), CustomLogCallback()],
)

# 開始訓練
print("Training started...")
trainer.train()

# 儲存模型
final_model_dir = "/content/final_model"
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)
print(f"Model saved to {final_model_dir}")

# 加載同義詞集合
hownet_dict = OpenHowNet.HowNetDict()
hownet_dict.initialize_similarity_calculation()

# 加載哈工大詞林
def load_cilin(file_path):
    synonym_groups = []
    with open(file_path, encoding='utf-8') as f:
        for line in f:
            if "=" in line:
                words = line.strip().split('=')[1].split()
                synonym_groups.append(set(words))
    return synonym_groups

cilin_path = "/content/cilin.txt"
synonym_groups = load_cilin(cilin_path)

# 更新评估函数
def evaluate_translations(model, tokenizer, dataset, output_path, synonym_groups, num_translations=5):
    print("Starting evaluation...")
    translated_results = []
    total_meteor = 0
    total_bleu = 0
    total_bert_p = 0
    total_bert_r = 0
    total_bert_f1 = 0
    num_sentences = 0

    for example in tqdm(dataset):
        input_text = example["translation"]["en"]
        reference_text = jieba.lcut(example["translation"]["zh"])  # 分词后的参考翻译

        # 模型生成多个翻译
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256).to(device)
        outputs = model.generate(
            inputs["input_ids"],
            max_length=256,
            num_return_sequences=num_translations,
            num_beams=num_translations,
            early_stopping=True
        )
        translations = [jieba.lcut(tokenizer.decode(output, skip_special_tokens=True)) for output in outputs]

        # 计算 METEOR 分数
        meteor_scores = [
            calculate_meteor(translation, reference_text, synonym_groups)["METEOR"]
            for translation in translations
        ]

        # 找到最佳翻译（基于 METEOR）
        best_translation_idx = max(range(len(meteor_scores)), key=lambda idx: meteor_scores[idx])
        best_translation = translations[best_translation_idx]
        best_meteor_score = meteor_scores[best_translation_idx]

        # 转换最佳翻译为字符串形式以计算 BLEU 和 BERTScore
        best_translation_str = "".join(best_translation)
        reference_text_str = "".join(reference_text)

        # 计算 BLEU 分数
        bleu_score = sentence_bleu([reference_text], best_translation)
        total_bleu += bleu_score

        # 计算 BERTScore
        bert_p, bert_r, bert_f1 = score(
            cands=[best_translation_str],
            refs=[reference_text_str],
            lang="zh",
            verbose=False
        )
        best_bert_p = bert_p.mean().item()
        best_bert_r = bert_r.mean().item()
        best_bert_f1 = bert_f1.mean().item()

        # 累积分数
        total_meteor += best_meteor_score
        total_bert_p += best_bert_p
        total_bert_r += best_bert_r
        total_bert_f1 += best_bert_f1
        num_sentences += 1

        translated_results.append({
            "Original Text": input_text,
            "Reference Text": reference_text_str,
            "All Translations": ["".join(translation) for translation in translations],
            "Best Translation": best_translation_str,
            "BLEU Score": bleu_score,
            "METEOR Score": best_meteor_score,
            "BERTScore Precision": best_bert_p,
            "BERTScore Recall": best_bert_r,
            "BERTScore F1": best_bert_f1
        })

    # 保存结果
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(translated_results, f, ensure_ascii=False, indent=4)

    # 输出总体分数
    print(f"Overall BLEU: {total_bleu / num_sentences:.4f}")
    print(f"Overall METEOR: {total_meteor / num_sentences:.4f}")
    print(f"Overall BERTScore Precision: {total_bert_p / num_sentences:.4f}")
    print(f"Overall BERTScore Recall: {total_bert_r / num_sentences:.4f}")
    print(f"Overall BERTScore F1: {total_bert_f1 / num_sentences:.4f}")
    print(f"Evaluation results saved to {output_path}")

# 執行評估
evaluate_translations(model, tokenizer, test_dataset, "/content/translated_results.json", synonym_groups, num_translations=1)
