# T5 Overview
T5 was introduced in the paper [_Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer_](https://arxiv.org/abs/1910.10683). In that paper, authors provided a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.



##  A Shared Text-To-Text Framework

With T5, authors propose reframing all NLP tasks into a unified text-to-text-format where the input and output are always text strings, in contrast to BERT-style models that can only output either a class label or a span of the input. This text-to-text framework allows us to use the same model, loss function, and hyperparameters on any NLP task, including machine translation, document summarization, question answering, and classification tasks (e.g., sentiment analysis). T5 can even be applied to regression tasks by training it to predict the string representation of a number instead of the number itself [source](https://ai.googleblog.com/2020/02/exploring-transfer-learning-with-t5.html).

<img src="https://1.bp.blogspot.com/-o4oiOExxq1s/Xk26XPC3haI/AAAAAAAAFU8/NBlvOWB84L0PTYy9TzZBaLf6fwPGJTR0QCLcBGAsYHQ/s1600/image3.gif" width="700" height="300" />

<font color="grey">Diagram of our text-to-text framework. Every task we consider uses text as input to the model, which is trained to generate some target text. This allows us to use the same model, loss function, and hyperparameters across our diverse set of tasks including translation (green), linguistic acceptability (red), sentence similarity (yellow), and **document summarization (blue)**. </font>

## mT5: Multilingual T5

Multilingual T5 (mT5) is a massively multilingual pretrained text-to-text
transformer model, trained following a similar recipe as
[T5](https://github.com/google-research/text-to-text-transfer-transformer).


## Languages covered

mT5 is pretrained on the [mC4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual_nights_stay) corpus, covering 101 languages:

Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque,
Belarusian, Bengali, Bulgarian, Burmese, Catalan, Cebuano, Chichewa, Chinese,
Corsican, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino,
Finnish, French, Galician, Georgian, German, Greek, Gujarati, Haitian Creole,
Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian,
Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish,
Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy,
Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Nepali, Norwegian,
Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan,
Scottish Gaelic, Serbian, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali,
Sotho, Spanish, Sundanese, **Swahili**, Swedish, Tajik, Tamil, Telugu, Thai,
Turkish, Ukrainian, Urdu, Uzbek, Vietnamese, Welsh, West Frisian, Xhosa,
Yiddish, Yoruba, Zulu.

# Tutorial

# Instruction tuning mT5 for Swahili

## Task: Instruct mT5 to summarize Swahili content

We use [**"XL-Sum: Large-Scale Multilingual Abstractive Summarization for 44 Languages"**](https://aclanthology.org/2021.findings-acl.413/) to evaluate summarization performance in Swahili

In [2]:
!pip install datasets transformers[sentencepiece] evaluate rouge-score nltk -q

/bin/bash: -c: line 1: syntax error near unexpected token `'punkt''
/bin/bash: -c: line 1: `nltk.download('punkt')'


In [102]:
import transformers
import datasets
import random
import pandas as pd
import numpy as np
from datasets import load_dataset, concatenate_datasets
from datasets import Dataset
from evaluate import load
import torch
from IPython.display import display, HTML
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import nltk

# This is required for summarization rouge metric
nltk.download('punkt')
metric = load("rouge")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [174]:
model_checkpoint = "google/mt5-small"
max_input_length = 1024
max_target_length = 128
per_device_train_batch_size = 2
per_device_eval_batch_size = 16
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [140]:
xlsum_swahili = load_dataset("csebuetnlp/xlsum", "swahili")

In [141]:
def show_samples(dataset, num_examples=5):
    df = pd.DataFrame(dataset[:num_examples])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [143]:
show_samples(xlsum_swahili["validation"], 1)

Unnamed: 0,id,url,title,summary,text
0,51950398,https://www.bbc.com/swahili/51950398,Virusi vya corona: Maswali yako kuhusu ugonjwa wa corona yanajibiwa,"Mlipuko wa maambukizi ya corona umeathiri mataifa mbalimbali duniani, na watu wakiwa wanajiuliza swali kuu moja: Maambukizi ya ugonjwa huu yakoje?","Wahudumu wa afya Ninawezaje kujilinda na namna gani virusi hivi vinasambaa? Maswali kadhaa yamekuwa yakiulizwa na wananchi kutoka maeneo mbalimbali duniani. Je, barakoa(mask)inasaidia kukinga maambukizi ya virusi vya corona Watu wengi wameanza kuvaa barakoa Kuna ushahidi mdogo sana kuwa barakoa inasaidia kwa namna moja au nyingine. Wataalamu wanasema kuwa usafi wa mara kwa mara wa watu kuosha mikono na kutojigusa mdomoni kunasaidia zaidi. Je, virusi vya corona vinaweza kupatikana katika vitasa vya milango na vinaweza kukaa kwa muda gani? Maeneo mengi wameongeza kasi ya kufanya usafi Kama mtu ana maambukizi na akikohoa katika mkono wake na baadae kushika kitu , je kitu hicho kinaweza kupata maambukizi. Vitasa vya mlango ni mfano mzuri zaidi kuwa kuna hatari kubwa ya kupata maambukizi kama mtu mwenye corona akishika mlango wakati alikoholea mkono wake. Wataalamu wanadhani kuwa virusi vya corona vinaweza kukaa kwa muda wa siku kadhaa. Hivyo namna nzuri ya kukabiliana na jambo hili ili kupunguza hatari ya kupata maambukizi ni kunawa mikono mara kwa mara. Je, nikikutana kingono naweza kupata maambukizi? Haijawa wazi kama watu wakikutana kimwili wanakuwa katika hatari ya kupata maambukizi ya corona. Kwa sasa ni kukohoa na kupiga chapya ndio mambo yanayotajwa kuwa hatari katika maambukizi. Kuna utofauti gani kati ya corona na mafua? Dalili za maambukizi ya virusi vya corona na mafua yanafanana kwa kiasi kikubwa, hivyo inafanya tiba kuwa ngumu bila kupimwa . Dalili za virusi vya corona vinaweza kuanza kwa homa na kukohoa. Mafua mara nyingi huwa yana dalili nyingine kama koo kuwasha, huku watu wenye virusi vya corona huwa wanaweza kuishiwa pumzi kidogo. Je, virusi vya corona vinaambukiza zaidi ya mafua? Ni mapema mno kuweza kulinganisha lakini virusi vyote vinaambukiza. Kwa wastani virusi vya corona vinaweza kuambukiza watu wawili au watatu huku virusi vya mafua huwa ni kama vinatoka kwa mmoja kwenda kwa mwingine. Ingawa maambukizi yote ya mafua na corona huwa yanasambaa kwa haraka. Je,mtu anaweza kupata virusi vya corona kwa kula chakula kilichoandalia na mtu mwenye maambukizi ya corona? - Mtu mwenye maambukizi ya corona kama amepika bila kuzingatia usafi basi moja kwa moja anaweza kumuambukiza mtu mwingine virusi vya corona. Virusi vya corona vinaweza kusambaa kwa matone ya kikohozi yaliyo kwenye mkono. Kuosha mikono kabla ya kushika na kula chakula ndio ushauri unaotakiwa kuuzingatia. Nichukue tahadhari gani? Kwa watu wanaoishi Italia na maelfu ambao wanasafiri maeneo mbalimbali duniani, wako kwenye hatari ya maambukizi zaidi. Maambukizi yanasambaa kutoka kwa mtu mmoja kwenda kwa mwingine, kwa njia ya matone ya kikohozi. Ni muhimu kwa watu kuzingatia kuosha mikono mara kwa mara kwa maji yanayotiririka na sabuni au dawa ya kuosha mikono(sanitiser). Ni vyema kujizuia kuwa karibu na mtu anayekohoa au mwenye homa. Mtu yeyote anayedhani kuwa amepata maambukizi ya corona ni bora kupigia simu daktari. Baadhi ya mataifa yamepiga marufuku kwa watu kukutana katika mikusanyiko Inawezekana ugonjwa huu kupata chanjo? Kwa sasa hakuna chanjo ya kujikinga na virusi vya corona, ingawa wanasayansi bado wanapambana kutengeneza chanjo ya aina hiyo. Hivi ni virusi vipya ambavyo havijawahi kumpata binadamu kabla. Je, mabadiliko ya tabia nchi yanachangia athari zinazojitokea katika virusi vya corona? Haijawekwa wazi kama mabadiliko ya hali ya hewa yanachangia maambukizi. Baadhi ya virusi kama vya mafua huwa vinakuja mara nyingi katika wakati wa baridi kali. Je, mtu aliyeugua ugonjwa wa corona anaweza kupata maambukizi tena? Virusi hivi vipya vya corona , vinaweza kukusababisha mtu uugue na wengi ni watu wenye shida ya mapafu. Lakini virusi hivi vipya hakuna ambaye ana kinga dhidi yake. Hivyo haijalishi kama uliugua mwanzo au la. Aidha shirika la afya duniani limesema kuwa kabla ya miezi 18 kupita chanjo ya corona itakuwa imepatikana."


In [144]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/82.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [145]:
def preprocess_xlsum(examples):
    inputs = [f'Summarize the follow text:\n{text}'
              for text in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [147]:
tokenized_xlsum_swahili = xlsum_swahili.map(preprocess_xlsum, batched=True)

Map:   0%|          | 0/7898 [00:00<?, ? examples/s]

Map:   0%|          | 0/987 [00:00<?, ? examples/s]

Map:   0%|          | 0/987 [00:00<?, ? examples/s]

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [149]:
inputs = torch.tensor(tokenized_xlsum_swahili['validation'][:1]['input_ids'])
tokenizer.decode(model.cpu().generate(inputs, max_length=max_target_length)[0])

'<pad> <extra_id_0> ya corona.</s>'

# Can we leverage an existing high quality instruction dataset for this task?

Note that such dataset are commonly only available **in Enligsh**



# Lets start with Databricks' Dolly-15k dataset

Databricks-dolly-15k is an open source dataset of English instruction-following records generated by thousands of Databricks employees in several of the behavioral categories outlined in the InstructGPT paper, including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.

In [189]:
dolly_english = load_dataset("databricks/databricks-dolly-15k")

In [151]:
show_samples(dolly_english["train"], 2)

Unnamed: 0,instruction,context,response,category
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.","Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.",closed_qa
1,Which is a species of fish? Tope or Rope,,Tope,classification


In [156]:
def preprocess_dolly(examples):
    inputs = []
    targets = []
    for instruction, context in zip(examples["instruction"], examples["context"]):
        if len(context) > 0:
          inputs.append(f'{instruction}\nContext: {context}')
        else:
          inputs.append(instruction)

    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["response"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [157]:
tokenized_dolly_english = dolly_english.map(preprocess_dolly, batched=True)

Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

Map:   0%|          | 0/12125 [00:00<?, ? examples/s]

# Fine-tuning mT5 by using the English Dolly dataset

In [160]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [162]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [177]:
args = Seq2SeqTrainingArguments(
    #evaluation_strategy = "steps",
    do_eval=False,
    #eval_steps=10,
    learning_rate=2e-5,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    max_steps=100,
    predict_with_generate=True,
    generation_max_length=128,
    fp16=True,
    push_to_hub=False,
)

trainer_en = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_dolly_english["train"],
    eval_dataset=tokenized_xlsum_swahili["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [137]:
# Evaluate the pretrained model
trainer_en.evaluate(max_length=128)

{'eval_loss': nan,
 'eval_rouge1': 0.0,
 'eval_rouge2': 0.0,
 'eval_rougeL': 0.0,
 'eval_rougeLsum': 0.0,
 'eval_gen_len': 0.0,
 'eval_runtime': 268.535,
 'eval_samples_per_second': 3.675,
 'eval_steps_per_second': 0.231}

In [None]:
# fine-tune the model
trainer_en.train()

# Evaluate the fine-tuned model
trainer_en.evaluate(max_length=128)

# Fine-tuning mT5 by using the translated Swahili (Dolly) dataset

In [180]:
dolly_swahili_df = pd.read_excel("translated_ds.xlsx")
dolly_swahili_df.head(2)

Unnamed: 0,task_id,INPUT:context_tr,INPUT:context_src,INPUT:response_tr,INPUT:response_src,INPUT:instruction_tr,INPUT:instruction_src,toloka probabilities
0,000287b55d--656f562fa7ccfa2fa62cbad5,"""I'm So Excited"" ni wimbo wa mwimbaji wa Aust...","""I'm So Excited"" is a song by Australian singe...","""I'm So Excited"" ni wimbo wa mwimbaji wa Austr...","""I'm So Excited"" is a song by Australian singe...",Ni nani mwimbaji wa wimbo wa I'm So Excited?,Who is the singer of the song I'm So Excited?,0.988446
1,000287b55d--656f562fa7ccfa2fa62cbb0b,,,Kupanga safari ya kwenda Ulaya ni sawa na kupa...,Planning a trip to Europe is similar to planni...,"Je, nifanyeje kuhusu kupanga safari ya kwenda...",How should I go about planning a trip to Europe?,0.982769


In [182]:
# Load the translated Dolly samples as dataset
dolly_swahili = Dataset.from_pandas(dolly_swahili_df)

# Modify dataset to make consistent with original dolly
dolly_swahili = dolly_swahili.rename_column("INPUT:context_tr", "context")
dolly_swahili = dolly_swahili.rename_column("INPUT:instruction_tr", "instruction")
dolly_swahili = dolly_swahili.rename_column("INPUT:response_tr", "response")

# Remove unused columns
for c in ["INPUT:context_src", "INPUT:instruction_src", "INPUT:response_src", "toloka probabilities", "task_id"]:
    dolly_swahili = dolly_swahili.remove_columns(c)

In [183]:
tokenized_dolly_swahili = dolly_swahili.map(preprocess_dolly, batched=True)

Map:   0%|          | 0/12125 [00:00<?, ? examples/s]

In [None]:
model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
trainer_swa = Seq2SeqTrainer(
    model_2,
    args,
    train_dataset=tokenized_dolly_swahili,
    eval_dataset=tokenized_xlsum_swahili["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# fine-tune the model
trainer_swa.train()

# Evaluate the fine-tuned model
trainer_swa.evaluate(max_length=128)

# Fine-tuning mT5 by using the combined English and Swahili dataset

In [154]:
dolly_combined = concatenate_datasets([dolly_english['train'], dolly_swahili]).shuffle(seed=42)
dolly_combined

Dataset({
    features: ['instruction', 'context', 'response', 'category'],
    num_rows: 27136
})

In [186]:
show_samples(dolly_combined, 2)

Unnamed: 0,instruction,context,response,category
0,"Given this paragraph, where is The Walt Disney Company headquarters?","The Walt Disney Company, commonly known as Disney (/ˈdɪzni/), is an American multinational, mass media and entertainment conglomerate that is headquartered at the Walt Disney Studios complex in Burbank, California. Disney was founded on October 16, 1923, by brothers Walt and Roy O. Disney as Disney Brothers Studio; it also operated under the names Walt Disney Studio and Walt Disney Productions before changing its name to The Walt Disney Company in 1986. Early in its existence, the company established itself as a leader in the animation industry, with the creation of the widely popular character Mickey Mouse, who first appeared in Steamboat Willie, which used synchronized sound, to become the first post-produced sound cartoon. The character would go on to become the company's mascot.","According to this text, The Walt Disney Company is headquartered in Burbank, California.",closed_qa
1,Ni kipindi gani cha msimu wa pili cha Game of Thrones ambacho Neil Marshall alielekeza?,,"Neil Marshall aliongoza kipindi cha 9 cha msimu wa pili wa Game of Thrones, kilichoitwa ""Blackwater""",


In [187]:
tokenized_dolly_combined = dolly_combined.map(preprocess_dolly, batched=True)

Map:   0%|          | 0/27136 [00:00<?, ? examples/s]

In [None]:
model_3 = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
trainer_en_swa = Seq2SeqTrainer(
    model_3,
    args,
    train_dataset=tokenized_dolly_combined,
    eval_dataset=tokenized_xlsum_swahili["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# fine-tune the model
trainer_en_swa.train()

# Evaluate the fine-tuned model
trainer_en_swa.evaluate(max_length=128)