In [1]:
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import sacrebleu

In [2]:
def read_binary_file(file_path):
    with open(file_path, 'rb') as file:
        lines = file.read().decode('utf-8').split('\n')
    return lines


gujarati_text = read_binary_file('test_datasets/dev.guj_Gujr')
nepali_text = read_binary_file('test_datasets/dev.npi_Deva')
burmese_text = read_binary_file('test_datasets/dev.mya_Mymr')
khmer_text = read_binary_file('test_datasets/dev.khm_Khmr')
galician_text = read_binary_file('test_datasets/dev.glg_Latn')
english_labels = read_binary_file('test_datasets/dev.eng_Latn')
english_labels = [[i] for i in english_labels]

In [3]:
model_name = "facebook/mbart-large-50-many-to-one-mmt"
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to('cuda')


def translate(src_lang, tokenizer, model, text):
    tokenizer.src_lang = src_lang
    encoded_text = tokenizer(text, return_tensors = "pt", padding = True).to('cuda')
    generated_tokens = model.generate(**encoded_text, 
                                      forced_bos_token_id = tokenizer.lang_code_to_id['en_XX'])
    translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens = True)
    return translation


def batch_translate(src_lang, tokenizer, model, texts, batch_size=16):
    results = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_results = translate(src_lang, tokenizer, model, batch_texts)
        results.extend(batch_results)
    return results

gujarati_translations = batch_translate("gu_IN", tokenizer, model, gujarati_text)
torch.cuda.empty_cache()
nepali_translations = batch_translate("ne_NP", tokenizer, model, nepali_text)
torch.cuda.empty_cache()
burmese_translations = batch_translate("my_MM", tokenizer, model, burmese_text)
torch.cuda.empty_cache()
khmer_translations = batch_translate("km_KH", tokenizer, model, khmer_text)
torch.cuda.empty_cache()
galician_translations = batch_translate("gl_ES", tokenizer, model, galician_text)

tokenizer_config.json:   0%|          | 0.00/461 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.51k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


generation_config.json:   0%|          | 0.00/268 [00:00<?, ?B/s]

In [4]:
gujarati_bleu = sacrebleu.corpus_bleu(gujarati_translations, english_labels)
print(f"BLEU score on Gujarati: {gujarati_bleu.score}")

nepali_bleu = sacrebleu.corpus_bleu(nepali_translations, english_labels)
print(f"BLEU score on Nepali: {nepali_bleu.score}")

burmese_bleu = sacrebleu.corpus_bleu(burmese_translations, english_labels)
print(f"BLEU score on Burmese: {burmese_bleu.score}")

khmer_bleu = sacrebleu.corpus_bleu(khmer_translations, english_labels)
print(f"BLEU score on Khmer: {khmer_bleu.score}")

galician_bleu = sacrebleu.corpus_bleu(galician_translations, english_labels)
print(f"BLEU score on Galician: {galician_bleu.score}")

# overall_translations = [[gujarati_translations[i], nepali_translations[i], 
#                          burmese_translations[i], khmer_translations[i], 
#                          galician_translations[i]] for i in range(len(english_labels))]

# overall_bleu = sacrebleu.corpus_bleu(english_labels, overall_translations)
# print(f"BLEU score on Overall: {overall_bleu.score}")

BLEU score on Gujarati: 14.974611712121044
BLEU score on Nepali: 59.7077331719517
BLEU score on Burmese: 25.783134591199115
BLEU score on Khmer: 42.38979934239455
BLEU score on Galician: 59.34054545534634
