In [None]:
import transformers

print(transformers.__version__)

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [39]:
from transformers.utils import send_example_telemetry

send_example_telemetry("translation_notebook", framework="pytorch")

In [40]:
model_checkpoint = "Helsinki-NLP/opus-mt-ru-en"

In [41]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("wmt16", "ru-en")
metric = load_metric("sacrebleu")

Found cached dataset wmt16 (C:/Users/Арина/.cache/huggingface/datasets/wmt16/ru-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)


  0%|          | 0/3 [00:00<?, ?it/s]

In [42]:
raw_datasets


DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 1516162
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2818
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 2998
    })
})

In [43]:
raw_datasets["train"][0]

{'translation': {'en': 'iron cement is a ready for use paste which is laid as a fillet by putty knife or finger in the mould edges (corners) of the steel ingot mould.',
  'ru': 'iron cement - это готовая к использованию паста, которая наносится шпателем или пальцами в виде закругленного перехода в углы сталелитейного кокиля.'}}

In [44]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [45]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,translation
0,"{'en': 'The report’s argument is that while the stated motivation for ultra-loose monetary policy might be to guard against deflation and promote economic growth at a time when demand is weak, low interest rates also help governments fund their debt very cheaply. Moreover, as we enter the eighth year of aggressive easing, unintended consequences are starting to appear – notably asset-price bubbles, increasing economic inequality (as wealthier investors able to hold equities benefit at the expense of small savers), and the risk of higher inflation in the future.', 'ru': 'В докладе утверждается, что, хотя декларируемыми целями сверхмягкой монетарной политики являются борьба с дефляцией и содействие экономическому росту в условиях низкого спроса, низкие процентные ставки одновременно помогают властям очень дешево занимать в долг.'}"
1,"{'en': 'A drugs PR-test 2010-03-11 22:36 The Georgian government is going to pass a drugs test at the suggestion of the parliamentary opposition.', 'ru': 'Об Ардзинбе вместо эпитафии 04.03.2010 | 18:47 В московской больнице скончался первый президент Абхазии Владислав Ардзинба.'}"
2,"{'en': 'Málaga del Fresno', 'ru': 'Малага-дель-Фресно'}"
3,"{'en': 'List of minor planets/142701-142800', 'ru': 'Список астероидов'}"
4,"{'en': 'Alexandria Zoo', 'ru': 'Александрийский зоопарк'}"


In [46]:
metric

Metric(name: "sacrebleu", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: """
Produces BLEU scores along with its sufficient statistics
from a source against one or more references.

Args:
    predictions (`list` of `str`): list of translations to score. Each translation should be tokenized into a list of tokens.
    references (`list` of `list` of `str`): A list of lists of references. The contents of the first sub-list are the references for the first prediction, the contents of the second sub-list are for the second prediction, etc. Note that there must be the same number of references for each prediction (i.e. all sub-lists must be of the same length).
    smooth_method (`str`): The smoothing method to use, defaults to `'exp'`. Possible values are:
        - `'none'`: no smoothing
        - `'floor'`: increment zero counts
        - `'add-k'`: increment num/deno

In [47]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = [["hello there"], ["general kenobi"]]
metric.compute(predictions=fake_preds, references=fake_labels)

{'score': 0.0,
 'counts': [4, 2, 0, 0],
 'totals': [4, 2, 0, 0],
 'precisions': [100.0, 100.0, 0.0, 0.0],
 'bp': 1.0,
 'sys_len': 4,
 'ref_len': 4}

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [50]:
if "mbart" in model_checkpoint:
    tokenizer.src_lang = "en-XX"
    tokenizer.tgt_lang = "ro-RO"

In [None]:
tokenizer("Hello, this one sentence!")

In [None]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

In [17]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "translate Russian to En: "
else:
    prefix = ""

In [20]:
max_input_length = 128
max_target_length = 128
source_lang = "ru"
target_lang = "en"

def preprocess_function(examples):
    inputs = [prefix + ex[source_lang] for ex in examples["translation"]]
    targets = [ex[target_lang] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
preprocess_function(raw_datasets['train'][:2])

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [23]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
batch_size = 16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-{source_lang}-to-{target_lang}",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,  
    push_to_hub=True,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()