<a href="https://colab.research.google.com/github/AlaFalaki/tutorial_notebooks/blob/main/translation/hf_bart_translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune BART for Translation on WMT16 Dataset (and Train new Tokenizer)
The code is the supplementary material to the story published in NLPiation medium. Follow [the link](https://medium.com/@nlpiation/fine-tune-bart-for-translation-on-wmt16-dataset-and-train-new-tokenizer-4d0fbdc4aa2e) for a detailed explanation of creating a new tokenizer and use it in Translation task.

You can easily run the following codes and play around to grasp a firm understanding of the concepts. You can get better results by requesting a GPU and adjusting the fine-tuning hyperparameters.

In [1]:
!pip install -q transformers==4.26.1 datasets==2.10.1

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m33.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 KB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m51.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 KB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 KB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

# Load Dataset

In [2]:
import datasets

In [3]:
dataset = datasets.load_dataset("stas/wmt16-en-ro-pre-processed", cache_dir="./wmt16-en_ro")

Downloading builder script:   0%|          | 0.00/5.04k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/828 [00:00<?, ?B/s]

Downloading and preparing dataset wmt16-en-ro-pre-processed/enro (download: 57.26 MiB, generated: 180.62 MiB, post-processed: Unknown size, total: 237.88 MiB) to /content/wmt16-en_ro/stas___wmt16-en-ro-pre-processed/enro/1.1.0/c4093132d2665734cbb5098992e5cdf3cdbd807b80a5913a456ab7cb8c34ab2b...


Downloading data:   0%|          | 0.00/60.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/610320 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1999 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1999 [00:00<?, ? examples/s]

Dataset wmt16-en-ro-pre-processed downloaded and prepared to /content/wmt16-en_ro/stas___wmt16-en-ro-pre-processed/enro/1.1.0/c4093132d2665734cbb5098992e5cdf3cdbd807b80a5913a456ab7cb8c34ab2b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
print(dataset['train'][0])

{'translation': {'en': 'Membership of Parliament: see Minutes', 'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}


In [4]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 610320
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
})


In [9]:
def flatten(batch):
    batch['en'] = batch['translation']['en']
    batch['ro'] = batch['translation']['ro']
    
    return batch

In [10]:
train = dataset['train'].map( flatten )

Map:   0%|          | 0/610320 [00:00<?, ? examples/s]

In [14]:
print("en => ", train[0]['en'])
print("ro => ", train[0]['ro'])

en =>  Membership of Parliament: see Minutes
ro =>  Componenţa Parlamentului: a se vedea procesul-verbal


In [7]:
test = dataset['test'].map( flatten )
validation = dataset['validation'].map( flatten )

Map:   0%|          | 0/1999 [00:00<?, ? examples/s]

In [9]:
train.save_to_disk("./dataset/train")
test.save_to_disk("./dataset/test")
validation.save_to_disk("./dataset/validation")

Saving the dataset (0/1 shards):   0%|          | 0/610320 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1999 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1999 [00:00<?, ? examples/s]

# Create Tokenizer

In [15]:
from tokenizers import normalizers, pre_tokenizers, Tokenizer, models, trainers

In [16]:
# Build a tokenizer
bpe_tokenizer = Tokenizer(models.BPE())
bpe_tokenizer.normalizer = normalizers.Lowercase()
bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

In [17]:
trainer = trainers.BpeTrainer(
    vocab_size=50265,
    special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"],
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)

In [18]:
def batch_iterator():
    batch_length = 1000
    for i in range(0, len(train), batch_length):
        yield train[i : i + batch_length]["ro"]

In [21]:
bpe_tokenizer.train_from_iterator( batch_iterator(), length=len(train), trainer=trainer )

In [22]:
bpe_tokenizer.save("./ro_tokenizer.json")

In [15]:
# To read the tokenizer later:
# from transformers import PreTrainedTokenizerFast
# tmp = PreTrainedTokenizerFast.from_pretrained('./ro_tokenizer.json')

# Fine-Tuning

In [16]:
from transformers import BartForConditionalGeneration, AutoTokenizer, PreTrainedTokenizerFast
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_from_disk

## Load model

In [17]:
model = BartForConditionalGeneration.from_pretrained(  "facebook/bart-base" )

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/558M [00:00<?, ?B/s]

## Load Tokenizers

In [18]:
en_tokenizer = AutoTokenizer.from_pretrained( "facebook/bart-base" );
ro_tokenizer = PreTrainedTokenizerFast.from_pretrained( "./ro_tokenizer.json" );

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



In [19]:
ro_tokenizer.pad_token = en_tokenizer.pad_token

# Trainer

In [27]:
train = load_from_disk("./dataset/train")
test = load_from_disk("./dataset/test")
validation = load_from_disk("./dataset/validation")

In [28]:
def tokenize_dataset(sample):
    input = en_tokenizer(sample['en'], padding='max_length', max_length=120, truncation=True)
    label = ro_tokenizer(sample['ro'], padding='max_length', max_length=120, truncation=True)

    input["decoder_input_ids"] = label["input_ids"]
    input["decoder_attention_mask"] = label["attention_mask"]
    input["labels"] = label["input_ids"]

    return input

In [29]:
train = train.select(range(2000))
test = test.select(range(100))
validation = validation.select(range(100))

In [30]:
train_tokenized = train.map(tokenize_dataset, batched=True)
test_tokenized = test.map(tokenize_dataset, batched=True)
validation_tokenized = validation.map(tokenize_dataset, batched=True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [34]:
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=64,  # set to 500 for full training
    eval_steps=64,  # set to 8000 for full training
    warmup_steps=1,  # set to 2000 for full training
    max_steps=128, # delete for full training
    overwrite_output_dir=True,
    save_total_limit=3,
    fp16=False, # True if GPU
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [35]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=validation_tokenized,
)

max_steps is given, it will override any value given in num_train_epochs


In [36]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: ro, translation, en. If ro, translation, en are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2000
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 128
  Number of trainable parameters = 139420416


Step,Training Loss,Validation Loss
64,0.8551,1.451837
128,0.2496,0.542955


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: ro, translation, en. If ro, translation, en are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 2
Saving model checkpoint to ./checkpoint-64
Configuration saved in ./checkpoint-64/config.json
Configuration saved in ./checkpoint-64/generation_config.json
Model weights saved in ./checkpoint-64/pytorch_model.bin
The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: ro, translation, en. If ro, translation, en are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 2
Saving model checkpoint to ./checkpoint-128
Configuration saved in ./checkpoi

TrainOutput(global_step=128, training_loss=1.7540686773136258, metrics={'train_runtime': 33.0234, 'train_samples_per_second': 7.752, 'train_steps_per_second': 3.876, 'total_flos': 18292093747200.0, 'train_loss': 1.7540686773136258, 'epoch': 0.13})