In [1]:
from datasets import load_dataset
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
import spacy

# 加載 IWSLT 2017 英中翻譯資料集
dataset = load_dataset('iwslt2017', 'iwslt2017-en-zh', split='test[:100]')

# 加載微調後的模型和 tokenizer
model = MBartForConditionalGeneration.from_pretrained("./mbart_finetuned")
tokenizer = MBart50Tokenizer.from_pretrained("./mbart_finetuned")

# 設定源語言和目標語言
tokenizer.src_lang = "en_XX"
model.config.forced_bos_token_id = tokenizer.lang_code_to_id["zh_CN"]


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 加載英文的 NER 模型
nlp = spacy.load("en_core_web_sm")

# 命名實體識別函數
def identify_entities(text):
    doc = nlp(text)
    entities = [(ent.text, ent.label_) for ent in doc.ents]
    return entities


In [3]:
def mark_entities(text):
    entities = identify_entities(text)
    modified_text = text
    for entity, label in entities:
        entity_marker = f"<{label}:{entity}>"
        modified_text = modified_text.replace(entity, entity_marker)
    return modified_text, entities


In [4]:
def translate_with_entities(text):
    # 將實體標記並翻譯
    marked_text, entities = mark_entities(text)
    inputs = tokenizer(marked_text, return_tensors="pt")
    
    translated_tokens = model.generate(
        inputs["input_ids"],
        max_length=128,
        forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"]
    )
    translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
    
    return translated_text, entities


In [5]:
def postprocess_translation(translated_text, entities):
    # 使用原實體替換翻譯結果中的標記
    for entity, label in entities:
        entity_marker = f"<{label}:{entity}>"
        translated_text = translated_text.replace(entity_marker, entity)
    return translated_text


In [6]:
def entity_aware_translate(text):
    # 1. 實體標記並翻譯
    translated_text, entities = translate_with_entities(text)
    
    # 2. 還原實體
    final_translation = postprocess_translation(translated_text, entities)
    return final_translation

# 測試 Entity-Aware 翻譯
for example in dataset:
    input_text = example["translation"]["en"]
    final_translation = entity_aware_translate(input_text)
    print("Original:", input_text)
    print("Translation:", final_translation)
    print("=" * 50)




Original: Several years ago here at TED, Peter Skillman  introduced a design challenge  called the marshmallow challenge.
Translation: 几年前,在person:TED这里person:Peter Skillman 提出一个设计难题,叫做棉花糖难题
Original: And the idea's pretty simple:  Teams of four have to build the tallest free-standing structure  out of 20 sticks of spaghetti,  one yard of tape, one yard of string  and a marshmallow.
Translation: 这个想法很简单:  teams of four 必须用20个意大利面棒 建造最高的自由站立结构, 一 yard 的胶带, 一 yard 的绳子和一个棉花糖。
Original: The marshmallow has to be on top.
Translation: 棉花糖必须在上面。
Original: And, though it seems really simple, it's actually pretty hard  because it forces people  to collaborate very quickly.
Translation: 尽管它看起来很简单, 但它却很难因为 它迫使人们迅速合作。
Original: And so, I thought this was an interesting idea,  and I incorporated it into a design workshop.
Translation: 因此,我觉得这是一个有趣的主意, 我把它融入了一个设计工作室。
Original: And it was a huge success.
Translation: 这是一次巨大的成功。
Original: And since then, I've conducted  about 70 design workshops acros