<a href="https://colab.research.google.com/github/NamitMani/NLP-BERT_Variants/blob/main/Summarization_T5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Summarization
## This notebook outlines the concepts behind finetuning a Summarization model using T-5 BERT variant model

In [52]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [1]:
import torch
torch.cuda.empty_cache()

In [2]:
! pip install -q datasets transformers rouge-score nltk

[K     |████████████████████████████████| 346 kB 21.6 MB/s 
[K     |████████████████████████████████| 4.2 MB 49.6 MB/s 
[K     |████████████████████████████████| 140 kB 58.7 MB/s 
[K     |████████████████████████████████| 1.1 MB 48.6 MB/s 
[K     |████████████████████████████████| 86 kB 6.0 MB/s 
[K     |████████████████████████████████| 86 kB 7.3 MB/s 
[K     |████████████████████████████████| 212 kB 73.0 MB/s 
[K     |████████████████████████████████| 596 kB 67.8 MB/s 
[K     |████████████████████████████████| 127 kB 46.6 MB/s 
[K     |████████████████████████████████| 6.6 MB 9.1 MB/s 
[K     |████████████████████████████████| 94 kB 2.5 MB/s 
[K     |████████████████████████████████| 271 kB 72.5 MB/s 
[K     |████████████████████████████████| 144 kB 78.5 MB/s 
[K     |████████████████████████████████| 112 kB 46.0 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the follo

# Fine-tuning a model on a summarization task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/master/examples/images/summarization.png?raw=1)

We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using the `Trainer` API.

In [29]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`t5-small`](https://huggingface.co/t5-small) checkpoint. 

## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [30]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

Using custom data configuration default
Reusing dataset xsum (/root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)


  0%|          | 0/3 [00:00<?, ?it/s]

In [31]:
from datasets.dataset_dict import DatasetDict
train_dataset = raw_datasets["train"].shuffle(seed=42).select(range(2000))
val_dataset = raw_datasets["validation"].shuffle(seed=42).select(range(1000))
test_dataset = raw_datasets["test"].shuffle(seed=42).select(range(1000))
raw_datasets = DatasetDict({"train":train_dataset,"validation":val_dataset, "test":test_dataset})

Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934/cache-cdf5497ad00285d3.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934/cache-fd6a73959df5c803.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934/cache-5ac33534bf69a8ea.arrow


The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [32]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 1000
    })
})

To access an actual element, you need to select a split first, then give an index:

In [33]:
raw_datasets["train"][0]

{'document': 'In Wales, councils are responsible for funding and overseeing schools.\nBut in England, Mr Osborne\'s plan will mean local authorities will cease to have a role in providing education.\nAcademies are directly funded by central government and head teachers have more freedom over admissions and to change the way the school works.\nIt is a significant development in the continued divergence of schools systems on either side of Offa\'s Dyke.\nAnd although the Welsh Government will get extra cash to match the money for English schools to extend the school day, it can spend it on any devolved policy area.\nMinisters have no plans to follow suit.\nAt the moment, governing bodies are responsible for setting school hours and they need ministerial permission to make significant changes.\nThere are already more than 2,000 secondary academies in England and its extension to all state schools is unlikely to shake the Welsh Government\'s attachment to what they call a "community, compr

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [34]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [35]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"The North Tees and Hartlepool NHS Foundation Trust is one of nine in England found to have higher than predicted mortality rates last year.\nThe figures include deaths in hospital, or within 30 days of discharge.\nNorth Tees and Hartlepool NHS Trust medical director, Dr David Emerton, said it was ""reviewing the care of all patients who die"".\nDr Emerton said the figures - higher than expected for the second year running - were skewed by the number of patients treated without being admitted to hospital.\nSome patients, such as those in care homes nearing the end of their lives, where sent to hospital ""when not much really can be done"", he said.\nThe Health and Social Care Information Centre (HSCIC), who compiled the data, says the categorisation does not mean a hospital is failing or unsafe and does not take levels of social deprivation into account.\nIts figures should ""instead should be viewed as a 'smoke alarm' which requires further investigation by the trust"", a spokeswoman said.\nThe SHMI is available from April 2010 and the latest data covers July 2013 to June 2014.\nThe expected risk of the patient dying in hospital, or within 30 days of discharge, is estimated based on the patient's condition, age, sex and how they were admitted to hospital.\nDr Mike Smith, from the Patients Association, said the figures should be viewed in the context of the area's poverty and unemployment rate.\n""If people have, locally, a high mortality and morbidity rate, that means to say people are dying more than the rest of the nation and/or are sicker than the rest of the nation, then you mustn't necessarily think it's the hospital's fault,"" he said.","More people than expected have died at a Teesside hospital trust, latest figures show.",31035720
1,"Sir Kenneth Calman, who was before the House of Lords Constitution Committee, also said he believed such a document could define what it was to be British.\nHe told members that not enough had been done to ""articulate"" the case for the Union.\nIn 2009, the Calman Commission recommended that Scotland take charge of half the income tax it raised.\nHis report, which included input from the three main Unionist parties but not the SNP, came five years before the Scottish independence referendum in September 2014.\nI don't think [the case for the Union] has been well enough articulated - it is still a case of 'them at Westminster'\nSir Kenneth appeared before the committee - which has no SNP members - alongside Sir Paul Silk, former chairman of the Commission on Devolution in Wales.\nCrossbench peer, Lord Judge, asked if there was merit in creating a Charter of the Union.\nSir Kenneth replied: ""The answer is yes. I think it would make the discussion easier about who does what, and how we benefit, how you benefit and how we all benefit. I think that would be very helpful.\n""Without that communication, the average citizen is left wondering who does what for me now - and it is Westminster's fault anyway - so, it is easier to put it the wrong way.""\nSir Kenneth was also asked, by Labour peer Lord Morgan, whether such a charter should include suggestions on what it was to be British.\nThe former chief medical officer of Scotland said: ""I think that is a good point, that seems to be part of the [means to an] end, what does it mean to be a citizen of this country and what are our values.\n""I can see that a group could look at a range of these issues, articulating, communicating them and seeing if people want to be part of it.""\nThe former commission chief added: ""I don't think [the case for the Union] has been well enough articulated - it is still a case of 'them at Westminster', as opposed to 'Westminster is here to help you'.\n""We need to show that it [the Union] is worthwhile, not only that Scotland has something worthwhile to contribute to the UK parliament and the UK in general, but that is what the UK is for.""","The man who led a review of Scottish devolution has given his backing to a so-called ""Charter of the Union"".",34591295
2,"Danny Healy-Rae told the Irish Times that issues with the N22 were caused by ""numerous fairy forts in the area"".\nThe road had previously been repaired but problems had reappeared.\nMr Healy-Rae said he shared local belief that ""there was something in these places you shouldn't touch"".\nHe added that the road passed through an area that was rich in fairy folklore and magic.\nThe N22 is the main road between Killarney in County Kerry and Cork.\nIn Irish folklore, it is believed that disturbing areas, said to have strong connections to fairies, could bring bad luck or a curse.\nThese areas include fairy forts, also known as raths or lios, which are the remains of hillforts or ancient circular dwellings, and fairy trees or thorn bushes.\nSome people believe that destroying or tampering with these forts, trees or bushes, could lead to them dying young or becoming seriously ill.\nMr Healy-Rae, an independent TD (Irish member of parliament) for County Kerry, said: ""I have a machine standing in the yard right now. And if someone told me to go out and knock a fairy fort or touch it, I would starve first.""\nThe issue was raised at Kerry County Council, where Mr Healy-Rae's daughter, Maura, is a councillor, last week.\nShe told a council meeting that her father was convinced fairies were in the area of the road problems.\nMr Healy-Rae also raised the issue at Kerry County Council in 2007 when he was a councillor, asking if a dip in the N22 near Curraglass was caused by ""fairies at work"".\nThe Irish Times reports that the council's road department replied that it was due to a ""deeper underlying subsoil/geotechnical problem"".\nMr Healy-Rae, whose brother Michael is also a TD, has previously hit the headlines for comments in which he denied any human impact on climate change and said that ""God above"" controlled the weather.","Bad luck caused by disturbed fairy forts is causing dips in a major road between County Kerry and County Cork, an Irish member of parliament has said.",40863737
3,"Northumbria Police said that out of a crowd of 52,000 there were a total of 20 arrests.\nA spokesman said: ""But both sets of fans showed that they can enjoy the passion of this game without the poison that has blighted some of the past meetings between our two clubs.""\nSunday's game at St James' Park ended in a 1-1 draw.\nThe arrests - including four inside the ground - were for offences including drunk and disorderly, breach of the peace, throwing missiles, encroaching on a football pitch, and obstructing police\nCh Supt Steve Neill said the force had worked closely with both football clubs, the local authorities, British Transport Police and the Tyne and Wear Metro.\nHe said: ""We understand how important this fixture is to the people of this region and we recognised our policing plan needed to reflect that.\n""There is a lot of attention on this fixture and we want to thank the footballing community for working with us to make it a derby we can all be proud of.""","Fans attending the Tyne-Wear derby have been praised by police for ""behaving impeccably"" at a ""very tense occasion"".",35860630
4,"A cub escapes deep snow by hitching a ride on its mother's backside in Wapusk National Park, Manitoba, Canada.\nTaken by Daisy Gilardini, from Switzerland, the photo is one of 25 shortlisted for the People's Choice Award in the latest Wildlife Photographer of the Year Competition - on show now at the Natural History Museum in London.\nScroll down to see all 25 images, pre-selected by the museum from almost 50,000 submissions from 95 countries.\nA mother's hand\nAlain Mafart Renodier, France\nAlain Mafart Renodier was on a winter visit to Japan's Jigokudani Snow Monkey Park when he took this photograph of a sleeping baby Japanese macaque, its mother's hand covering its head protectively.\nOpportunistic croc\nBence Mate, Hungary\nAlthough this shot was taken from a safe hide, Bence Mate says it was chilling to see the killing eyes of this 4m (13ft) Nile crocodile. This one had been baited with natural carcasses on an island in the Zimanga Private Game Reserve, South Africa, but crocodiles also come here just to bask in the Sun.\nThe stare of death\nJohan Kloppers, South Africa\nJohan Kloppers saw this little wildebeest shortly after it was born in the Kgalagadi Transfrontier Park, South Africa. Little did he know that he would witness its death later that same day. The small herd of wildebeest walked right past a pride of lions, and the calf was caught by a lioness and then taken by this male lion.\nMonkey ball\nThomas Kokta, Germany\nCold temperatures on Shodoshima Island, Japan, sometimes lead to monkey balls, where a group of five or more snow monkeys huddle together to keep warm. Thomas Kokta climbed a tree to get this image.\nFacing the storm\nGunther Riehle, Germany\nGunther Riehle arrived at the sea-ice in Antarctica in sunshine, but by the evening a storm had picked up - and then came snow. He concentrated on taking images of the emperor penguin chicks huddled together to shield themselves.\nGhostly snow geese\nGordon Illg, US\nThese snow geese almost seemed like ghosts in the pink early morning light as they landed among sandhill cranes in the Bosque del Apache National Wildlife Refuge, New Mexico, US.\nSisters\nBernd Wasiolka, Germany\nBernd Wasiolka encountered a large lion pride at a waterhole in the Kgalagadi Transfrontier Park, South Africa. One of the two males spray-marked the branches of a nearby tree. Later two females sniffed the markings and for a brief moment both adopted the same posture.\nInto the fray\nStephen Belcher, New Zealand\nStephen Belcher spent a week photographing golden snub-nosed monkeys in a valley in the Zhouzhi Nature Reserve in the Qinling Mountains, China. The monkeys have very thick fur, which they need to withstand the freezing nights in winter. This image shows two males about to fight, one already up on a rock, the other bounding in with a young male.\nHead-on\nTapio Kaisla, Finland\nTapio Kaisla took a trip to Dovrefjell-Sunndalsfjell National Park, Norway, to find these oxen in their natural habitat. Even though spring is not rutting season for these animals, they were already seriously testing their strength against each other. The air rang out with the loud bang of the head-on collision.\nColorado red\nAnnie Katz, US\nIt was a crisp, clear day in January when Annie Katz saw this Colorado red fox hunting in her neighbour's field in Aspen, Colorado, US. The light was perfect, and she took the photo as the fox approached her, looking right into the lens of her camera.\nThe couple\nSergio Sarta, Italy\nDuring a dive off the coast of Tulamben, Bali, Indonesia, Sergio Sarta saw a bright-coloured organism - a fire urchin with an elegant couple of little Coleman shrimps. The fire urchin has quills that are very toxic to humans - the shrimps avoid this danger by seeking out safe areas between the quills.\nJelly starburst\nAndrea Marshall, US\nAndrea Marshall was snorkelling off the coast of Mozambique when she came across hundreds of large jelly-fish. Many were covered with brittle stars - opportunistic riders, taking advantage of this transport system to disperse along the coast. Delicate lighting makes the jelly glow, so the viewer can focus on the subtle colours and textures.\nThe stand-off\nMichael Lambie, Canada\nIt was breeding season and all the male turkeys were putting on a show for the females, but a number of birds seemed a little confused. This one was more concerned with the potential suitor in front of it, not realising it was its own reflection.\nInto the night\nKarine Aigner, US\nDuring the summer months, 20 million Mexican free-tailed bats arrive at Bracken Cave in San Antonio, Texas, US, to give birth and raise their young. Each evening at dusk, the hungry mothers emerge into the night in a vortex, circling out through the entrance and rising into the sky to feed on insects.\nWillow up close\nDavid Maitland, UK\nDavid Maitland photographed the crystallised chemical salicin, which comes from willow tree bark. Salicin forms the basis of the analgesic Aspirin - no doubt this is why some animals seek out willow bark to chew on.\nThe blue trail\nMario Cea, Spain\nThe kingfisher frequented this natural pond every day, and Mario Cea used a high shutter speed with artificial light to photograph it. He used several units of flash for the kingfisher and a continuous light to capture the wake as the bird dived down towards the water.\nEye in focus\nAlly McDowell, US/UK\nAlly McDowell often focuses on colours and patterns underwater - and this is the eye of a parrotfish during a night dive.\nSpiral\nMarco Gargiulo, Italy\nSabella spallanzanii is a species of marine polychaete, also known as a bristle worm. The worm secretes mucus that hardens to form a stiff, sandy tube that protrudes from the sand. It has two layers of feeding tentacles that can be retracted into the tube, and one of the layers forms a distinct spiral.\nEye contact\nGuy Edwardes, UK\nThe Dalmatian pelican, seen here on Lake Kerkini, Greece, is the largest species of pelican in the world. It is native to eastern Europe, Russia and Asia. However, its population is currently threatened in some areas from hunting, water pollution and habitat loss, particularly a decline in wetlands.\nConfusion\nRudi Hulshof, South Africa\nRudi Hulshof wanted to capture the uncertainty of the future of the southern white rhino in the Welgevonden Game Reserve, South Africa, because of poaching. He anticipated the moment when these two rhinos would walk past each other, creating this silhouette effect and the illusion of a two-headed rhino.\nTasty delicacy\nCristobal Serrano, Spain\nThe natural world provides countless magical moments, none more so than the delicate moment a tiny, elegant hummingbird softly inserts its slender bill into the corolla of a flower to drink nectar. Cristobal Serrano was lucky enough to capture that exact moment in Los Quetzales National Park, Costa Rica.\nBreakfast time\nCari Hill, New Zealand\nShortly after purchasing the Giraffe Manor in Nairobi, Kenya, the owners learned that the only remaining Rothschild's giraffes in the country were at risk, as their sole habitat was being subdivided into smallholdings. So they began a breeding programme to reintroduce the Rothschild's giraffe into the wild. Today, guests can enjoy visits from resident giraffes in search of a treat.\nCaterpillar curl\nReinhold Schrank, Austria\nReinhold Schrank was at Lake Kerkini, Greece, taking pictures of birds, but the conditions were not ideal, so he looked for other options. He saw this caterpillar on a flower and encouraged it on to a piece of rolled dry straw. He had to work fast because the caterpillar was constantly moving.\nRainbow wings\nVictor Tyakht, Russia\nThe bird's wing acts as a diffraction grating - a surface structure with a repeating pattern of ridges or slits. The structure causes the incoming light rays to spread out, bend and split into spectral colours, producing this shimmering rainbow effect.\nVote for the People's Choice Award here before 10 January 2017.\nThe exhibition runs until 10 September 2017.\nTop image: Hitching a ride - by Daisy Gilardini, Switzerland.\nA female polar bear and her cub in Wapusk National Park, Manitoba, Canada.","Wildlife Photographer of the Year is developed and produced by the Natural History Museum, London.",38083691


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [36]:
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_aggregator: Return aggregates if this is set to True
Retu

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [37]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [38]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/t5-small/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fe501e8fd6425b8ec93df37767fcce78ce626e34cc5edc859c662350cf712e41.406701565c0afd9899544c1cb8b93185a76f00b31e5ce7f6e18bbaef02241985
Model config T5Config {
  "_name_or_path": "t5-small",
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "

By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [39]:
tokenizer("Hello, this one sentence!")

{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [40]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [41]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [42]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [43]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [44]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[21603, 10, 86, 10256, 6, 6098, 7, 33, 1966, 21, 3135, 11, 12162, 53, 2061, 5, 299, 16, 2789, 6, 1363, 411, 7, 12940, 31, 7, 515, 56, 1243, 415, 5779, 56, 18682, 12, 43, 3, 9, 1075, 16, 1260, 1073, 5, 30358, 7, 33, 1461, 11264, 57, 2069, 789, 11, 819, 3081, 43, 72, 4333, 147, 7209, 7, 11, 12, 483, 8, 194, 8, 496, 930, 5, 94, 19, 3, 9, 1516, 606, 16, 8, 2925, 12355, 122, 1433, 13, 2061, 1002, 30, 893, 596, 13, 4395, 9, 31, 7, 12991, 1050, 5, 275, 2199, 8, 22982, 3141, 56, 129, 996, 1723, 12, 1588, 8, 540, 21, 1566, 2061, 12, 4285, 8, 496, 239, 6, 34, 54, 1492, 34, 30, 136, 20, 4571, 162, 26, 1291, 616, 5, 3271, 7, 43, 150, 1390, 12, 1130, 3237, 5, 486, 8, 798, 6, 3, 19585, 5678, 33, 1966, 21, 1898, 496, 716, 11, 79, 174, 6323, 23, 138, 6059, 12, 143, 1516, 1112, 5, 290, 33, 641, 72, 145, 3, 8630, 6980, 3, 9, 6615, 2720, 7, 16, 2789, 11, 165, 4924, 12, 66, 538, 2061, 19, 9909, 12, 8944, 8, 22982, 3141, 31, 7, 11352, 12, 125, 79, 580, 3, 9, 96, 18782, 485, 6, 3452, 825, 121

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [45]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [46]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

loading configuration file https://huggingface.co/t5-small/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fe501e8fd6425b8ec93df37767fcce78ce626e34cc5edc859c662350cf712e41.406701565c0afd9899544c1cb8b93185a76f00b31e5ce7f6e18bbaef02241985
Model config T5Config {
  "_name_or_path": "t5-small",
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_lengt

Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

To instantiate a `Seq2SeqTrainer`, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [47]:
batch_size = 4
args = Seq2SeqTrainingArguments(
    "test-summarization",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the cell and customize the weight decay. Since the `Seq2SeqTrainer` will save the model regularly and our dataset is quite large, we tell it to make three saves maximum. Lastly, we use the `predict_with_generate` option (to properly generate summaries) and activate mixed precision training (to go a bit faster).

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:

In [48]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [49]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [50]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp half precision backend


We can now finetune our model by just calling the `train` method:

In [53]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: document, id, summary. If document, id, summary are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2000
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 500


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,3.0584,2.828076,19.5112,3.4724,15.2733,15.3328,18.677


Saving model checkpoint to test-summarization/checkpoint-500
Configuration saved in test-summarization/checkpoint-500/config.json
Model weights saved in test-summarization/checkpoint-500/pytorch_model.bin
tokenizer config file saved in test-summarization/checkpoint-500/tokenizer_config.json
Special tokens file saved in test-summarization/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: document, id, summary. If document, id, summary are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 4


Epoch,Training Loss,Validation Loss




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=500, training_loss=3.058385986328125, metrics={'train_runtime': 191.1357, 'train_samples_per_second': 10.464, 'train_steps_per_second': 2.616, 'total_flos': 429342259150848.0, 'train_loss': 3.058385986328125, 'epoch': 1.0})

In [None]:
# !pip install -q GPUtil

# import torch
# from GPUtil import showUtilization as gpu_usage
# from numba import cuda

# def free_gpu_cache():
#     print("Initial GPU Usage")
#     gpu_usage()                             

#     torch.cuda.empty_cache()

#     cuda.select_device(0)
#     cuda.close()
#     cuda.select_device(0)

#     print("GPU Usage after emptying the cache")
#     gpu_usage()

# free_gpu_cache() 