## Load dataset and pip install

In [30]:
import torch

In [None]:
# ! pip install -q transformers sentencepiece datasets accelerate evaluate sacrebleu

In [1]:
from datasets import load_dataset

ds = load_dataset("thainq107/iwslt2015-en-vi")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
ds

DatasetDict({
    train: Dataset({
        features: ['en', 'vi'],
        num_rows: 133317
    })
    validation: Dataset({
        features: ['en', 'vi'],
        num_rows: 1268
    })
    test: Dataset({
        features: ['en', 'vi'],
        num_rows: 1268
    })
})

In [5]:
ds['train'][0]

{'en': 'Rachel Pike : The science behind a climate headline',
 'vi': 'Khoa học đằng sau một tiêu đề về khí hậu'}

## Tokenizer

### Import tokenizer

In [3]:
from transformers import AutoTokenizer

model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

In [15]:
tokenizer.pad_token_id

1

### Tokenize

In [22]:
MAX_LEN = 75

def preprocess_function(examples):
    input_ids = tokenizer(
        examples["en"], padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt"
    )['input_ids']
    labels = tokenizer(
        examples["vi"], padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt"
    )['input_ids']

    return {
        'input_ids' : input_ids, 
        'labels' : labels
    }

In [23]:
preprocessed_ds = ds.map(preprocess_function, batched=True)

Map: 100%|██████████| 133317/133317 [00:14<00:00, 9127.15 examples/s] 
Map: 100%|██████████| 1268/1268 [00:00<00:00, 8441.68 examples/s]
Map: 100%|██████████| 1268/1268 [00:00<00:00, 4063.97 examples/s]


## Model

In [24]:
from transformers import AutoModelForSeq2SeqLM

model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [25]:
model

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

## Trainner

### Test trainning

In [37]:
preds_sample = torch.tensor(preprocessed_ds['train'][0]['input_ids']).unsqueeze(0)
labels_sample = torch.tensor(preprocessed_ds['train'][0]['labels']).unsqueeze(0)
preds = model.generate(input_ids=preds_sample)
preds

tensor([[     2, 250024, 127055,  66937,     13,     12,  67766,   2546, 218877,
            858,    889,  10037,   6248,   1893,  17964,  42254,      2]])

In [38]:
decoded_pred = tokenizer.batch_decode(
    preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
)

decoded_label = tokenizer.batch_decode(
    labels_sample, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
print("decoded_pred", decoded_pred)
print("decoded_label", decoded_label)

decoded_pred ['Rachel Pike: Khoa học đằng sau một tiêu đề về khí hậu']
decoded_label ['Khoa học đằng sau một tiêu đề về khí hậu']


### Compute metrics

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_pred = tokenizer.batch_decode(
        preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    decoded_label = tokenizer.batch_decode(
        labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    decoded_pred, decoded_label = postprocess_text(decoded_pred, decoded_label)

    result = metric.compute(predictions=decoded_pred,
                            references=decoded_label)

    result = {"bleu": result["score"]}
    return result

### Trainner

In [None]:
# Disable wandb
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
import os
os.environ["WANDB_DISABLED"] = "true"


training_args = Seq2SeqTrainingArguments(
    output_dir="./mBart50/en-vi-mbart50",
    logging_dir="logs",
    logging_steps=1000,
    predict_with_generate=True,
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    save_total_limit=1,
    num_train_epochs=3,
    load_best_model_at_end=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=preprocessed_ds["train"],
    eval_dataset=preprocessed_ds["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,  # You can also use 'processing_class=tokenizer' if needed
    compute_metrics=compute_metrics,
)

## Inferences (beam-greedy search)

In [None]:
# Download model
import sacrebleu
from transformers import pipeline

translator = pipeline(model="thainq107/en-vi-mbart50")

# Test a sample with beam search
translated_text = translator("I go to school", num_beams=2)
print(translated_text)

# Greedy search for test set
pred_sentences = translator(
    ds["test"]["en"], batch_size=32, num_beams=1, do_sample=False
)

# Beam search for test set
pred_sentences = translator(ds["test"]["en"], batch_size=32, num_beams=5)

# Evaluate

bleu_score = sacrebleu.corpus_bleu(
    pred_sentences, [ds["test"]["vi"]], force=True)
print(bleu_score)