In [13]:
from datasets import load_dataset
import spacy
from transformers import MBartForConditionalGeneration, MBart50Tokenizer

# 載入 IWSLT 2017 英中翻譯資料集（取前 2000 筆）
dataset = load_dataset('iwslt2017', 'iwslt2017-en-zh', split='train[:2000]', trust_remote_code=True)

# 1. 載入英文 NER 模型
nlp = spacy.load("en_core_web_sm")

# 2. 載入 mBART 翻譯模型（微調後的 mBART 作為基礎翻譯模型）
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)

# 設定源語言和目標語言
tokenizer.src_lang = "en_XX"
tokenizer.tgt_lang = "zh_CN"
model.config.forced_bos_token_id = tokenizer.lang_code_to_id["zh_CN"]


In [14]:
def preprocess_entity_aware(text):
    # 使用 NER 模型來識別句中的實體
    doc = nlp(text)
    
    # 將實體標記加入句子中
    modified_text = text
    entities = []
    for ent in doc.ents:
        entity_marker = f"<{ent.label_}:{ent.text}>"
        modified_text = modified_text.replace(ent.text, entity_marker)
        entities.append((ent.text, ent.label_))

    return modified_text, entities


In [15]:
def entity_aware_translate(text):
    # 1. 對文本進行實體識別和標記
    processed_text, entities = preprocess_entity_aware(text)

    # 2. 進行翻譯
    inputs = tokenizer(processed_text, return_tensors="pt")
    translated_tokens = model.generate(inputs["input_ids"])
    translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

    # 3. 後處理：將實體替換回原來的翻譯結果中
    for ent_text, ent_label in entities:
        entity_marker = f"<{ent_label}:{ent_text}>"
        if entity_marker in translated_text:
            # 保留實體一致性
            translated_text = translated_text.replace(entity_marker, ent_text)

    return translated_text


In [17]:
# 測試模型
for example in dataset.select(range(10)):  # 測試前10個句子
    input_text = example['translation']['en']
    translated_text = entity_aware_translate(input_text)
    
    print("Original text:", input_text)
    print("Translated text:", translated_text)
    print("=" * 50)


TypeError: The current model class (MBartForConditionalGeneration) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'MBartForCausalLM', 'MBartForConditionalGeneration'}

In [24]:
import spacy
import torch
from transformers import MBartForConditionalGeneration, MBart50Tokenizer

# 載入 NER 模型
nlp = spacy.load("en_core_web_sm")

# 載入 mBART 模型
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)

# 設定源語言和目標語言
tokenizer.src_lang = "en_XX"
model.config.forced_bos_token_id = tokenizer.lang_code_to_id["zh_CN"]

# 預處理實體標記的函數
def preprocess_entity_aware(text):
    doc = nlp(text)
    modified_text = text
    entities = []
    for ent in doc.ents:
        # 在實體周圍加上 < 和 > 符號，並保留其標記
        entity_marker = f"<{ent.label_}:{ent.text}>"
        modified_text = modified_text.replace(ent.text, entity_marker)
        entities.append((ent.text, ent.label_))
    return modified_text, entities

# 使用 forward() 方法進行翻譯的實體識別和翻譯函數
def entity_aware_translate(text):
    # 1. 預處理：對文本進行實體識別和標記
    processed_text, entities = preprocess_entity_aware(text)
    inputs = tokenizer(processed_text, return_tensors="pt")

    # 2. 使用 generate 方法進行翻譯，確保設置了目標語言
    translated_tokens = model.generate(
        inputs["input_ids"],
        max_length=128,
        forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"]  # 指定目標語言為中文
    )
    translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

    # 3. 後處理：將實體替換回原來的翻譯結果中
    for ent_text, ent_label in entities:
        entity_marker = f"<{ent_label}:{ent_text}>"
        if entity_marker in translated_text:
            translated_text = translated_text.replace(entity_marker, ent_text)

    return translated_text

# 測試
input_text = "Microsoft was founded in 1975 by Bill Gates and Paul Allen."
translated_text = entity_aware_translate(input_text)
print("Original text:", input_text)
print("Translated text:", translated_text)


TypeError: The current model class (MBartForConditionalGeneration) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'MBartForCausalLM', 'MBartForConditionalGeneration'}

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "C:\Users\USER\anaconda3\envs\pytorch\lib\site-packages\transformers\utils\import_utils.py", line 1382, in _get_module
  File "C:\Users\USER\anaconda3\envs\pytorch\lib\importlib\__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 843, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "C:\Users\USER\anaconda3\envs\pytorch\lib\site-packages\transformers\models\gpt_neo\modeling_gpt_neo.py", line 26, in <module>
    from ...cache_utils import Cache, DynamicCache, StaticCache
ImportError: cannot import name 'StaticCache'