<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/T5/Fine_tuning_Dutch_T5_base_on_CNN_Daily_Mail_for_summarization_(on_TPU_using_HuggingFace_Accelerate).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Fine-tune T5-base-dutch to perform Dutch abstractive summarization on TPU

In this notebook, we are going to fine-tune a Dtuch `T5ForConditionalGeneration` model (namely `t5-base-dutch`) whose weights were the result of the [JAX/FLAX community week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104/57) at 🤗, in PyTorch on a Dutch summarization dataset, namely the Dutch translation of the CNN/Daily Mail dataset. We are going to fine-tune on Colab's TPU using [HuggingFace Accelerate](https://github.com/huggingface/accelerate). For data preparation, we are going to use [HuggingFace Datasets](https://github.com/huggingface/datasets).

Make sure to set Runtime to "TPU" before running this notebook 🤗.  

* T5 paper: https://arxiv.org/abs/1910.10683
* HuggingFace' T5 documentation: https://huggingface.co/transformers/master/model_doc/t5.html

Resources I used to make this notebook:
* Venelin Valkov's awesome Youtube videos, for example [this one](https://www.youtube.com/watch?v=r6XY80Z9eSA)
* The [official HuggingFace Accelerate TPU example](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/accelerate/simple_nlp_example.ipynb)

In [1]:
!pip install -q transformers datasets accelerate sentencepiece
!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

## Load dataset

Here we load the Dutch translation of the CNN/Daily Mail dataset, created by the Belgian AI company [ML6](https://www.ml6.eu/). It is hosted on HuggingFace's hub, as can be seen [here](https://huggingface.co/datasets/ml6team/cnn_dailymail_nl).

In [2]:
from datasets import load_dataset

# for demonstration purposes, we only download a small subset
train_ds, val_ds, test_ds = load_dataset("ml6team/cnn_dailymail_nl", split=['train[:1000]', 'validation[:500]', 'test[:200]'])

Using custom data configuration default
Reusing dataset cnn_dailymail_nl (/root/.cache/huggingface/datasets/cnn_dailymail_nl/default/0.0.0/73618cbc23f25331390bc0475f361e0531b47feee292c9cf84d396d6a3b9b608)


In [3]:
train_ds

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 10
})

Let's look at an example:

In [4]:
example = train_ds[0]
article = example['article']
summary = example['highlights']
print("Article:", article)
print("Summary:", summary)

Article: (CNN) -- de bewering van de Amerikaanse minister van Buitenlandse Zaken John Kerry dat terroristen "maar ze kunnen zich niet verbergen" na twee operaties in Afrika in het weekend is een herinnering dat Amerika's leger is steeds actiever op het continent. Het roept ook vragen op over de internationale wettigheid van dergelijke operaties, en hun lange termijn impact, vooral in zwakke Afrikaanse staten. In sommige gevallen Amerikaanse militaire engagementen in Afrika hebben al meer instabiliteit veroorzaakt in plaats van het verminderen van de risico's voor internationale vrede en veiligheid? Lees meer: Moet de VS vrezen Boko Haram? De Delta-eenheid van het Amerikaanse leger heeft de vermeende al-Qaeda-leider Abu Anas al Libi, die geboren werd nazih Abd al Hamid al Ruqhay, in Libië, is belangrijk voor de inspanningen van de VS tegen terrorisme. Een paar maanden geleden woonden president Barack Obama en voormalig president George W. Bush een herdenkingsdienst bij in Dar es Salaam 

Each example consists of an article and a corresponding summary. Easy, huh? Note that you can train T5 on any text-to-text problem. So it could be text as input as an SQL query (as text) as output, it could be text as input and a question related to that text as output (a task called question generation), etc.

## Encode the dataset

