# Fine-tuning a T5 to summarize customer feedback
In this notebook, the 🤗 T5 Transformers model is fine-tunded for a summarization task. A simulated dataset of customer feedback messages to an insurance platform is used for training and evaluation.

This notebook is built to run with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that *model has a sequence-to-sequence version in the Transformers library*. Here we picked the t5-small checkpoint.

**SEE COLAB VERSION: [HERE](https://colab.research.google.com/drive/12OI_Q0SzhA-Ik17EkC-T0NNZN9CNSRPx?usp=sharing)**

In [1]:
!pip install datasets evaluate transformers[torch] rouge-score nltk jsonlines

Collecting datasets
  Downloading datasets-2.19.0-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [8]:
import os
import tarfile
import random
import jsonlines
import json
import nltk
import numpy as np
import pandas as pd
from IPython.display import display, HTML
import datasets
from datasets import load_dataset, load_from_disk
from evaluate import load
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

nltk.download("punkt")

os.environ["HF_TOKEN"] = "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"

print(transformers.__version__)

# check Python version
from platform import python_version
print(python_version())

tar_gz_path = 'hf_customer_feedback.tar.gz'
extraction_path = 'hf_customer_feedback'

# Training parameters
N_EPOCHS = 1
BATCH_SIZE = 16

4.40.1
3.10.12


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [9]:
metric = load("rouge")

model_checkpoint = "t5-small"
finetuned_model_name = f"{model_checkpoint}-finetuned-feedback"

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

max_input_length = 1024
max_target_length = 128

In [10]:
def preprocess_function(examples, label="summary", text="feedback"):
    inputs = [prefix + doc for doc in examples[text]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples[label], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [12]:
with tarfile.open(tar_gz_path, 'r:gz') as tar:
    tar.extractall(path=extraction_path)

raw_datasets = load_from_disk(f'{extraction_path}/')
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['summary', 'feedback', 'id'],
        num_rows: 971
    })
    validation: Dataset({
        features: ['summary', 'feedback', 'id'],
        num_rows: 121
    })
    test: Dataset({
        features: ['summary', 'feedback', 'id'],
        num_rows: 122
    })
})

In [13]:
tokenized_datasets = raw_datasets.map(lambda x: preprocess_function(x, label="summary", text="feedback"), batched=True)
tokenized_datasets

Map:   0%|          | 0/971 [00:00<?, ? examples/s]

Map:   0%|          | 0/121 [00:00<?, ? examples/s]

Map:   0%|          | 0/122 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['summary', 'feedback', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 971
    })
    validation: Dataset({
        features: ['summary', 'feedback', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 121
    })
    test: Dataset({
        features: ['summary', 'feedback', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 122
    })
})

In [14]:
#model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    finetuned_model_name,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=N_EPOCHS,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()
# save model locally
trainer.save_model(finetuned_model_name)

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,No log,2.983158,24.9931,10.0881,21.9651,22.0687,16.4876
2,No log,2.182213,36.3348,17.5969,34.3034,34.2834,12.1653
3,No log,1.96071,43.7295,21.5907,41.8815,41.929,10.5372
4,No log,1.841159,48.7074,25.1744,46.8382,46.8399,10.405
5,No log,1.767351,50.1972,26.4116,48.1456,48.0538,10.2066
6,No log,1.71954,51.0984,27.8685,48.9483,49.0108,10.3554
7,No log,1.683163,50.272,27.3168,48.4083,48.4307,10.0331
8,No log,1.655822,50.6829,27.5132,48.6684,48.735,10.2727
9,2.363000,1.635744,50.0286,27.0674,48.0211,48.0783,10.1736
10,2.363000,1.623971,50.8207,26.8345,48.6528,48.6903,10.1983




model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

events.out.tfevents.1714573944.1c1f59b624b8.1026.0:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
def generate_summary(document, model_checkpoint):
  model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
  if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
  else:
    prefix = ""
  prompt = prefix+document
  model_inputs = tokenizer(prompt, max_length=max_input_length, truncation=True, return_tensors="pt")
  input_ids = model_inputs["input_ids"]
  outputs = model.generate(input_ids, max_length=max_target_length)
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
  return summary

