<a href="https://colab.research.google.com/github/Znerual/TensorFlowCodeSnippets/blob/main/examples/summarization-tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets as well as other dependencies. Uncomment the following cell and run it. Note the `rouge-score` and `nltk` dependencies - even if you've used 🤗 Transformers before, you may not have these installed!

In [2]:
! pip install transformers datasets
! pip install rouge-score nltk
! pip install huggingface_hub

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.3-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 27.6 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.4.0-py3-none-any.whl (365 kB)
[K     |████████████████████████████████| 365 kB 56.3 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.9.1-py3-none-any.whl (120 kB)
[K     |████████████████████████████████| 120 kB 27.0 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 54.2 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 75.

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then run the following cell and input your token:

In [3]:
from huggingface_hub import notebook_login

notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default

git config --global credential.helper store[0m


Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:

In [4]:
!apt install git-lfs
!git config --global user.email "laurenz_ruzicka@gmx.at"
!git config --global user.name "Laurenz Ruzicka"

Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.3.4-1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 20 not upgraded.


Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version:

In [4]:
import transformers

print(transformers.__version__)

4.21.3


You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).

# Fine-tuning a model on a summarization task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/main/examples/images/summarization.png?raw=1)

We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using Keras.

In [5]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we pick the [`t5-small`](https://huggingface.co/t5-small) checkpoint. 

## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [6]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

Downloading builder script:   0%|          | 0.00/2.05k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/954 [00:00<?, ?B/s]



Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [7]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

To access an actual element, you need to select a split first, then give an index:

In [8]:
raw_datasets["train"][0]

 'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.',
 'id': '35232142'}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [7]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML


def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(
        dataset
    ), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [8]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"More pupils than ever have achieved the literacy and maths scores needed for secondary school, according to figures from the Department for Education\nFour out of five pupils got good grades in all the tests, says the DfE.\nHowever, Schools Minister Nick Gibb said schools in some council areas had performed poorly.\nThe results of this year's tests, taken in May by all 11-year-old state school pupils, show a one percentage point rise in those meeting the standard in mathematics (to 87%) and two percentage points in writing (to 87%).\nThere was a four percentage point rise in scores in the grammar, punctuation and spelling test (to 80%), while attainment in reading was unchanged on the year before, with 89% meeting the expected standard.\nThe government says 80% of pupils achieved the required ""Level Four"" standard or above in all subjects, compared with 78% in 2014 and 62% in 2009.\nBut Mr Gibb said schools in some local authority areas were still not doing well enough.\nHe announced a ""crack-down"" on councils, including Medway, Poole, Luton, Doncaster and Bedford, whose schools had performed poorly.\nIn these areas 73% of pupils achieved the required standard in all subjects, compared with in Kensington and Chelsea, the strongest performing area, where 90% of pupils met the grade.\nMr Gibb said the government was ""committed to driving up standards as a matter of social justice"".\n""That is why I will be writing to the director of children's services and directors of education of councils that are bottom of the league tables and asking that they meet me as a matter of urgency to explain how they intend to improve the teaching of reading and arithmetic in the primary schools under their control,"" he said.\nOverall, Mr Gibb said, he was ""delighted that 90,000 more children are starting secondary school with a firm grasp of the basics compared to just five years ago"".\nIn particular he highlighted improvements in sponsored primary academies, which have taken over some of the most seriously underperforming schools.\nSponsored academies that had been open a year saw a rise of five percentage points (to 71%) over the schools they replaced, the statistics suggest.\n""These results vindicate our decision to expand the valuable academies programme into primary schools with thousands of children on course to receive a better education,"" he said.\n""Our reform programme is driven by social justice, and we will continue to raise the bar so young people are prepared to succeed in modern Britain.""\nThis year, some 580,000 primary pupils took the tests, but this is the last year these tests will be used.\nFrom next summer, pupils will be assessed on a ""tough"" new national curriculum, which came into effect in September 2014, and will be given a scaled score where 100 will represent the expected standard.","The performance of children in England in tests at the end of primary school has edged upwards, the government has announced.",34074235
1,"Police will be able to seize valuables worth more than 10,000 kroner (1,340 euros; Â£1,000) from refugees to cover housing and food costs.\nMPs also approved plans to delay family reunions for asylum seekers.\nA spokesman for UN chief Ban Ki-moon criticised the decision, saying refugees deserved compassion.\n""People who have suffered tremendously, who have escaped war and conflict, who've literally walked hundreds of kilometres if not more and put their lives at risk by crossing the Mediterranean should be treated with compassion and respect, and within their full rights as refugees,"" said Stephane Dujarric.\nThe bill has been widely criticised by human rights groups.\nWhy are countries seizing refugees' valuables?\nMigrants feel chill of tighter borders\nEurope's migrant crisis\nThe prospect of refugees having possessions seized has drawn comparisons to the confiscation of valuables from Jews during World War Two.\nThe government has said that items of sentimental value, such as wedding rings, will be exempt. It also raised the amount refugees will be allowed to keep from 3,000 kroner to 10,000 following objections.\nThe government has said that the policy brings refugees in line with unemployed Danes, who also face having to sell assets above a certain level to claim benefits.\nHowever, critics have said that many Danes have unemployment insurance that saves them having to sell assets, and anyway would not face the kind of searches proposed under the new law.\nThe new measures also mean the period migrants will have to wait before applying for relatives to join them will be extended from one year to three - a move aimed at discouraging new arrivals.\nTemporary residence permits will be shortened and the conditions for obtaining a permanent permit will be restricted.\nDenmark received more than 21,000 asylum seekers in 2015.\nMPs approved the measures by 81 votes to 27 following a lengthy, and at times angry, debate. One MP abstained and 70 others were absent. The centre-left opposition Social Democrats and the anti-immigration Danish People's Party both voted in favour.\nMartin Henriksen, immigration spokesman for the Danish People's Party, described the numbers of migrants entering Europe as an ""exodus"".\n""More needs to be done. We need more border controls, we need tighter immigration rules,"" he said.\nBut Johanne Schmidt-Nielsen, of the opposition left Red-Green Alliance that opposed the bill, said it was ""a symbolic move to scare people away"".\nPrime Minister Lars Lokke Rasmussen of the centre-right Venstre party had previously shrugged off criticism of the proposals calling them ""the most misunderstood bill in Denmark's history"".\nThe UN refugee agency (UNHCR), the European Commission and other groups have criticised the proposals.\nAmnesty International regional director John Dalhuisen described the vote as ""mean spirited"".\n""This is a sad reflection of how far Denmark has strayed from its historic support of international norms enshrined in the Refugee Convention,"" he said.\nAndreas Kamm, of the Danish Refugee Council, said they were concerned about the new limitations on family reunification.\n""It hampers the integration process for those who already arrived and it leaves alone those who are back in the region, as vulnerable groups,"" he said.\n""It's very worrying and it's very inhumane.""\nDenmark is not the first European country to demand the assets of asylum seekers.\nEarlier this month, Switzerland was criticised by a refugee group for seizing assets from some 100 people in 2015. Under Swiss rules, asylum seekers have to hand over assets above $1,000 (Â£700; â‚¬900).",The Danish parliament has backed a controversial proposal to confiscate asylum seekers' valuables to pay for their upkeep.,35406436
2,"The US software giant paid $7.2bn (Â£5.5bn) for Nokia's handset business in 2014, but failed to make a success of new devices.\nIn May, Windows-powered smartphones accounted for fewer than 1% of global smartphone sales.\nOne industry analyst suggested the firm was too late to the market.\n""They spent all that money because they recognised that the smartphone market was important,"" said Eddie Murphy, telecoms analyst at Priory Consulting.\n""They were right - but just too late. Apple and Android devices have dominated the space and Windows hasn't made an impact.""\nOne of the problems Microsoft faced was the so-called ""app gap""- a shortage of popular titles appearing on its smartphones. The problem also blighted Blackberry's BB10 operating system.\n""It was a tremendous problem,"" said Mr Murphy. ""I have a lot of sympathy because I have a Windows phone and the number of apps is very small in comparison to Android. That was a real disadvantage for Microsoft.""\nThe job losses were initially announced in May as part of a plan to close 1,850 posts in Microsoft's smartphone business.\nThe firm's latest operating system Windows 10 can still be used to power smartphones and in February computing giant HP announced a smartphone running Windows 10 that could transform into a desktop PC.\n""I think they have looked at what Google did with Android. The dominance of the Android platform is because it is open and other companies can use it,"" said Mr Murphy.\n""One of the things Microsoft has done is introduce the common Windows 10 platform, that allows apps to work across desktop, tablet and mobile. It's a good idea and hopefully one that will generate some new apps for the platform.\n""I hope we haven't seen the last of Windows Phones. Having more platforms in the market is in the best interests of consumers.""","Microsoft has confirmed it will close its mobile phone unit in Finland, cutting 1,350 jobs.",36763904
3,"Jonas Knudsen, who netted in the reverse fixture in August, met Jordan Spence's second-half cross to put Ipswich ahead against the run of play.\nBut winger Murphy fired a 20-yard shot past Bartosz Bialkowski at his near post to level with 21 minutes left.\nBialkowski later made amends with a stunning stop to deny Alex Tettey.\nNorwich moved up a place to eighth but Alex Neil's side are six points off the Championship play-off places, while Ipswich remained 15th with their fourth draw in five outings.\nBoth derbies this season have ended as 1-1 draws, while Norwich have still not been beaten by their local rivals since April 2009.\nReferee Oliver Langford had been the centre of attention before the two goals, refusing to give Ipswich a penalty when David McGoldrick appeared to have been dragged to the ground at a corner and then disallowing a goal from Norwich full-back Mitchell Dijks for handball.\nThe Canaries pushed for a winner late on but Bialkowski kept out Tettey's powerful 18-yard drive and also a low shot from substitute Alex Pritchard.\nNorwich boss Alex Neil told BBC Radio Norfolk:\n""I think if you look at pretty much every statistic connected with the game or if you watched the game, I think we were deserving of a win.\n""I think the frustrating thing from us was the fact that we've been very good going forward and scoring goals and we just lacked that little bit of a cutting edge that we've had in most games this season.\n""Even when we did threaten to score, their keeper pulled off probably three really good saves to keep the game level.""\nIpswich boss Mick McCarthy told BBC Radio Suffolk:\n""They've had to be really well-organised and resolute and tough and all those horrible things that people don't like when you're getting beat, but when you go to Norwich and get a point everybody respects those qualities.\n""(Bartosz Bialkowski) kept them out, he's been brilliant, and he's annoyed that he's let one in - he was just outstanding.\n""What was good is that we didn't concede again, and maybe eight weeks ago we'd have lost that game.\n""Against a good team, it's just great that we came away with a point.""\nMatch ends, Norwich City 1, Ipswich Town 1.\nSecond Half ends, Norwich City 1, Ipswich Town 1.\nAttempt missed. Alex Pritchard (Norwich City) left footed shot from outside the box is high and wide to the left following a corner.\nCorner, Norwich City. Conceded by Tommy Smith.\nAttempt missed. Timm Klose (Norwich City) right footed shot from the centre of the box misses to the left. Assisted by Alex Pritchard with a cross following a corner.\nCorner, Norwich City. Conceded by Tommy Smith.\nOffside, Norwich City. Alex Pritchard tries a through ball, but Cameron Jerome is caught offside.\nAttempt saved. Mitchell Dijks (Norwich City) left footed shot from outside the box is saved in the centre of the goal. Assisted by Alexander Tettey.\nJordan Spence (Ipswich Town) is shown the yellow card for a bad foul.\nJosh Murphy (Norwich City) wins a free kick on the left wing.\nFoul by Jordan Spence (Ipswich Town).\nAttempt missed. Cameron Jerome (Norwich City) header from the right side of the six yard box is high and wide to the right. Assisted by Jacob Murphy with a cross.\nCorner, Norwich City. Conceded by Bartosz Bialkowski.\nAttempt saved. Alex Pritchard (Norwich City) right footed shot from the centre of the box is saved in the bottom left corner. Assisted by Mitchell Dijks.\nAttempt missed. Josh Murphy (Norwich City) right footed shot from the left side of the box is high and wide to the right. Assisted by Alex Pritchard.\nSubstitution, Ipswich Town. Toumani Diagouraga replaces Emyr Huws.\nSubstitution, Norwich City. Alex Pritchard replaces Wes Hoolahan.\nOffside, Norwich City. Josh Murphy tries a through ball, but Cameron Jerome is caught offside.\nSubstitution, Norwich City. Josh Murphy replaces Steven Naismith.\nAttempt missed. Timm Klose (Norwich City) header from the centre of the box is too high. Assisted by Wes Hoolahan with a cross following a corner.\nCorner, Norwich City. Conceded by Bartosz Bialkowski.\nAttempt saved. Alexander Tettey (Norwich City) right footed shot from the centre of the box is saved in the top centre of the goal.\nAttempt blocked. Alexander Tettey (Norwich City) right footed shot from outside the box is blocked. Assisted by Wes Hoolahan.\nSubstitution, Ipswich Town. Kieffer Moore replaces Freddie Sears.\nAttempt missed. Cameron Jerome (Norwich City) header from the centre of the box misses to the left. Assisted by Mitchell Dijks with a cross.\nDelay over. They are ready to continue.\nDelay in match Jonny Howson (Norwich City) because of an injury.\nGoal! Norwich City 1, Ipswich Town 1. Jacob Murphy (Norwich City) right footed shot from the right side of the box to the bottom right corner. Assisted by Wes Hoolahan.\nAttempt saved. Steven Naismith (Norwich City) left footed shot from the left side of the box is saved in the centre of the goal. Assisted by Jonny Howson.\nJordan Spence (Ipswich Town) wins a free kick on the right wing.\nFoul by Alexander Tettey (Norwich City).\nGoal! Norwich City 0, Ipswich Town 1. Jonas Knudsen (Ipswich Town) header from the centre of the box to the bottom right corner. Assisted by Jordan Spence with a cross.\nAttempt blocked. Freddie Sears (Ipswich Town) right footed shot from the centre of the box is blocked.\nAttempt saved. Cameron Jerome (Norwich City) header from the centre of the box is saved in the centre of the goal. Assisted by Wes Hoolahan.\nFreddie Sears (Ipswich Town) is shown the yellow card for a bad foul.\nIvo Pinto (Norwich City) wins a free kick in the attacking half.\nFoul by Freddie Sears (Ipswich Town).\nSubstitution, Ipswich Town. Tommy Smith replaces Myles Kenlock.\nFoul by Jonny Howson (Norwich City).\nCole Skuse (Ipswich Town) wins a free kick in the defensive half.",Jacob Murphy's equaliser for Norwich City at Carrow Road extended Ipswich Town's wait of almost eight years for a victory in an East Anglian derby.,39018806
4,"It has ordered that the country's Hindu marriage act should be altered to allow irretrievable breakdown of marriage as grounds for divorce.\nUp until now, a divorce would in most cases be granted by the courts only if there were mutual consent.\nCorrespondents say that marriage breakdowns are becoming more common and India's divorce rate is increasing.\nMinister of information Ambika Soni said that the proposed change in the law would help an estranged partner get a divorce ""if any party does not come to court or wilfully avoids the court"".\nLast year the Supreme Court said the judiciary should strive to keep married people together, but it also ruled that couples who had completely split should not be denied a divorce.\nThe latest proposed amendment, passed by a cabinet meeting chaired by Prime Minister Manmohan Singh, will include irretrievable breakdown of marriage as a legal justification for divorce for the first time.\n""In today's day and age it may be a welcome step but it will only really help urban women,"" Kamini Jaiswal, a Supreme Court advocate, told the AFP news agency.\n""Rural women will still get a raw deal as they are more oppressed by their husbands.\n""Divorce is definitely more socially acceptable in urban India,"" she said. ""I have seen a rapid rise in divorces, but in order to obtain a divorce it can take anywhere from six months to 20 years.""\nOfficial figures on the divorce rate are unavailable but experts say that roughly 11 Indian marriages in every 1,000 end in divorce. The rate in the United States is about 400 in every 1,000.",The Indian government has proposed a new law which will make it easier for couples to get divorced.,10284416


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [11]:
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_aggregator: Return aggregates if this is set to True
Retu

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [12]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading config.json:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [14]:
tokenizer("Hello, this is a sentence!")

{'input_ids': [8774, 6, 48, 19, 3, 9, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [15]:
tokenizer(["Hello, this is a sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 19, 3, 9, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [16]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this is a sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 19, 3, 9, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [10]:
if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [11]:
max_input_length = 1024
max_target_length = 128


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["summary"], max_length=max_target_length, truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [12]:
preprocess_function(raw_datasets["train"][:2])

{'input_ids': [[21603, 10, 37, 423, 583, 13, 1783, 16, 20126, 16496, 6, 80, 13, 8, 844, 6025, 4161, 6, 19, 341, 271, 14841, 5, 7057, 161, 19, 4912, 16, 1626, 5981, 11, 186, 7540, 16, 1276, 15, 2296, 7, 5718, 2367, 14621, 4161, 57, 4125, 387, 5, 15059, 7, 30, 8, 4653, 4939, 711, 747, 522, 17879, 788, 12, 1783, 44, 8, 15763, 6029, 1813, 9, 7472, 5, 1404, 1623, 11, 5699, 277, 130, 4161, 57, 18368, 16, 20126, 16496, 227, 8, 2473, 5895, 15, 147, 89, 22411, 139, 8, 1511, 5, 1485, 3271, 3, 21926, 9, 472, 19623, 5251, 8, 616, 12, 15614, 8, 1783, 5, 37, 13818, 10564, 15, 26, 3, 9, 3, 19513, 1481, 6, 18368, 186, 1328, 2605, 30, 7488, 1887, 3, 18, 8, 711, 2309, 9517, 89, 355, 5, 3966, 1954, 9233, 15, 6, 113, 293, 7, 8, 16548, 13363, 106, 14022, 84, 47, 14621, 4161, 6, 243, 255, 228, 59, 7828, 8, 1249, 18, 545, 11298, 1773, 728, 8, 8347, 1560, 5, 611, 6, 255, 243, 72, 1709, 1528, 161, 228, 43, 118, 4006, 91, 12, 766, 8, 3, 19513, 1481, 410, 59, 5124, 5, 96, 196, 17, 19, 1256, 68, 27, 103, 317, 132

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [13]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/205 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is sequence-to-sequence (both the input and output are text sequences), we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [14]:
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

Downloading tf_model.h5:   0%|          | 0.00/231M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

Next we set some parameters like the learning rate and the `batch_size`and customize the weight decay. 

The last two arguments are to setup everything so we can push the model to the [Hub](https://huggingface.co/models) at the end of training. Remove the two of them if you didn't follow the installation steps at the top of the notebook, otherwise you can change the value of push_to_hub_model_id to something you would prefer.

In [15]:
batch_size = 8
learning_rate = 2e-5
weight_decay = 0.01
num_train_epochs = 1

model_name = model_checkpoint.split("/")[-1]
push_to_hub_model_id = f"{model_name}-finetuned-xsum"

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels. Note that our data collators are multi-framework, so make sure you set `return_tensors='tf'` so you get `tf.Tensor` objects back and not something else!

We also want to compute `ROUGE` metrics, which will require us to generate text from our model. To speed things up, we can compile our generation loop with XLA. This results in a *huge* speedup - up to 100X! The downside of XLA generation, though, is that it doesn't like variable input shapes, because it needs to run a new compilation for each new input shape! To compensate for that, let's use `pad_to_multiple_of` for the dataset we use for text generation. This will reduce the number of unique input shapes a lot, meaning we can get the benefits of XLA generation with only a few compilations.

In [16]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")

generation_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf", pad_to_multiple_of=128)

In [17]:
tokenized_datasets["train"]

Dataset({
    features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 204045
})

Next, we convert our datasets to `tf.data.Dataset`, which Keras understands natively. There are two ways to do this - we can use the slightly more low-level [`Dataset.to_tf_dataset()`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.to_tf_dataset) method, or we can use [`Model.prepare_tf_dataset()`](https://huggingface.co/docs/transformers/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset). The main difference between these two is that the `Model` method can inspect the model to determine which column names it can use as input, which means you don't need to specify them yourself. Make sure to specify the collator we just created as our `collate_fn`!

In [18]:
train_dataset = model.prepare_tf_dataset(
    tokenized_datasets["train"],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_collator,
)

validation_dataset = model.prepare_tf_dataset(
    tokenized_datasets["validation"],
    batch_size=batch_size,
    shuffle=False,
    collate_fn=data_collator,
)

generation_dataset = model.prepare_tf_dataset(
    tokenized_datasets["validation"],
    batch_size=8,
    shuffle=False,
    collate_fn=generation_data_collator
)

Now we initialize our loss and optimizer and compile the model. Note that most Transformers models compute loss internally - we can train on this as our loss value simply by not specifying a loss when we `compile()`.

In [19]:
from transformers import AdamWeightDecay
import tensorflow as tf

optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay)
model.compile(optimizer=optimizer)

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


Now we can train our model. We can also add a few optional callbacks here, which you can remove if they aren't useful to you. In no particular order, these are:
- PushToHubCallback will sync up our model with the Hub - this allows us to resume training from other machines, share the model after training is finished, and even test the model's inference quality midway through training!
- TensorBoard is a built-in Keras callback that logs TensorBoard metrics.
- KerasMetricCallback is a callback for computing advanced metrics. There are a number of common metrics in NLP like ROUGE which are hard to fit into your compiled training loop because they depend on decoding predictions and labels back to strings with the tokenizer, and calling arbitrary Python functions to compute the metric. The KerasMetricCallback will wrap a metric function, outputting metrics as training progresses.

If this is the first time you've seen `KerasMetricCallback`, it's worth explaining what exactly is going on here. The callback takes two main arguments - a `metric_fn` and an `eval_dataset`. It then iterates over the `eval_dataset` and collects the model's outputs for each sample, before passing the `list` of predictions and the associated `list` of labels to the user-defined `metric_fn`. If the `predict_with_generate` argument is `True`, then it will call `model.generate()` for each input sample instead of `model.predict()` - this is useful for metrics that expect generated text from the model, like `ROUGE`.

This callback allows complex metrics to be computed each epoch that would not function as a standard Keras Metric. Metric values are printed each epoch, and can be used by other callbacks like `TensorBoard` or `EarlyStopping`.

In [20]:
import numpy as np
import nltk


def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Rouge expects a newline after each sentence
    decoded_predictions = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_predictions
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]
    result = metric.compute(
        predictions=decoded_predictions, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    # Add mean generated length
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)

    return result

And now we can try training our model. By default, we only do a single epoch of training here, as the inputs are very long, which means training is quite slow. However, you may wish to experiment with larger pre-trained models and longer training runs if you want to maximize the quality of your summaries.

In [None]:
from transformers.keras_callbacks import PushToHubCallback, KerasMetricCallback
from tensorflow.keras.callbacks import TensorBoard

tensorboard_callback = TensorBoard(log_dir="./summarization_model_save/logs")

push_to_hub_callback = PushToHubCallback(
    output_dir="./summarization_model_save",
    tokenizer=tokenizer,
    hub_model_id=push_to_hub_model_id,
)

metric_callback = KerasMetricCallback(
    metric_fn, eval_dataset=generation_dataset, predict_with_generate=True, use_xla_generation=True
)

callbacks = [metric_callback, tensorboard_callback, push_to_hub_callback]

model.fit(
    train_dataset, validation_data=validation_dataset, epochs=1, callbacks=callbacks
)

Cloning https://huggingface.co/Znerual/t5-small-finetuned-xsum into local empty directory.


    6/25505 [..............................] - ETA: 4:17:09 - loss: 3.8957





If you used the callback above, you can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `"your-username/the-name-you-picked"` so for instance:

```python
from transformers import TFAutoModelForSeq2SeqLM

model = TFAutoModelForSeq2SeqLM.from_pretrained("your-username/my-awesome-model")
```

## Inference

Now we've trained our model, let's see how we could load it and use it to summarize text in future! First, let's load it from the hub. This means we can resume the code from here without needing to rerun everything above every time.

In [None]:
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

# You can of course substitute your own username and model here if you've trained and uploaded it!
model_name = 'Rocketknight1/t5-small-finetuned-xsum'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

Now let's try tokenizing a document from the training set. Don't forget to add 'summarize:' at the start if you're using a `T5` model.

In [None]:
document = 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\n"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said.\n"That may not be true but it is perhaps my perspective over the last few days.\n"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?"\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\nThe Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\n"I was quite taken aback by the amount of damage that has been done," he said.\n"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses."\nHe said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.'
if 't5' in model_name: 
    document = "summarize: " + document
tokenized = tokenizer([document], return_tensors='np')
out = model.generate(**tokenized, max_length=128)

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0]))

Not bad for a single epoch of training! Of course, the flood warning isn't much use to them after they've been flooded, but the summary correctly identified flooding in Dumfries and the Nith as the key event. 

## Using XLA in inference

If you just want to generate a few summaries, the code above is all you need. However, generation can be **much** faster if you use XLA, and if you want to generate data in bulk, you should probably use it! If you're using XLA, though, remember that you'll need to do a new XLA compilation for every input size you pass to the model. This means that you should keep your batch size constant, and consider padding inputs to the same length, or using `pad_to_multiple_of` in your tokenizer to reduce the number of different input shapes you pass. Let's show an example of that:

In [None]:
import tensorflow as tf

@tf.function(jit_compile=True)
def generate(inputs):
    return model.generate(**inputs, max_length=128)

tokenized_data = tokenizer([document], return_tensors="np", pad_to_multiple_of=128)
out = generate(tokenized_data)

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0]))

When using XLA generation, you'll notice that the first call to generate with a new input shape takes a long time because XLA has to compile your function, but subsequent calls are extremely quick. Also, XLA always generates to the maximum length, which can lead to a lot of padding tokens in your output! These are easy to remove, however:

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0], skip_special_tokens=True))

## Pipeline API

The pipeline API offers a convenient shortcut for all of this, but doesn't (yet!) support XLA generation:

In [None]:
from transformers import pipeline

summarizer = pipeline('text2text-generation', model_name, framework="tf")

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at Rocketknight1/t5-small-finetuned-xsum.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Remember that if we're using a T5 model then we appended "summarize: " to the start of our input above. Don't forget to do that when you're getting summaries for new texts!

In [None]:
summarizer(document, max_length=128)

Token indices sequence length is longer than the specified maximum sequence length for this model (541 > 512). Running this sequence through the model will result in indexing errors
2022-07-25 15:05:51.802359: I tensorflow/compiler/xla/service/service.cc:170] XLA service 0x563f5be5ff70 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-07-25 15:05:51.802390: I tensorflow/compiler/xla/service/service.cc:178]   StreamExecutor device (0): Host, Default Version




Easy!