In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pip install sentencepiece

You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

Loading and processing data

In [4]:
from datasets import load_dataset

In [5]:
traindata = load_dataset("wmt16","de-en", split="train[:50000]")
valdata = load_dataset("wmt16","de-en", split="validation")
testdata = load_dataset("wmt16","de-en", split="test")

Found cached dataset wmt16 (/Users/harshvardhan/.cache/huggingface/datasets/wmt16/de-en/1.0.0/f5dc442f4d1c2cc487cd2d5591af56c03a5f03bb98a3bb92151d015c8c9cb7ad)
Found cached dataset wmt16 (/Users/harshvardhan/.cache/huggingface/datasets/wmt16/de-en/1.0.0/f5dc442f4d1c2cc487cd2d5591af56c03a5f03bb98a3bb92151d015c8c9cb7ad)
Found cached dataset wmt16 (/Users/harshvardhan/.cache/huggingface/datasets/wmt16/de-en/1.0.0/f5dc442f4d1c2cc487cd2d5591af56c03a5f03bb98a3bb92151d015c8c9cb7ad)


In [6]:
# en_list_test = []
# de_list_test = []
# en_list_val = []
# de_list_val = []

In [7]:
references_test = []
for i in testdata:
    references_test.append([i['translation']['de']])
print(references_test)

