In [2]:
import numpy as np
from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
from transformers import AutoTokenizer, DataCollatorForSeq2Seq

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = "xsum"
seed_num = 1
model_name = "google-t5/t5-small"

In [4]:
loaded_dataset = load_dataset(dataset)

Downloading builder script: 100%|██████████| 5.76k/5.76k [00:00<00:00, 10.7MB/s]
Downloading readme: 100%|██████████| 6.24k/6.24k [00:00<00:00, 12.5MB/s]
Downloading data: 100%|██████████| 255M/255M [01:08<00:00, 3.71MB/s] 
Downloading data: 2.72MB [00:00, 16.1MB/s]                           
Generating train split: 100%|██████████| 204045/204045 [00:32<00:00, 6185.20 examples/s]
Generating validation split: 100%|██████████| 11332/11332 [00:18<00:00, 611.31 examples/s]
Generating test split: 100%|██████████| 11334/11334 [00:18<00:00, 611.93 examples/s]


In [6]:
loaded_dataset
# make the dataset into a pandas dataframe
# df = pd.DataFrame(loaded_dataset['train'])
# # add the test dataset to the dataframe
# df = pd.concat([df, pd.DataFrame(loaded_dataset['test'])], ignore_index=True)

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [8]:
# Tokenize the summary column
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [9]:
prefix = "summarize: "  # Required so the T5 model knows that we are going to summarize
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs)
    labels = tokenizer(text_target=examples["summary"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_name)
tokenized_dataset = loaded_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/204045 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (541 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 204045/204045 [01:24<00:00, 2425.33 examples/s]
Map: 100%|██████████| 11332/11332 [00:05<00:00, 2237.08 examples/s]
Map: 100%|██████████| 11334/11334 [00:05<00:00, 2265.14 examples/s]


In [10]:
# Make the dataset into a Dataframe
df = pd.DataFrame(tokenized_dataset['train'])
df.tail()

Unnamed: 0,document,summary,id,input_ids,attention_mask,labels
204040,The initial figure released in July was booste...,UK economic growth for the second quarter of t...,34084759,"[21603, 10, 37, 2332, 2320, 1883, 16, 1718, 47...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1270, 1456, 1170, 21, 8, 511, 2893, 13, 8, 21..."
204041,"MEPs, including European Parliament chief Brex...",Theresa May's offer to give EU citizens in the...,40552318,"[21603, 10, 283, 8569, 7, 6, 379, 1611, 12876,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[290, 7, 9, 932, 31, 7, 462, 12, 428, 3371, 51..."
204042,Lincoln Red Imps will bring a 1-0 lead to Glas...,Erik Sviatchenko is adamant that Celtic will p...,36781065,"[21603, 10, 9884, 1624, 14472, 7, 56, 830, 3, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[21173, 180, 2099, 14547, 18994, 19, 3, 9, 781..."
204043,Former Liverpool defender Mark Lawrenson expan...,People have spent a large part of this season ...,31579588,"[21603, 10, 18263, 15131, 3, 13720, 2185, 2402...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[2449, 43, 1869, 3, 9, 508, 294, 13, 48, 774, ..."
204044,The incident occurred at the headquarters of t...,Police in Thailand have charged two executives...,35809055,"[21603, 10, 37, 5415, 6935, 44, 8, 13767, 13, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5076, 16, 10508, 43, 4977, 192, 13510, 45, 3,..."


In [12]:
from pprint import pprint
pprint(df['document'][0])

('The full cost of damage in Newton Stewart, one of the areas worst affected, '
 'is still being assessed.\n'
 'Repair work is ongoing in Hawick and many roads in Peeblesshire remain badly '
 'affected by standing water.\n'
 'Trains on the west coast mainline face disruption due to damage at the '
 'Lamington Viaduct.\n'
 'Many businesses and householders were affected by flooding in Newton Stewart '
 'after the River Cree overflowed into the town.\n'
 'First Minister Nicola Sturgeon visited the area to inspect the damage.\n'
 'The waters breached a retaining wall, flooding many commercial properties on '
 'Victoria Street - the main shopping thoroughfare.\n'
 'Jeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she '
 'could not fault the multi-agency response once the flood hit.\n'
 'However, she said more preventative work could have been carried out to '
 'ensure the retaining wall did not fail.\n'
 '"It is difficult but I do think there is so much publicity fo

In [13]:
# Give me the percentiles of length of input_ids using pandas and plot them
df['input_ids'].apply(len).describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])

count    204045.000000
mean        525.922223
std         438.174692
min           3.000000
25%         249.000000
50%         412.000000
75%         682.000000
90%        1061.000000
95%        1309.000000
99%        1937.000000
max       39490.000000
Name: input_ids, dtype: float64

In [14]:
# The same for the labels
df['labels'].apply(len).describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])

count    204045.000000
mean         30.383719
std           8.337640
min           3.000000
25%          25.000000
50%          30.000000
75%          35.000000
90%          39.000000
95%          43.000000
99%          54.000000
max         178.000000
Name: labels, dtype: float64