In [1]:
import tqdm as notebook_tqdm
import ipywidgets as widgets
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
import evaluate
import numpy as np

2024-04-05 12:31:15.077348: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-05 12:31:15.077409: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-05 12:31:15.078938: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [21]:
%pip install bert_score

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bert_score
Successfully installed bert_score-0.3.13
Note: you may need to restart the kernel to use updated packages.


In [2]:
train_words = load_dataset("wmt16", "de-en", split="train[:50000]")
eval_words = load_dataset("wmt16", "de-en", split="validation")
test_words = load_dataset("wmt16", "de-en", split="test")

In [4]:
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [5]:
source_lang = "en"
target_lang = "de"
prefix = "translate English to German: "

def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs

In [6]:
tokenized_train = train_words.map(preprocess_function, batched=True)
tokenized_eval= eval_words.map(preprocess_function, batched=True)
tokenized_test= test_words.map(preprocess_function, batched=True)

  0%|          | 0/50 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [7]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [30]:
blue = evaluate.load("bleu")
meteor = evaluate.load("meteor")
bert = evaluate.load("bertscore")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
                          
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result_bleu = blue.compute(predictions=decoded_preds, references=decoded_labels)
    result_meteor = meteor.compute(predictions=decoded_preds, references=decoded_labels)
    result_BERT = bert.compute(predictions=decoded_preds, references=decoded_labels, model_type="distilbert-base-uncased")
    result = {"bleu": result_bleu['bleu'], 
              "bleu_precision": sum(result_bleu['precisions'])/len(result_bleu['precisions']),
              "meteor":result_meteor['meteor'],
              "BERT_precision":sum(result_BERT['precision'])/len(result_BERT['precision']),
              "BERT_recall":sum(result_BERT['recall'])/len(result_BERT['recall']),
              "BERT_F1":sum(result_BERT['f1'])/len(result_BERT['f1'])}
    return result

[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [9]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [31]:
training_args = Seq2SeqTrainingArguments(
    output_dir="translation",
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
#     weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    logging_strategy='epoch'
)

test_trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [32]:
store=test_trainer.predict(tokenized_test)
store.metrics

{'test_loss': 1.1487655639648438,
 'test_bleu': 0.13886489800968674,
 'test_bleu_precision': 0.3524607955529484,
 'test_meteor': 0.38955218595880303,
 'test_BERT_precision': 0.9218545817621313,
 'test_BERT_recall': 0.8812523247520698,
 'test_BERT_F1': 0.9006845349866416,
 'test_runtime': 35.2179,
 'test_samples_per_second': 85.156,
 'test_steps_per_second': 1.335}

In [33]:
store=test_trainer.predict(tokenized_eval)
store.metrics

{'test_loss': 1.2565711736679077,
 'test_bleu': 0.12778813916553747,
 'test_bleu_precision': 0.32624722234542136,
 'test_meteor': 0.36988885486383316,
 'test_BERT_precision': 0.9173205423706849,
 'test_BERT_recall': 0.8784083713854222,
 'test_BERT_F1': 0.8970375052241479,
 'test_runtime': 25.9865,
 'test_samples_per_second': 83.466,
 'test_steps_per_second': 1.308}