The T5 model, like any other Transformer model, does not directly expect text as input. Rather, it expects `input_ids` and `attention_mask`. The `input_ids` are integer vocabulary indices of the tokens of the text (you can read more about them [here](https://huggingface.co/transformers/glossary.html#input-ids)). As labels, it expects the `input_ids` of the summary. 

Let's encode them using the tokenizer. We also prepend the input with a so-called task prefix, which the authors of T5 used when fine-tuning the model. Here, the prefix is simply "Vat samen: " (which is Dutch for "Summarize: "), followed by a long document. Note that this prefix will probably not help a lot, since it has not seen this during pre-training. However, it is definitely useful if you want to fine-tune the English T5 model for summarization, as summarization is also part of its pre-training.

As we'll train the model on TPU, we pad both the inputs and targets up to the max length. If we were to train this model on GPUs, we would instead pad them up to the longest in a batch, which is more efficient in terms of memory. However, TPUs don't like that.

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-dutch")

prefix = "Vat samen: "
max_input_length = 512
max_target_length = 64

def preprocess_examples(examples):
  # encode the documents
  articles = examples['article']
  summaries = examples['highlights']
  
  inputs = [prefix + article for article in articles]
  model_inputs = tokenizer(inputs, max_length=max_input_length, padding="max_length", truncation=True)

  # encode the summaries
  labels = tokenizer(summaries, max_length=max_target_length, padding="max_length", truncation=True).input_ids

  # important: we need to replace the index of the padding tokens by -100
  # such that they are not taken into account by the CrossEntropyLoss
  labels_with_ignore_index = []
  for labels_example in labels:
    labels_example = [label if label != 0 else -100 for label in labels_example]
    labels_with_ignore_index.append(labels_example)
  
  model_inputs["labels"] = labels_with_ignore_index

  return model_inputs

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
encoded_train_ds = train_ds.map(preprocess_examples, batched=True, remove_columns=train_ds.column_names)
encoded_val_ds = val_ds.map(preprocess_examples, batched=True, remove_columns=val_ds.column_names)
encoded_test_ds = test_ds.map(preprocess_examples, batched=True, remove_columns=test_ds.column_names)

Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail_nl/default/0.0.0/73618cbc23f25331390bc0475f361e0531b47feee292c9cf84d396d6a3b9b608/cache-f86ffb3f2743aab0.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail_nl/default/0.0.0/73618cbc23f25331390bc0475f361e0531b47feee292c9cf84d396d6a3b9b608/cache-7c0388c292d6afd4.arrow


Let's verify an example, by decoding the `input_ids` back to text:

In [7]:
tokenizer.decode(encoded_train_ds[0]['input_ids'])

'Vat samen: (CNN) -- de bewering van de Amerikaanse minister van Buitenlandse Zaken John Kerry dat terroristen "maar ze kunnen zich niet verbergen" na twee operaties in Afrika in het weekend is een herinnering dat Amerika\'s leger is steeds actiever op het continent. Het roept ook vragen op over de internationale wettigheid van dergelijke operaties, en hun lange termijn impact, vooral in zwakke Afrikaanse staten. In sommige gevallen Amerikaanse militaire engagementen in Afrika hebben al meer instabiliteit veroorzaakt in plaats van het verminderen van de risico\'s voor internationale vrede en veiligheid? Lees meer: Moet de VS vrezen Boko Haram? De Delta-eenheid van het Amerikaanse leger heeft de vermeende al-Qaeda-leider Abu Anas al Libi, die geboren werd nazih Abd al Hamid al Ruqhay, in Libië, is belangrijk voor de inspanningen van de VS tegen terrorisme. Een paar maanden geleden woonden president Barack Obama en voormalig president George W. Bush een herdenkingsdienst bij in Dar es Sa

We also decode the `labels` (which aren't set to -100) back to text:

In [8]:
labels = encoded_train_ds[0]['labels']
print(labels)

[7352, 18, 23388, 1298, 4128, 14, 5, 2460, 8, 10, 760, 604, 7, 6095, 8139, 12, 13, 873, 3, 4, 295, 11, 2355, 11, 3114, 227, 192, 19, 1694, 39, 13, 1312, 15, 246, 4188, 71, 14, 10, 8140, 3, 4, 46, 3, 12869, 8, 2206, 3, 649, 5319, 15, 530, 9, 306, 8, 5, 1041, 3, 7, 222, 8139, 12, 13, 3, 1]


In [9]:
tokenizer.decode([x for x in labels if x != -100])

"Anti-terrorisme beleid leeft op de rand van het internationale recht, Alex Vines schrijft. Amerikaanse invallen in Afrika laten zien dat Amerika's leger is steeds actiever op het continent. Het opbouwen van professionele verantwoordelijk militair is slechts een deel van de oplossing, zegt Vines </s>"

Next, let's set the format to PyTorch.

In [10]:
encoded_train_ds.set_format(type="torch")
encoded_val_ds.set_format(type="torch")
encoded_test_ds.set_format(type="torch")

In [11]:
print("Number of training examples:", len(encoded_train_ds))
print("Number of validation examples:", len(encoded_val_ds))
print("Number of test examples:", len(encoded_test_ds))

Number of training examples: 10
Number of validation examples: 5
Number of test examples: 2


We define a function to create PyTorch dataloaders.

In [12]:
from torch.utils.data import DataLoader

def create_dataloaders(train_batch_size=8, eval_batch_size=32):
    train_dataloader = DataLoader(encoded_train_ds, shuffle=True, batch_size=train_batch_size)
    val_dataloader = DataLoader(encoded_val_ds, shuffle=False, batch_size=eval_batch_size)
    
    return train_dataloader, val_dataloader

## Fine-tune a model 

Below, we define a `training_function`, which defines a regular training loop in native PyTorch. We only need to add a few lines to make sure the code will run on TPU. The Accelerator object will take care of that. Basically, the model as well as the data will be replicated across each of the 8 TPU cores. 

We also define a dictionary of training-related hyperparameters, which we can easily tweak.

In [13]:
hyperparameters = {
    "learning_rate": 0.0001,
    "num_epochs": 1000, # set to very high number
    "train_batch_size": 2, # Actual batch size will this x 8 (was 8 before but can cause OOM)
    "eval_batch_size": 2, # Actual batch size will this x 8 (was 32 before but can cause OOM)
    "seed": 42,
    "patience": 3, # early stopping
    "output_dir": "/content/",
}

In [14]:
import torch
from transformers import T5ForConditionalGeneration, AdamW, set_seed
from accelerate import Accelerator
from tqdm.notebook import tqdm
import datasets
import transformers

def training_function():
    # Initialize accelerator
    accelerator = Accelerator()

    # To have only one message (and not 8) per logs of Transformers or Datasets, we set the logging verbosity
    # to INFO for the main process only.
    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # The seed need to be set before we instantiate the model, as it will determine the random head.
    set_seed(hyperparameters["seed"])

    # Instantiate the model, let Accelerate handle the device placement.
    model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-dutch")

    # Instantiate optimizer
    optimizer = AdamW(model.parameters(), lr=hyperparameters["learning_rate"])

    # Prepare everything
    train_dataloader, val_dataloader = create_dataloaders(
        train_batch_size=hyperparameters["train_batch_size"], eval_batch_size=hyperparameters["eval_batch_size"]
    )
    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
    # prepare method.
    model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, 
                                                                             train_dataloader, val_dataloader)
    
    # Now we train the model
    epochs_no_improve = 0
    min_val_loss = 1000000
    for epoch in range(hyperparameters["num_epochs"]):
        # We only enable the progress bar on the main process to avoid having 8 progress bars.
        progress_bar = tqdm(range(len(train_dataloader)), disable=not accelerator.is_main_process)
        progress_bar.set_description(f"Epoch: {epoch}")
        model.train()
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            
            optimizer.step()
            optimizer.zero_grad()
            progress_bar.set_postfix({'loss': loss.item()})
            progress_bar.update(1)

        # Evaluate at the end of the epoch (distributed evaluation as we have 8 TPU cores)
        model.eval()
        validation_losses = []
        for batch in val_dataloader:
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss

            # We gather the loss from the 8 TPU cores to have them all.
            validation_losses.append(accelerator.gather(loss[None]))

        # Compute average validation loss
        val_loss = torch.stack(validation_losses).sum().item() / len(validation_losses)
        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}: validation loss:", val_loss)
        if val_loss < min_val_loss:
          epochs_no_improve = 0
          min_val_loss = val_loss
          continue
        else:
          epochs_no_improve += 1
          # Check early stopping condition
          if epochs_no_improve == hyperparameters["patience"]:
            accelerator.print("Early stopping!")
            break

    # save trained model
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    # Use accelerator.save to save
    unwrapped_model.save_pretrained(hyperparameters["output_dir"], save_function=accelerator.save)



Next, we can easily start training by wrapping the `training_function` in a `notebook_launcher`.

In [15]:
from accelerate import notebook_launcher

notebook_launcher(training_function)

Launching a training on 8 TPU cores.


loading configuration file https://huggingface.co/flax-community/t5-base-dutch/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/429863f709b0983f0f5730aa35d1b92a471bc989bd44b2de2a0b4a21cac1a00a.3002a7d60c9ed2cc52b423329227d69c542ddfb8997267a81d19761eac5e8ed6
Model config T5Config {
  "_name_or_path": ".",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "gradient_checkpointing": false,
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200

  0%|          | 0/1 [00:00<?, ?it/s]

epoch 0: validation loss: 117.00708770751953


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 1: validation loss: 102.19902801513672


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 2: validation loss: 98.3016357421875


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 3: validation loss: 94.3941421508789


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 4: validation loss: 88.50342559814453


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 5: validation loss: 82.73766326904297


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 6: validation loss: 79.30978393554688


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 7: validation loss: 80.51023864746094


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 8: validation loss: 82.93736267089844


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 9: validation loss: 81.85161590576172
Early stopping!


Configuration saved in /content/config.json
Model weights saved in /content/pytorch_model.bin


## Inference

Now that we have a trained model, let's use it to generate a summary on a new, unseen text.

In [16]:
text = """De Rode Duivels moeten hun droom op een Europese titel opbergen. 
De Belgen verloren vrijdag in München in de kwartfinales van een uitgekookt Italië met 2-1.
België kwam via Barella en Insigne op een dubbele achterstand, maar Lukaku gaf nieuwe hoop 
met een strafschop. De gelijkmaker zat er echter niet meer in.
De Italianen begonnen het beste aan de wedstrijd, al bleven grote kansen uit. Even voor het 
kwartier leek Bonucci uit het niets na een vrije trap met de buik de score te openen, 
maar de videoref keurde de goal af voor buitenspel.
De Belgen moesten het evenwicht herstellen. De Bruyne waagde zijn kans met een afstandschot, dat 
Donnarumma geweldig met de vlakke hand uit doel ranselde. Diezelfde Donnarumma moest 
even nadien Lukaku, na een nieuwe tegenaanval, van de 1-0 houden. De match ging goed op en af - 
de Italianen hadden het meeste balbezit, maar de Belgen loerden onder impuls van De Bruyne 
op de counter."""

trained_model = T5ForConditionalGeneration.from_pretrained(hyperparameters["output_dir"])

input_ids = tokenizer(text, return_tensors="pt").input_ids
 
generated_ids = trained_model.generate(input_ids, do_sample=True, 
    max_length=50, 
    top_k=0, 
    temperature=0.7
)

summary = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)
print(summary)

over Dat met ook Dat3 jaar het즌4 op het over aan voor op24 heeft- werddr jaar op voor overo een voor kan zich4 Dat op4 d
