You will need to install Transformers and  Datasets as well as other dependencies. 

In [1]:
! pip install datasets transformers rouge-score nltk

Collecting datasets
  Downloading datasets-1.18.3-py3-none-any.whl (311 kB)
[?25l[K     |█                               | 10 kB 42.2 MB/s eta 0:00:01[K     |██                              | 20 kB 40.2 MB/s eta 0:00:01[K     |███▏                            | 30 kB 14.8 MB/s eta 0:00:01[K     |████▏                           | 40 kB 7.4 MB/s eta 0:00:01[K     |█████▎                          | 51 kB 6.8 MB/s eta 0:00:01[K     |██████▎                         | 61 kB 8.0 MB/s eta 0:00:01[K     |███████▍                        | 71 kB 8.5 MB/s eta 0:00:01[K     |████████▍                       | 81 kB 8.8 MB/s eta 0:00:01[K     |█████████▌                      | 92 kB 9.8 MB/s eta 0:00:01[K     |██████████▌                     | 102 kB 8.2 MB/s eta 0:00:01[K     |███████████▋                    | 112 kB 8.2 MB/s eta 0:00:01[K     |████████████▋                   | 122 kB 8.2 MB/s eta 0:00:01[K     |█████████████▊                  | 133 kB 8.2 MB/s eta 0:00:01

In [2]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [3]:
!apt install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-470
Use 'apt autoremove' to remove it.
The following NEW packages will be installed:
  git-lfs
0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.
Need to get 2,129 kB of archives.
After this operation, 7,662 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 git-lfs amd64 2.3.4-1 [2,129 kB]
Fetched 2,129 kB in 1s (1,624 kB/s)
Selecting previously unselected package git-lfs.
(Reading database ... 155320 files and directories currently installed.)
Preparing to unpack .../git-lfs_2.3.4-1_amd64.deb ...
Unpacking git-lfs (2.3.4-1) ...
Setting up git-lfs (2.3.4-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...


# Fine-tuning a pre-trained model on a summarization task

To do that, we can use any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`t5-small`](https://huggingface.co/t5-small) checkpoint. 

In [5]:
model_checkpoint = "t5-small"

## Loading the dataset

We will use a sampled publicly available text summarization dataset based on some news article collected  by `Hermann et al (2015)` (can be downloaded from [here](https://drive.google.com/drive/folders/1CP36015srVw9Q0-1kTwOfAuo4z2t2aKh?usp=sharing)) after doing some preprocessing. This can be easily done with the functions `load_dataset` and `load_metric` for calculating the performance. 

In [7]:
from datasets import load_dataset, load_metric

raw_datasets =load_dataset('csv', data_files={'train': 'train.csv', 'test': '/test.csv', 'val': 'val.csv'})
metric = load_metric("rouge")

Using custom data configuration default-8cc7cca13617b5a6


Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-8cc7cca13617b5a6/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e...


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-8cc7cca13617b5a6/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [8]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 5000
    })
    val: Dataset({
        features: ['document', 'summary'],
        num_rows: 5000
    })
})

To access an actual element, you need to select a split first, then give an index:

In [9]:
raw_datasets["train"][0]

{'document': "editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become m

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [10]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [11]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary
0,"-lrb- cnn -rrb- -- a man suspected in the slayings of his girlfriend and her four children admitted choking the oklahoma woman to death , but said the children were not present at the time , according to an affidavit filed in the case . joshua steven durcho was arrested tuesday after a car chase with police . joshua steven durcho , 25 , was arrested tuesday night in hamilton county , texas , officials said . he is suspected of killing summer rust , 25 ; her son teagin , 4 ; and daughters evynn , 3 , and autumn and kirsten , both 7 . all five bodies were found in rust 's apartment in el reno , oklahoma , about 30 miles west of oklahoma city , on monday . durcho 's first cousin notified authorities he found the body of rust , who is identified in the affidavit as summer dawn garas . police also found the children 's bodies in the apartment , according to the affidavit , written by a special agent with the oklahoma state bureau of investigation and filed tuesday in canadian county , oklahoma , district court . `` the medical examiner 's office has reported to our agents that the preliminary assessment of the cause and manner of death for all five individuals was asphyxiation , suffocation and strangulation , '' the affidavit said . `` it was also reported that each body had ligature marks around the neck . the ligature marks were also observed by osbi crime scene investigators . '' a spokeswoman for the state medical examiner 's office told cnn on wednesday that the cause of death for summer rust and teagin was strangulation , and that a ligature -- which could include a string , cord or wire -- was used to strangle them . autopsies on the three girls were being conducted wednesday , the spokeswoman said . a woman told police durcho came to her apartment monday afternoon and told her he had `` choked '' summer rust to death and that he was leaving oklahoma , according to the affidavit . the woman asked durcho about rust 's children , the affidavit said , and `` durcho told her that the children were at their grandmother 's residence ... while he and summer worked out their relationship problems . '' the woman called durcho 's mother and told her what he had said about killing rust , the affidavit said . durcho 's mother drove to the apartment to check on the woman , but no one answered her knocks . she then called her nephew , durcho 's cousin , to accompany her , leading to the discovery of rust 's body , according to the document . about 6:30 p.m. monday , the affidavit said , durcho went to the home of another cousin , a female , and told her `` he was in trouble and that he was headed out of state . '' durcho was driving rust 's 1989 white ford thunderbird , the document said , and asked his cousin to swap cars with him , but she declined . a surveillance video showed durcho at a truck stop on interstate 40 about three hours later , driving the thunderbird , the affidavit said . early tuesday morning , a text message was sent from a cell phone in durcho 's possession to his mother 's cell phone , according to the affidavit . tracking and cell phone records showed durcho 's phone was located in wichita falls , texas , at the time . later that morning , durcho called his mother , with the call shown to be from the abilene , texas , area , the affidavit said . durcho 's mother said `` durcho told her he loved her and had to go , '' according to the document . police said durcho was arrested after a car chase tuesday night . a texas state trooper attempted to stop the car durcho was driving because the trooper suspected the driver was drunk , according to erin mangrum of the canadian county sheriff 's office . when the trooper ran the license plate on the car , it matched the tag number of a vehicle sought by oklahoma police . the car sped off , mangrum said , and during the ensuing chase the car crashed . durcho suffered only minor injuries and was taken into custody , mangrum said . a court hearing was to be held for durcho on wednesday in hamilton county , according to cnn affiliates . the hamilton county district attorney 's office did not immediately return a call from cnn . durcho was being held in the county jail tuesday night , mangrum said . rust 's mother , susan rust of carson city , nevada , said durcho was unemployed and had been living with rust and her children .\n","<t> new : affidavit describes suspect 's actions after slayings . </t> <t> mother , son strangled with ligature , autopsy shows . </t> <t> suspect arrested after chase in texas . </t> <t> family found dead in oklahoma apartment on monday . </t>\n"
1,"kathmandu , nepal -lrb- cnn -rrb- -- the effort to rescue hundreds of trekkers stranded for six days in a town near nepal 's iconic mount everest due to bad weather continued monday . with weather conditions improving monday , the process of transporting the tourists to the nation 's capital was fully under way , the tourism ministry said . `` the target is to transport 1,500 tourists to kathmandu today , '' hari basyal , spokesman of the nepal ministry of tourism and civil aviation told cnn . more than 2,200 tourists have been stranded in the village of lukla since last week where food supplies are limited , basyal said . lukla , a popular starting point for people on their way to the world 's tallest peak , is located in northeast nepal . stranded tourists took 48 flights and helicopter rides out of lukla on monday . a day earlier , at least 500 trekkers were flown to the capital . last week , some of the tourists began a four-day walk to the town of jiri to take buses to kathmandu .\n",<t> bad weather conditions have stranded hundreds of trekkers . </t> <t> tourists are being evacuated from the town of lukla to nepal 's capital . </t>\n
2,"havana , cuba -lrb- cnn -rrb- -- cuban president raul castro said sunday that his government would allow more private businesses and make it easier for those businesses to hire workers , as the socialist economy struggles to get back on its feet and shed up to one million redundant state jobs . the government `` agreed to broaden the exercise of self employment and its use as another alternative for the employment of those excess workers , '' castro said during a biannual session of the national assembly . he went on to say that the government would eliminate `` numerous '' prohibitions to the granting of licenses for private businesses and to the sales of some products , as well as `` make the contracting of a work force more flexible . '' in exchange , those businesses will pay taxes on income and sales , and pay contributions for employees , he said . the measures `` constitute a structural and conceptual change in the interest of preserving and developing our social system to make it sustainable in the future , '' castro said . the decision was part of a series of measures approved by the council of ministers to reduce `` the considerably inflated payroll in the state sector , '' he added . earlier this year , castro said that more than one million state jobs , out of a total of 5.1 million , could be redundant . the government has launched a few , small free-market reforms . in april , for example , barbershops were handed over to employees , who pay rent and taxes but charge what they want . licenses have also been granted to private taxis . for a couple of years , fallow land in the countryside has been turned over to private farmers . the more they produce , the more they earn . not surprisingly , output is up . in his speech , castro also mentioned for the first time the release of political prisoners that started last month . `` the revolution can be generous because it is strong , '' he said . the catholic church and spain negotiated the release of 52 prisoners jailed in 2003 as part of a crackdown on people the government accused of receiving money from the enemy government in washington . so far , 21 prisoners have been released and flown abroad . castro did not give details about why cuba agreed to free the `` counterrevolutionaries '' or if he expected any gesture in return . former president fidel castro did not make an appearance despite expectations that he might .\n",<t> the cuban president says it will become easier to have a private business . </t> <t> raul castro says jobs will be cut `` in the considerably inflated '' state sector . </t> <t> fidel castro did not attend his brother 's speech to the national assembly . </t>\n
3,"-lrb- cnn -rrb- -- fending off elimination for the third straight game , the san francisco giants thrashed the st. louis cardinals on monday night to earn the right to battle for their second world series title in three years . the giants beat the cardinals 9-0 monday in a game that lacked the drama of some of the other games during the highly competitive series . san francisco had been on the brink since last thursday , when st. louis jumped to a three games to one advantage in the best-of-seven national league championship series . but san francisco has been on a roll since -- beating the cardinals twice by five-run margins , before finishing off the reigning champs in convincing fashion monday at at&t park , the northern california team 's home . with the victory , the giants earned the right to face the detroit tigers in the world series , which begins wednesday . the tigers handily swept the new york yankees last week to become american league titlists and earn their shot at baseball 's top prize . on monday night , san francisco used the same formula that 's worked for them so well in recent days : timely hitting and exceptional pitching . matt cain was stellar on the mound , pitching 5 and 2/3 scoreless innings to start the game . he even helped his cause at the plate with a run-scoring single in the second inning . the next frame , the giants exploded for five runs to build a comfortable 7-0 lead . they never looked back , fending off a st. louis rally in the sixth inning and tacking on more runs in the seventh and eighth innings . monday 's win also marked the second furious comeback the giants had completed during the postseason . the team was able to come back from a 2-0 deficit and win a best-of-five series against the cincinati reds earlier in the playoffs . `` they did n't want to go home . they kept believing , '' said giants manager bruce bochy . giants catcher buster posey said it is the team 's attitude that helps them beat the odds . `` it is a lot of want and a lot of will power , '' posey told reporters . `` you have to believe you can do it . '' the giants easily finished the regular season atop the national league west with a 94-68 record , tied with atlanta for fourth best in the league . detroit , meanwhile , had an 88-74 mark in pulling past the chicago white sox in the final month to earn a ticket to the playoffs . that said , while their record is inferior , the tigers feature two of the biggest weapons in baseball in pitcher and 2011 american league mvp justin verlander and this year 's triple crown winner miguel cabrera .\n","<t> st. louis led the national league series 3-1 just a few days ago . </t> <t> but san francisco wins 3 straight to advance to the world series . </t> <t> they 'll face the detroit tigers , led by justin verlander and miguel cabrera . </t>\n"
4,"-lrb- cnn -rrb- -- endurance swimmer diana nyad was stung by a sea creature for a second time saturday night and was being treated by doctors , her team said in a blog post , leaving the continuation of her swim in question . the 62-year-old , in transit from havana to florida , was stung by some kind of presumed jellyfish , the blog said . `` her face and eyes and the area around her eyes are affected . she is out of the water and aboard the transom of the voyager where she is being treated by doctors . it will be up to diana to decide whether or not to continue to the swim . '' the incident was another setback for the athlete , who was stung by portuguese man o ' war earlier in the day . at 6:30 p.m. , nyad was 49 miles from havana . having passed the 24-hour mark , she was entering a critical time in her quest to cover the 103 miles . `` tonight , her second night in the open water , may be the most critical , '' the team wrote . `` steve munatones , the independent observer for the international swim federation who is accompanying the expedition , says that swimmers have a much better chance of success if they can make it through the second night . '' there was a bit of excitement early saturday afternoon . an oceanic whitetip shark swam near nyad , but a diver on her team faced it off and it meandered away . the swimmer improved her performance late saturday morning after struggling to maintain her usual stroke rate , her support team said . fortified by chicken soup , nyad was making good progress until the saturday evening incident . `` this afternoon -- it is stunning to actually witness -- diana is swimming stronger and stronger , '' one post said . `` her strokes are up to 50 per minute , she is eating pasta , gobbling bananas , bits of peanut butter sandwiches , along with high-carb & high calorie liquid concoctions . '' the going was rough before dawn saturday , when nyad had stopped her freestyle stroke and complained that she could n't breathe properly after getting stung . doctors from the university of miami gave the swimmer a shot to reduce inflammation , oxygen and other medication , the blog said , and after treading water for an hour nyad said she felt better . candace hogan , a friend who has been on most of nyad 's swims since 1978 , said she could recover and complete the 100-plus mile marathon , the blog said . the team initially said nyad had likely been stung by a moon jellyfish but revised that to say it was the more troublesome portuguese man o ' war . a national institutes of health report says an encounter with a portuguese man o ' war can lead to `` significant systemic reactions '' but rarely death . while mild stings generally produce localized pain , severe ones can provoke symptoms ranging from headaches to seizures , delirium , coma and paralysis , as well as breathing problems , cramping and vomiting , the nih says . chief handler bonnie stoll said on twitter that nyad had been `` stung along both arms the side of her body and her face . '' nyad had to clear herself of tentacles , change her swimsuit and put on a shirt for protection after the incident . another member of nyad 's support team said the way she handled the setback was a `` testament to her strength . '' `` it was scary , '' stoll said . `` but diana is happy that this happened early while she is still at her strongest . '' a safety diver who entered the water to help nyad was also stung numerous times , her blog said . he received treatment back on one of the flotilla of boats accompanying her . nyad began her swim just after 6 p.m. friday from havana 's hemingway marina . the former world champion swimmer expects the swim to take close to 60 hours , which would put her into florida sometime monday . nyad last attempted this swim in early august and had to be pulled from the water after some 60 miles , and almost 29 hours of swimming . she blamed a shoulder injury she suffered early in the journey , and an 11-hour-long asthma attack . her first cuba-florida attempt , back in 1978 , was brought to an end by strong currents and bad weather after almost 42 hours in the water , according to her website . cnn 's matt sloane and shasta darlington contributed to this report .\n","<t> new : she is deciding whether she can continue the swim . </t> <t> a shark approached the swim area , but swam away . </t> <t> this is her third attempt to swim the 100-plus miles from cuba to florida . </t> <t> her first attempt , in 1978 , was cut short by bad weather and strong currents . </t>\n"


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [12]:
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Retu

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.


In [14]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [18]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [19]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [20]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[21603, 10, 6005, 3, 31, 7, 2232, 3, 10, 16, 69, 1187, 8, 8073, 939, 3, 6, 3, 75, 29, 29, 25688, 7, 698, 70, 2704, 16, 6013, 1506, 11, 8341, 8, 1937, 1187, 8, 984, 3, 5, 270, 3, 6, 4199, 14677, 3, 32, 31, 2160, 35, 1217, 1105, 1096, 3, 9, 11796, 213, 186, 13, 8, 16, 11171, 33, 19367, 3, 1092, 3, 5, 46, 16, 5058, 629, 26, 30, 8, 3, 2, 11821, 1501, 3, 6, 3, 31, 31, 213, 186, 19367, 3, 1092, 16, 11171, 33, 629, 26, 16, 1337, 3690, 274, 3689, 3, 5, 1337, 3690, 3, 6, 12215, 26, 9, 3, 18, 40, 52, 115, 18, 3, 75, 29, 29, 3, 18, 52, 52, 115, 18, 1636, 8, 24651, 1501, 13, 8, 1337, 3690, 18, 14677, 15, 7140, 12042, 20, 9174, 3064, 19, 3, 26, 17344, 8, 3, 2, 11821, 1501, 3, 5, 3, 31, 31, 270, 3, 6, 16, 11171, 28, 8, 167, 5274, 2550, 21154, 33, 3, 14736, 15, 4094, 552, 79, 3, 31, 60, 1065, 12, 2385, 16, 1614, 3, 5, 167, 557, 3, 6, 79, 522, 2672, 3991, 42, 3991, 13, 12710, 53, 46, 5502, 1636, 3991, 24, 5191, 3, 849, 1926, 90, 99, 348, 845, 33, 1086, 3, 2, 1792, 179, 3110, 106, 725, 3

In [21]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/50 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [22]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/231M [00:00<?, ?B/s]

To instantiate a `Seq2SeqTrainer`, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [23]:
batch_size = 4
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-sum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

Finally, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:

In [24]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [25]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [26]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp half precision backend


We can now finetune our model by just calling the `train` method:

In [27]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, document.
***** Running training *****
  Num examples = 50000
  Num Epochs = 2
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 25000


Epoch,Training Loss,Validation Loss


Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-500
Configuration saved in t5-small-finetuned-xsum/checkpoint-500/config.json
Model weights saved in t5-small-finetuned-xsum/checkpoint-500/pytorch_model.bin
tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-500/tokenizer_config.json
Special tokens file saved in t5-small-finetuned-xsum/checkpoint-500/special_tokens_map.json


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,1.1773,1.084538,19.9926,9.3496,17.6834,19.1233,19.0
2,1.1649,1.079481,20.043,9.3599,17.7692,19.1832,19.0


Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-1000
Configuration saved in t5-small-finetuned-xsum/checkpoint-1000/config.json
Model weights saved in t5-small-finetuned-xsum/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-1000/tokenizer_config.json
Special tokens file saved in t5-small-finetuned-xsum/checkpoint-1000/special_tokens_map.json
Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-1500
Configuration saved in t5-small-finetuned-xsum/checkpoint-1500/config.json
Model weights saved in t5-small-finetuned-xsum/checkpoint-1500/pytorch_model.bin
tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-1500/tokenizer_config.json
Special tokens file saved in t5-small-finetuned-xsum/checkpoint-1500/special_tokens_map.json
Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-2000
Configuration saved in t5-small-finetuned-xsum/checkpoint-2000/config.json
Model weights saved in t5-small-finetune

TrainOutput(global_step=25000, training_loss=1.1962815795898438, metrics={'train_runtime': 7104.5361, 'train_samples_per_second': 14.076, 'train_steps_per_second': 3.519, 'total_flos': 2.6782396813541376e+16, 'train_loss': 1.1962815795898438, 'epoch': 2.0})

Now we can evalute our finetuned model on the test data.

In [29]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, document.
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 4


{'epoch': 2.0,
 'eval_gen_len': 19.0,
 'eval_loss': 1.0794814825057983,
 'eval_rouge1': 20.043,
 'eval_rouge2': 9.3599,
 'eval_rougeL': 17.7692,
 'eval_rougeLsum': 19.1832,
 'eval_runtime': 355.6721,
 'eval_samples_per_second': 14.058,
 'eval_steps_per_second': 3.514}

The rest is just to save the model in order to use in the future as a predictor.

In [30]:
trainer.save_model("sum_model_")

Saving model checkpoint to sum_model_
Configuration saved in sum_model_/config.json
Model weights saved in sum_model_/pytorch_model.bin
tokenizer config file saved in sum_model_/tokenizer_config.json
Special tokens file saved in sum_model_/special_tokens_map.json
