In [8]:
import datasets

train_data = datasets.load_dataset('wmt16','de-en',split='train[:50000]')
val_data = datasets.load_dataset('wmt16',"de-en", split="validation")
test_data = datasets.load_dataset('wmt16','de-en', split="test")

In [11]:
from datasets import load_metric
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.translate.bleu_score import sentence_bleu
import evaluate


bleu_metric = evaluate.load('bleu')
meteor_metric = load_metric('meteor')
bertscore_metric = load_metric('bertscore')


checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model_2b = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

def translate_2B(model_2b, tokenizer, text):
    input_text = "translate English to German: " + text
    input_ids = tokenizer.encode(input_text, return_tensors="pt")

    outputs = model_2b.generate(input_ids)
    german_translation = tokenizer.decode(outputs[0])

    return german_translation

translations = []
references = []
for item in val_data:
    english_text = item['translation']['en']
    german_translation = translate_2B(model_2b, tokenizer, english_text)
    translations.append(german_translation)
    references.append(item['translation']['de'])
    if(len(translations) % 500 == 0):
      print(len(translations))

bleu = bleu_metric.compute(predictions=translations, references=references)
meteor = meteor_metric.compute(predictions=translations, references=references)
bertscore = bertscore_metric.compute(predictions=translations, references=references, lang='en')

print(f"Validation BLEU-1: {bleu['precisions'][0]}")
print(f"Validation BLEU-2: {bleu['precisions'][1]}")
print(f"Validation BLEU-3: {bleu['precisions'][2]}")
print(f"Validation BLEU-4: {bleu['precisions'][3]}")
print(f"Validation METEOR: {meteor['meteor']}")
print(f"Validation BERTScore: {bertscore['f1']}")

translations = []
references = []
for item in test_data:
    english_text = item['translation']['en']
    german_translation = translate_2B(model_2b, tokenizer, english_text)
    translations.append(german_translation)
    references.append(item['translation']['de'])
    if(len(translations) % 50 == 0):
      print(len(translations))
      
bleu = bleu_metric.compute(predictions=translations, references=references)
meteor = meteor_metric.compute(predictions=translations, references=references)
bertscore = bertscore_metric.compute(predictions=translations, references=references, lang='en')

print(f"Test BLEU-1: {bleu['precisions'][0]}")
print(f"Test BLEU-2: {bleu['precisions'][1]}")
print(f"Test BLEU-3: {bleu['precisions'][2]}")
print(f"Test BLEU-4: {bleu['precisions'][3]}")
print(f"Test METEOR: {meteor['meteor']}")
print(f"Test BERTScore: {bertscore['f1']}")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\HP\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\HP\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\HP\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


500
1000
1500
2000


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Validation BLEU-1: 0.43757681124666525
Validation BLEU-2: 0.24355604001025905
Validation BLEU-3: 0.15256865244805842
Validation BLEU-4: 0.10020853504133462
Validation METEOR: 0.32681816960476157
Validation BERTScore: [0.9524190425872803, 0.8304336667060852, 0.8981911540031433, 0.9013258218765259, 0.8705294132232666, 0.8789345026016235, 0.8828963041305542, 0.87613445520401, 0.8759353756904602, 0.9117261171340942, 0.881777822971344, 0.934824526309967, 0.9171419143676758, 0.8552533388137817, 0.9110256433486938, 0.8622555732727051, 0.9005396962165833, 0.8267114758491516, 0.8509339690208435, 0.8591240048408508, 0.8765894174575806, 0.8849284648895264, 0.8635934591293335, 0.8607895374298096, 0.8441227674484253, 0.8325853943824768, 0.8952022790908813, 0.8507253527641296, 0.8261086940765381, 0.8616601228713989, 0.8789239525794983, 0.872600793838501, 0.8791310787200928, 0.8523179292678833, 0.8767527937889099, 0.8250548243522644, 0.8661631345748901, 0.8598069548606873, 0.8187151551246643, 0.96578

In [12]:
print(f"Test BLEU-1: {bleu['precisions'][0]}")
print(f"Test BLEU-2: {bleu['precisions'][1]}")
print(f"Test BLEU-3: {bleu['precisions'][2]}")
print(f"Test BLEU-4: {bleu['precisions'][3]}")
print(f"Test METEOR: {meteor['meteor']}")
print(f"Test BERTScore: {bertscore['f1']}")

Test BLEU-1: 0.45546615170046134
Test BLEU-2: 0.268678622207953
Test BLEU-3: 0.17319673947841505
Test BLEU-4: 0.11587960008508828
Test METEOR: 0.34849179722066537
Test BERTScore: [0.863649845123291, 0.921355664730072, 0.9038230776786804, 0.8924747705459595, 0.9258562326431274, 0.8644993305206299, 0.8429644703865051, 0.8770416975021362, 0.8773070573806763, 0.8750308156013489, 0.8455091118812561, 0.8470861911773682, 0.8262258172035217, 0.8300788402557373, 0.7915123701095581, 0.871814489364624, 0.8867651224136353, 0.8347976207733154, 0.8287848234176636, 0.8852449059486389, 0.9002132415771484, 0.848119854927063, 0.8794133067131042, 0.8976000547409058, 0.8844802379608154, 0.8714832663536072, 0.8623412847518921, 0.9692016243934631, 0.8655409216880798, 0.8796126246452332, 0.9730011224746704, 0.8910902738571167, 0.7956358194351196, 0.9542884230613708, 0.9006202816963196, 0.9619002938270569, 0.8976296186447144, 0.8671607971191406, 0.919253945350647, 0.9048522710800171, 0.9022316932678223, 0.870