[['Obama empfängt Netanyahu'], ['Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich.'], ['Die beiden wollten über die Umsetzung der internationalen Vereinbarung sowie über Teherans destabilisierende Maßnahmen im Nahen Osten sprechen.'], ['Bei der Begegnung soll es aber auch um den Konflikt mit den Palästinensern und die diskutierte Zwei-Staaten-Lösung gehen.'], ['Das Verhältnis zwischen Obama und Netanyahu ist seit Jahren gespannt.'], ['Washington kritisiert den andauernden Siedlungsbau Israels und wirft Netanyahu mangelnden Willen beim Friedensprozess vor.'], ['Durch den von Obama beworbenen Deal um das iranische Atomprogramm hat sich die Beziehung der beiden weiter verschlechtert.'], ['Im März hatte Netanyahu auf Einladung der Republikaner vor dem US-Kongress eine umstrittene Rede gehalten, die teils als Affront gegen Obama gewertet wurde.'], ['Die Rede war mit Obama nicht abgesprochen, ein Treffen hatte dieser mit Hinweis auf die seinerzeit bevorstehende W

In [8]:
references_val = []
for i in valdata:
    references_val.append([i['translation']['de']])
print(references_val)

[['Die Premierminister Indiens und Japans trafen sich in Tokio.'], ['Indiens neuer Premierminister Narendra Modi trifft bei seinem ersten wichtigen Auslandsbesuch seit seinem Wahlsieg im Mai seinen japanischen Amtskollegen Shinzo Abe in Toko, um wirtschaftliche und sicherheitspolitische Beziehungen zu besprechen.'], ['Herr Modi befindet sich auf einer fünftägigen Reise nach Japan, um die wirtschaftlichen Beziehungen mit der drittgrößten Wirtschaftsnation der Welt zu festigen.'], ['Pläne für eine stärkere kerntechnische Zusammenarbeit stehen ganz oben auf der Tagesordnung.'], ['Berichten zufolge hofft Indien darüber hinaus auf einen Vertrag zur Verteidigungszusammenarbeit zwischen den beiden Nationen.'], ['Polizei von Karratha verhaftet 20-Jährigen nach schneller Motorradjagd'], ['Ein Motorrad wurde beschlagnahmt, nachdem der Fahrer es mit 125 km/h in einer 70 km/h-Zone und durch Buschland gefahren hatte, um der Polizei in Bilbara zu entkommen.'], ['Verkehrspolizisten in Karratha versuc

Getting the translated texts

In [9]:
def translate(batch, tokenizer,device):
    # Ensure that the model and inputs are on the same device
    model.to(device)
    txts = [f"translate English to German: {sentence}" for sentence in batch["translation"]['en']]
    inputs = tokenizer(txts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    outputs = model.generate(
        input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=512
    )
    translations = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    return translations

In [10]:
from tqdm import tqdm

In [11]:
from torch.utils.data import DataLoader

In [12]:
val_dataloader = DataLoader(valdata, batch_size=32, shuffle=False)
test_dataloader = DataLoader(testdata, batch_size=32, shuffle=False)

In [13]:
german_translations_val = []
for batch in tqdm(val_dataloader, desc="Translating validation set"):
    batch_translated = translate(batch, tokenizer, "cpu")
    german_translations_val.extend(batch_translated)

Translating validation set: 100%|██████████| 68/68 [11:26<00:00, 10.09s/it]


In [14]:
german_translations_test = []
for batch in tqdm(test_dataloader, desc="Translating Testing set"):
    batch_translated = translate(batch, tokenizer, "cpu")
    german_translations_test.extend(batch_translated)

Translating Testing set: 100%|██████████| 94/94 [15:33<00:00,  9.93s/it]


Evaluation

In [15]:
from evaluate import load
bleu = load("bleu")

# Compute BLEU scores for different n-gram levels
bleu1 = bleu.compute(predictions=german_translations_test, references=references_test, max_order=1)['bleu']
bleu2 = bleu.compute(predictions=german_translations_test, references=references_test, max_order=2)['bleu']
bleu3 = bleu.compute(predictions=german_translations_test, references=references_test, max_order=3)['bleu']
bleu4 = bleu.compute(predictions=german_translations_test, references=references_test, max_order=4)['bleu']

print("BLEU-1 score (Test) : ", bleu1)
print("BLEU-2 score (Test) : ", bleu2)
print("BLEU-3 score (Test) : ", bleu3)
print("BLEU-4 score (Test) : ", bleu4)

BLEU-1 score (Test) :  0.6173815483698936
BLEU-2 score (Test) :  0.4767904215770271
BLEU-3 score (Test) :  0.3812786762509227
BLEU-4 score (Test) :  0.31045530684044015


In [16]:
bleu1 = bleu.compute(predictions=german_translations_val, references=references_val, max_order=1)['bleu']
bleu2 = bleu.compute(predictions=german_translations_val, references=references_val, max_order=2)['bleu']
bleu3 = bleu.compute(predictions=german_translations_val, references=references_val, max_order=3)['bleu']
bleu4 = bleu.compute(predictions=german_translations_val, references=references_val, max_order=4)['bleu']

print("BLEU-1 score (Val) : ", bleu1)
print("BLEU-2 score (Val) : ", bleu2)
print("BLEU-3 score (Val) : ", bleu3)
print("BLEU-4 score (Val) : ", bleu4)

BLEU-1 score (Val) :  0.5896815800080613
BLEU-2 score (Val) :  0.4406826405673354
BLEU-3 score (Val) :  0.34410406607004834
BLEU-4 score (Val) :  0.27436246022629457


In [17]:
bertscore = load("bertscore")
results3 = bertscore.compute(predictions=german_translations_test, references=references_test, lang="de")
print("\nSentence-wise BERT score on test data:\n", results3)


Sentence-wise BERT score on test data:
 {'precision': [0.9144582152366638, 0.9650661945343018, 0.9685443639755249, 0.8797836303710938, 0.9306670427322388, 0.8365610837936401, 0.8025346994400024, 0.9454704523086548, 0.8436805605888367, 0.7950944900512695, 0.78751540184021, 0.8805647492408752, 0.8569702506065369, 0.9133842587471008, 0.5581915974617004, 0.8433668613433838, 0.9055695533752441, 0.8692854046821594, 0.884058952331543, 0.8737986087799072, 0.9497246742248535, 0.9169869422912598, 0.8986693620681763, 0.8915311098098755, 0.8915966153144836, 0.9064998626708984, 0.8764452934265137, 0.9407349228858948, 0.7961886525154114, 0.9097888469696045, 1.0, 0.8939446806907654, 0.8258265852928162, 0.9604731202125549, 0.9130258560180664, 0.9526132345199585, 0.9205688834190369, 0.8066465258598328, 0.9427056908607483, 0.9285317063331604, 0.9048322439193726, 0.806366503238678, 0.8659064769744873, 0.8386735916137695, 0.829745888710022, 0.8809801340103149, 0.9688714742660522, 0.6994218826293945, 0.92

In [18]:
print("Average BERT score on test data : ", sum(results3['precision'])/len(results3['precision']))

Average BERT score on test data :  0.8728996489734878


In [19]:
bertscore = load("bertscore")
results3 = bertscore.compute(predictions=german_translations_val, references=references_val, lang="de")
print("\nSentence-wise BERT score on Validation data:\n", results3)


Sentence-wise BERT score on Validation data:
 {'precision': [0.8879619240760803, 0.8536378741264343, 0.9581870436668396, 0.8868697285652161, 0.8901283740997314, 0.8531867265701294, 0.9224591255187988, 0.8628812432289124, 0.8858681917190552, 0.9476878643035889, 0.7949899435043335, 0.9396442174911499, 0.93113112449646, 0.8486418128013611, 0.8641343116760254, 0.8293249607086182, 0.9577564597129822, 0.7679856419563293, 0.869165301322937, 0.8831398487091064, 0.8126075267791748, 0.913248598575592, 0.8514277338981628, 0.8289997577667236, 0.8293420076370239, 0.8306722640991211, 0.9409663677215576, 0.855198860168457, 0.8567876815795898, 0.8525779247283936, 0.918196439743042, 0.8780519366264343, 0.921749472618103, 0.8556233644485474, 0.8353450894355774, 0.9243960380554199, 0.9005717039108276, 0.8808929324150085, 0.8155556917190552, 0.9676378965377808, 0.9380099773406982, 0.9131386280059814, 0.8555325269699097, 0.9450991749763489, 0.8979119062423706, 0.8978720307350159, 0.9318264722824097, 0.865

In [20]:
print("Average BERT score on Validation data : ", sum(results3['precision'])/len(results3['precision']))

Average BERT score on Validation data :  0.8624865656736646


In [21]:
meteor = load('meteor')
results4 = meteor.compute(predictions=german_translations_test, references=references_test)
print("Meteor score on test data : ", results4 , "\n")

[nltk_data] Error loading wordnet: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1091)>
[nltk_data] Error loading punkt: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1091)>
[nltk_data] Error loading omw-1.4: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1091)>


Meteor score on test data :  {'meteor': 0.5772148087831891} 



In [22]:
meteor = load('meteor')
results4 = meteor.compute(predictions=german_translations_val, references=references_val)
print("Meteor score on Validation data : ", results4 , "\n")

[nltk_data] Error loading wordnet: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1091)>
[nltk_data] Error loading punkt: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1091)>
[nltk_data] Error loading omw-1.4: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1091)>


Meteor score on Validation data :  {'meteor': 0.5458072768029277} 