def write_jsonl(data, fp):
  with open(fp, 'w') as file:
      for key, value in data.items():
          json_object = {'id': key, 'summary': value}
          json_line = json.dumps(json_object)
          file.write(json_line + '\n')
  return


def generate_summaries(test_data, model_checkpoint, fp):
  summaries = {}
  for sample in test_data:
    document = sample["feedback"]
    id = sample["id"]
    summary = generate_summary(document=document, model_checkpoint=model_checkpoint)
    summaries[id] = summary

  write_jsonl(data=summaries, fp=fp)
  return summaries


# Eval

In [16]:
test_data = raw_datasets["test"]
test_data

Dataset({
    features: ['summary', 'feedback', 'id'],
    num_rows: 122
})

## T5 small with fine-tuning

In [17]:
model_checkpoint = finetuned_model_name
t5_small_finetuned_fp = "t5-small-finetuned-summaries.jsonl"
t5_small_finetuned_summaries = generate_summaries(test_data=test_data, model_checkpoint=model_checkpoint, fp=t5_small_finetuned_fp)
print(len(t5_small_finetuned_summaries))

122


## T5-small without fine-tuning

In [18]:
model_checkpoint = "t5-small"
t5_small_fp = "t5-small-summaries.jsonl"
t5_small_summaries = generate_summaries(test_data=test_data, model_checkpoint=model_checkpoint, fp=t5_small_fp)
print(len(t5_small_summaries))

122


## T5-Base

In [19]:
model_checkpoint = "t5-base"
t5_base_fp = "t5-base-summaries.jsonl"
t5_base_summaries = generate_summaries(test_data=test_data, model_checkpoint=model_checkpoint, fp=t5_base_fp)
print(len(t5_base_summaries))

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

122


## Rouge and Bleu Scores

In [21]:
def score_summaries(summaries):
  ids = list(summaries.keys())
  predictions = [summaries[id] for id in ids]
  references = [test_data.filter(lambda x: x['id'] == id)['summary'][0] for id in ids]

  ## calculate rouge scores on summaries
  metric = load("rouge")
  metric.add_batch(predictions=predictions, references=references)
  rouge_results = metric.compute()

  # calculate bleu scores on summaries
  metric = load("bleu")
  metric.add_batch(predictions=predictions, references=references)
  bleu_results = metric.compute()
  bleu_results

  scores = {"rouge":rouge_results, "bleu":bleu_results}
  return scores

t5_small_finetuned_scores = score_summaries(summaries=t5_small_finetuned_summaries)
t5_small_scores = score_summaries(summaries=t5_small_summaries)
t5_base_scores = score_summaries(summaries=t5_base_summaries)
all_scores = {"t5-small-finetuned":t5_small_finetuned_scores, "t5-small":t5_small_scores, "t5-base":t5_base_scores}
all_scores_df  = pd.DataFrame(all_scores)
all_scores_df.to_csv("all_scores.csv")
all_scores_df

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter:   0%|          | 0/122 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Unnamed: 0,t5-small-finetuned,t5-small,t5-base
rouge,"{'rouge1': 0.2535707655705478, 'rouge2': 0.113...","{'rouge1': 0.17416097478483022, 'rouge2': 0.06...","{'rouge1': 0.17704215935638012, 'rouge2': 0.05..."
bleu,"{'bleu': 0.04300318371763852, 'precisions': [0...","{'bleu': 0.024443221851953585, 'precisions': [...","{'bleu': 0.020017147677267606, 'precisions': [..."


In [22]:
all_scores_df

Unnamed: 0,t5-small-finetuned,t5-small,t5-base
rouge,"{'rouge1': 0.2535707655705478, 'rouge2': 0.113...","{'rouge1': 0.17416097478483022, 'rouge2': 0.06...","{'rouge1': 0.17704215935638012, 'rouge2': 0.05..."
bleu,"{'bleu': 0.04300318371763852, 'precisions': [0...","{'bleu': 0.024443221851953585, 'precisions': [...","{'bleu': 0.020017147677267606, 'precisions': [..."


In [23]:
import pickle
with open('all_scores.pickle', 'wb') as f:
    pickle.dump(all_scores, f)