mBART50 - without fine-tuning (SWRC)

In [None]:
import os
import json
from datasets import load_dataset
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
from evaluate import load

# Load the test dataset
test_dataset = load_dataset('csv', data_files='ko_zh_test_dataset_donga.csv')

# Load BLEU metric
bleu_metric = load("sacrebleu")

# Initialize tokenizer
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# Define the translation function
def generate_translation(examples, model, max_length=200):
    target_language_code = "zh_CN"
    formatted_inputs = [f"ko_KR {sentence} </s>" for sentence in examples['source']]
    encoded_inputs = tokenizer(formatted_inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length)

    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=encoded_inputs['input_ids'],
            attention_mask=encoded_inputs['attention_mask'],
            forced_bos_token_id=tokenizer.lang_code_to_id[target_language_code]
        )

    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return {"generated_text": generated_texts}

# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)

# Generate translations and save them to a file
with open('generated_translation_Donga_1st_checkpoint_24225_Donga.txt', 'w', encoding='utf-8') as f:
    for idx, source in enumerate(test_dataset['train']['source']):
        generated_translation = generate_translation({"source": [source]}, model)  # Add model as argument
        f.write(generated_translation['generated_text'][0] + '\n')  # Access the first translation

        # Print progress every 100 sentences
        if (idx + 1) % 100 == 0:
            print(f"Processed {idx + 1} sentences")