In [1]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset('samsum')

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

Train dataset size: 14732
Test dataset size: 819


In [2]:
from random import randrange        


sample = dataset['train'][randrange(len(dataset["train"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")
print(f"summary: \n{sample['summary']}\n---------------")

dialogue: 
Ethan: Which gas station has the best prices?
Alice: I always tank at the Tesco 
Sara: Which one?
Alice: The one out of town
Sara: I heard they have cheap gas there
Alice: It's always few cents per litre cheaper
Ethan: Good to know. Thanks!! 
---------------
summary: 
The Tesco that's out of town has the cheapest gas according to Alice and Sara.
---------------


In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_id="google/flan-t5-small"

# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [4]:
from datasets import concatenate_datasets

# The maximum total input sequence length after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Max source length: {max_source_length}")

# The maximum total sequence length for target text after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
print(f"Max target length: {max_target_length}")

Max source length: 512
Max target length: 95


In [5]:
def preprocess_function(sample,padding="max_length"):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")

Map:   0%|          | 0/14732 [00:00<?, ? examples/s]

Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']


In [6]:
from transformers import AutoModelForSeq2SeqLM

# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

In [7]:
import evaluate
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
nltk.download("punkt")

# Metric
metric = evaluate.load("rouge")

# helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\linha\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [8]:
from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)

In [9]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Hugging Face repository id
repository_id = f"{model_id.split('/')[1]}-{'samsum'}"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=repository_id,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    predict_with_generate=True,
    fp16 = False, # Overflows with fp16
    learning_rate=5e-5,
    num_train_epochs=10,
    # logging & evaluation strategies
    logging_dir=f"{repository_id}/logs",
    logging_strategy="steps",
    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    # metric_for_best_model="overall_f1",
    # push to hub parameters
    report_to="tensorboard"
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    compute_metrics=compute_metrics,
)

In [10]:
trainer.train()

  0%|          | 0/36830 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 1.9209, 'learning_rate': 4.993212055389628e-05, 'epoch': 0.01}
{'loss': 1.9404, 'learning_rate': 4.986424110779257e-05, 'epoch': 0.03}
{'loss': 1.9285, 'learning_rate': 4.9796361661688847e-05, 'epoch': 0.04}
{'loss': 1.7874, 'learning_rate': 4.972848221558512e-05, 'epoch': 0.05}
{'loss': 1.8714, 'learning_rate': 4.96606027694814e-05, 'epoch': 0.07}
{'loss': 1.8926, 'learning_rate': 4.9592723323377684e-05, 'epoch': 0.08}
{'loss': 1.7895, 'learning_rate': 4.952484387727396e-05, 'epoch': 0.1}
{'loss': 1.8691, 'learning_rate': 4.945696443117024e-05, 'epoch': 0.11}
{'loss': 1.8162, 'learning_rate': 4.938908498506653e-05, 'epoch': 0.12}
{'loss': 1.8048, 'learning_rate': 4.932120553896281e-05, 'epoch': 0.14}
{'loss': 1.8271, 'learning_rate': 4.9253326092859086e-05, 'epoch': 0.15}
{'loss': 1.8404, 'learning_rate': 4.9185446646755365e-05, 'epoch': 0.16}
{'loss': 1.8631, 'learning_rate': 4.9117567200651645e-05, 'epoch': 0.18}
{'loss': 1.8776, 'learning_rate': 4.9049687754547924e-05, 'ep



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6491508483886719, 'eval_rouge1': 43.5063, 'eval_rouge2': 19.2741, 'eval_rougeL': 36.212, 'eval_rougeLsum': 39.6495, 'eval_gen_len': 16.95848595848596, 'eval_runtime': 54.9921, 'eval_samples_per_second': 14.893, 'eval_steps_per_second': 3.728, 'epoch': 1.0}
{'loss': 1.7974, 'learning_rate': 4.497692098832474e-05, 'epoch': 1.0}
{'loss': 1.7524, 'learning_rate': 4.490904154222102e-05, 'epoch': 1.02}
{'loss': 1.734, 'learning_rate': 4.48411620961173e-05, 'epoch': 1.03}
{'loss': 1.7572, 'learning_rate': 4.4773282650013576e-05, 'epoch': 1.05}
{'loss': 1.7186, 'learning_rate': 4.4705403203909855e-05, 'epoch': 1.06}
{'loss': 1.6856, 'learning_rate': 4.4637523757806134e-05, 'epoch': 1.07}
{'loss': 1.6907, 'learning_rate': 4.456964431170242e-05, 'epoch': 1.09}
{'loss': 1.7638, 'learning_rate': 4.45017648655987e-05, 'epoch': 1.1}
{'loss': 1.7183, 'learning_rate': 4.443388541949498e-05, 'epoch': 1.11}
{'loss': 1.6803, 'learning_rate': 4.4366005973391264e-05, 'epoch': 1.13}
{'loss':



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6354368925094604, 'eval_rouge1': 43.1225, 'eval_rouge2': 19.2661, 'eval_rougeL': 36.0203, 'eval_rougeLsum': 39.2956, 'eval_gen_len': 16.532356532356534, 'eval_runtime': 55.9297, 'eval_samples_per_second': 14.643, 'eval_steps_per_second': 3.665, 'epoch': 2.0}
{'loss': 1.6372, 'learning_rate': 3.995384197664947e-05, 'epoch': 2.01}
{'loss': 1.6198, 'learning_rate': 3.988596253054575e-05, 'epoch': 2.02}
{'loss': 1.707, 'learning_rate': 3.9818083084442033e-05, 'epoch': 2.04}
{'loss': 1.654, 'learning_rate': 3.975020363833831e-05, 'epoch': 2.05}
{'loss': 1.586, 'learning_rate': 3.968232419223459e-05, 'epoch': 2.06}
{'loss': 1.7352, 'learning_rate': 3.961444474613088e-05, 'epoch': 2.08}
{'loss': 1.7057, 'learning_rate': 3.954656530002716e-05, 'epoch': 2.09}
{'loss': 1.6854, 'learning_rate': 3.9478685853923436e-05, 'epoch': 2.1}
{'loss': 1.7343, 'learning_rate': 3.941080640781971e-05, 'epoch': 2.12}
{'loss': 1.6417, 'learning_rate': 3.9342926961715994e-05, 'epoch': 2.13}
{'loss



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6240123510360718, 'eval_rouge1': 43.0724, 'eval_rouge2': 19.4419, 'eval_rougeL': 35.9736, 'eval_rougeLsum': 39.4537, 'eval_gen_len': 16.67032967032967, 'eval_runtime': 135.7165, 'eval_samples_per_second': 6.035, 'eval_steps_per_second': 1.511, 'epoch': 3.0}
{'loss': 1.6287, 'learning_rate': 3.4998642411077926e-05, 'epoch': 3.0}
{'loss': 1.6522, 'learning_rate': 3.4930762964974205e-05, 'epoch': 3.01}
{'loss': 1.6321, 'learning_rate': 3.4862883518870484e-05, 'epoch': 3.03}
{'loss': 1.601, 'learning_rate': 3.479500407276677e-05, 'epoch': 3.04}
{'loss': 1.5625, 'learning_rate': 3.472712462666305e-05, 'epoch': 3.05}
{'loss': 1.5314, 'learning_rate': 3.465924518055933e-05, 'epoch': 3.07}
{'loss': 1.5877, 'learning_rate': 3.4591365734455614e-05, 'epoch': 3.08}
{'loss': 1.5271, 'learning_rate': 3.4523486288351886e-05, 'epoch': 3.1}
{'loss': 1.6178, 'learning_rate': 3.4455606842248166e-05, 'epoch': 3.11}
{'loss': 1.6185, 'learning_rate': 3.4387727396144445e-05, 'epoch': 3.12}
{'



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6198571920394897, 'eval_rouge1': 44.1204, 'eval_rouge2': 20.2736, 'eval_rougeL': 36.8841, 'eval_rougeLsum': 40.4502, 'eval_gen_len': 16.605616605616607, 'eval_runtime': 151.8226, 'eval_samples_per_second': 5.394, 'eval_steps_per_second': 1.35, 'epoch': 4.0}
{'loss': 1.5812, 'learning_rate': 2.9975563399402662e-05, 'epoch': 4.0}
{'loss': 1.558, 'learning_rate': 2.990768395329894e-05, 'epoch': 4.02}
{'loss': 1.5103, 'learning_rate': 2.9839804507195224e-05, 'epoch': 4.03}
{'loss': 1.573, 'learning_rate': 2.9771925061091503e-05, 'epoch': 4.05}
{'loss': 1.5692, 'learning_rate': 2.9704045614987785e-05, 'epoch': 4.06}
{'loss': 1.5176, 'learning_rate': 2.963616616888406e-05, 'epoch': 4.07}
{'loss': 1.5898, 'learning_rate': 2.956828672278034e-05, 'epoch': 4.09}
{'loss': 1.6016, 'learning_rate': 2.9500407276676623e-05, 'epoch': 4.1}
{'loss': 1.5455, 'learning_rate': 2.9432527830572902e-05, 'epoch': 4.11}
{'loss': 1.5184, 'learning_rate': 2.9364648384469184e-05, 'epoch': 4.13}
{'l



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6169418096542358, 'eval_rouge1': 44.153, 'eval_rouge2': 20.1367, 'eval_rougeL': 36.8736, 'eval_rougeLsum': 40.4585, 'eval_gen_len': 16.804639804639805, 'eval_runtime': 56.39, 'eval_samples_per_second': 14.524, 'eval_steps_per_second': 3.635, 'epoch': 5.0}
{'loss': 1.5238, 'learning_rate': 2.49524843877274e-05, 'epoch': 5.01}
{'loss': 1.4867, 'learning_rate': 2.4884604941623678e-05, 'epoch': 5.02}
{'loss': 1.5461, 'learning_rate': 2.4816725495519957e-05, 'epoch': 5.04}
{'loss': 1.4787, 'learning_rate': 2.474884604941624e-05, 'epoch': 5.05}
{'loss': 1.5038, 'learning_rate': 2.468096660331252e-05, 'epoch': 5.06}
{'loss': 1.5701, 'learning_rate': 2.4613087157208798e-05, 'epoch': 5.08}
{'loss': 1.4932, 'learning_rate': 2.4545207711105077e-05, 'epoch': 5.09}
{'loss': 1.5636, 'learning_rate': 2.447732826500136e-05, 'epoch': 5.1}
{'loss': 1.492, 'learning_rate': 2.440944881889764e-05, 'epoch': 5.12}
{'loss': 1.4649, 'learning_rate': 2.4341569372793918e-05, 'epoch': 5.13}
{'loss



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6206824779510498, 'eval_rouge1': 44.733, 'eval_rouge2': 20.4625, 'eval_rougeL': 37.1198, 'eval_rougeLsum': 40.871, 'eval_gen_len': 16.978021978021978, 'eval_runtime': 54.0477, 'eval_samples_per_second': 15.153, 'eval_steps_per_second': 3.793, 'epoch': 6.0}
{'loss': 1.4734, 'learning_rate': 1.9997284822155853e-05, 'epoch': 6.0}
{'loss': 1.501, 'learning_rate': 1.992940537605213e-05, 'epoch': 6.01}
{'loss': 1.5319, 'learning_rate': 1.9861525929948414e-05, 'epoch': 6.03}
{'loss': 1.4842, 'learning_rate': 1.9793646483844693e-05, 'epoch': 6.04}
{'loss': 1.5514, 'learning_rate': 1.9725767037740972e-05, 'epoch': 6.05}
{'loss': 1.3935, 'learning_rate': 1.965788759163725e-05, 'epoch': 6.07}
{'loss': 1.4688, 'learning_rate': 1.9590008145533534e-05, 'epoch': 6.08}
{'loss': 1.4952, 'learning_rate': 1.9522128699429813e-05, 'epoch': 6.1}
{'loss': 1.5246, 'learning_rate': 1.9454249253326092e-05, 'epoch': 6.11}
{'loss': 1.4749, 'learning_rate': 1.9386369807222375e-05, 'epoch': 6.12}
{'



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.61181640625, 'eval_rouge1': 44.4626, 'eval_rouge2': 20.3587, 'eval_rougeL': 36.8987, 'eval_rougeLsum': 40.6886, 'eval_gen_len': 16.912087912087912, 'eval_runtime': 56.3581, 'eval_samples_per_second': 14.532, 'eval_steps_per_second': 3.637, 'epoch': 7.0}
{'loss': 1.5232, 'learning_rate': 1.4974205810480587e-05, 'epoch': 7.01}
{'loss': 1.4667, 'learning_rate': 1.4906326364376868e-05, 'epoch': 7.02}
{'loss': 1.4187, 'learning_rate': 1.4838446918273147e-05, 'epoch': 7.03}
{'loss': 1.4442, 'learning_rate': 1.4770567472169428e-05, 'epoch': 7.05}
{'loss': 1.4759, 'learning_rate': 1.4702688026065709e-05, 'epoch': 7.06}
{'loss': 1.5311, 'learning_rate': 1.463480857996199e-05, 'epoch': 7.07}
{'loss': 1.451, 'learning_rate': 1.4566929133858267e-05, 'epoch': 7.09}
{'loss': 1.5007, 'learning_rate': 1.4499049687754548e-05, 'epoch': 7.1}
{'loss': 1.4924, 'learning_rate': 1.4431170241650829e-05, 'epoch': 7.11}
{'loss': 1.5019, 'learning_rate': 1.436329079554711e-05, 'epoch': 7.13}
{'lo



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.615966796875, 'eval_rouge1': 44.2972, 'eval_rouge2': 19.9653, 'eval_rougeL': 36.8154, 'eval_rougeLsum': 40.4101, 'eval_gen_len': 16.746031746031747, 'eval_runtime': 53.6182, 'eval_samples_per_second': 15.275, 'eval_steps_per_second': 3.823, 'epoch': 8.0}
{'loss': 1.4716, 'learning_rate': 9.951126798805322e-06, 'epoch': 8.01}
{'loss': 1.4544, 'learning_rate': 9.883247352701603e-06, 'epoch': 8.02}
{'loss': 1.5091, 'learning_rate': 9.815367906597884e-06, 'epoch': 8.04}
{'loss': 1.4929, 'learning_rate': 9.747488460494163e-06, 'epoch': 8.05}
{'loss': 1.4317, 'learning_rate': 9.679609014390444e-06, 'epoch': 8.06}
{'loss': 1.4626, 'learning_rate': 9.611729568286723e-06, 'epoch': 8.08}
{'loss': 1.496, 'learning_rate': 9.543850122183004e-06, 'epoch': 8.09}
{'loss': 1.4512, 'learning_rate': 9.475970676079284e-06, 'epoch': 8.1}
{'loss': 1.4381, 'learning_rate': 9.408091229975565e-06, 'epoch': 8.12}
{'loss': 1.3879, 'learning_rate': 9.340211783871844e-06, 'epoch': 8.13}
{'loss': 1.



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.6165939569473267, 'eval_rouge1': 44.1914, 'eval_rouge2': 19.9611, 'eval_rougeL': 36.6953, 'eval_rougeLsum': 40.3145, 'eval_gen_len': 16.913308913308914, 'eval_runtime': 52.872, 'eval_samples_per_second': 15.49, 'eval_steps_per_second': 3.877, 'epoch': 9.0}
{'loss': 1.4229, 'learning_rate': 4.995927233233777e-06, 'epoch': 9.0}
{'loss': 1.4434, 'learning_rate': 4.9280477871300576e-06, 'epoch': 9.01}
{'loss': 1.4977, 'learning_rate': 4.8601683410263375e-06, 'epoch': 9.03}
{'loss': 1.4452, 'learning_rate': 4.7922888949226175e-06, 'epoch': 9.04}
{'loss': 1.4682, 'learning_rate': 4.7244094488188975e-06, 'epoch': 9.06}
{'loss': 1.4926, 'learning_rate': 4.656530002715178e-06, 'epoch': 9.07}
{'loss': 1.4597, 'learning_rate': 4.588650556611458e-06, 'epoch': 9.08}
{'loss': 1.4088, 'learning_rate': 4.520771110507738e-06, 'epoch': 9.1}
{'loss': 1.5238, 'learning_rate': 4.452891664404018e-06, 'epoch': 9.11}
{'loss': 1.4693, 'learning_rate': 4.385012218300299e-06, 'epoch': 9.12}
{'los



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.617746114730835, 'eval_rouge1': 44.4606, 'eval_rouge2': 20.1139, 'eval_rougeL': 36.8539, 'eval_rougeLsum': 40.5396, 'eval_gen_len': 16.94871794871795, 'eval_runtime': 56.2279, 'eval_samples_per_second': 14.566, 'eval_steps_per_second': 3.646, 'epoch': 10.0}


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


{'train_runtime': 6683.0116, 'train_samples_per_second': 22.044, 'train_steps_per_second': 5.511, 'train_loss': 1.5758347343757624, 'epoch': 10.0}


TrainOutput(global_step=36830, training_loss=1.5758347343757624, metrics={'train_runtime': 6683.0116, 'train_samples_per_second': 22.044, 'train_steps_per_second': 5.511, 'train_loss': 1.5758347343757624, 'epoch': 10.0})

In [11]:
#trainer.evaluate()



  0%|          | 0/205 [00:00<?, ?it/s]

{'eval_loss': 1.3694512844085693,
 'eval_rouge1': 47.5774,
 'eval_rouge2': 24.1721,
 'eval_rougeL': 40.1577,
 'eval_rougeLsum': 43.7501,
 'eval_gen_len': 17.024420024420024,
 'eval_runtime': 127.8578,
 'eval_samples_per_second': 6.406,
 'eval_steps_per_second': 1.603,
 'epoch': 10.0}