# Training a summarization model

At the very beginning, it's essential to set up the enviroment

In [None]:
import os
os.environ["http_proxy"]="127.0.0.1:7890"
os.environ["https_proxy"]="127.0.0.1:7890"

In [None]:
%pip install py7zr

First, we need to prepare a multilanguage dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset("samsum")

Let's take a look at the size of this dataset

In [None]:
dataset

And some of the samples

In [None]:
def show_samples(dataset, num_samples=3, seed=42):
    sample = dataset["train"].shuffle(seed=seed).select(range(num_samples))
    for example in sample:
        print(f"\n'>> :Summary: {example['summary']}'")
        print(f"'>> Dialogue: {example['dialogue']}'")


show_samples(dataset)

The next thing to do is to choose a proper model: mT5

Preprocessing: to test the model and datasets, we use a small model

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "google/mT5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Let's see if our tokenizer works well.

In [None]:
inputs = tokenizer("I love reading the course of natural language processing!")
inputs

We can also transform the ids to words by using convert_ids_to_tokens

In [None]:
tokenizer.convert_ids_to_tokens(inputs.input_ids)

to ensure we will not input a too long text, we need to truncate the input with text_length

In [None]:
max_input_length = 2048
max_target_length = 300

def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["dialogue"],
        max_length=max_input_length,
        truncation=True
    )
    labels = tokenizer(
        examples["summary"],
        max_length=max_target_length,
        truncation=True
    )
    model_inputs["labels"]=labels["input_ids"]
    return model_inputs

With this function, our datasets is prepared.

In [None]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)

The next step is calculating the baselines, we often use ROUGE as a good metric parameter.

In [None]:
%pip install rouge_score

In [None]:
import evaluate

rouge_score =  evaluate.load("rouge")

We can check the rouge_score

In [None]:
generated_summary = "I absolutely loved reading the Hunger Games"
reference_summary = "I loved reading the Hunger Games"
scores = rouge_score.compute(
    predictions=[generated_summary], references=[reference_summary]
)
scores

create a baseline: the leading 3 sentences

In [None]:
import nltk
nltk.download("punkt")

In [None]:
from nltk.tokenize import sent_tokenize

def three_sentence_summary(text):
    return "\n".join(sent_tokenize(text)[:3])


print(three_sentence_summary(dataset["train"][1]["dialogue"]))

We use the lead-3 as our baseline

In [None]:
def evaluate_baseline(dataset, metric):
    summaries = [three_sentence_summary(text) for text in dataset["dialogue"]]
    return metric.compute(predictions=summaries, references=dataset["summary"])

In [None]:
score = evaluate_baseline(dataset["validation"], rouge_score)
score

The last thing: fine-tune

In [None]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
%pip install huggingface-hub

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from transformers import Seq2SeqTrainingArguments

batch_size = 8
num_train_epochs = 8
# Show the training loss with every epoch
logging_steps = len(tokenized_datasets["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

args = Seq2SeqTrainingArguments(
    output_dir=f"{model_name}-finetuned-amazon-en-es",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
    push_to_hub=True,
)

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

To see the score

In [None]:
trainer.evaluate()

We can also push our weights to hub:

In [None]:
trainer.push_to_hub(commit_message="Training complete", tags="summarization")

# How to use a fine-tuned model?

load it to memory

In [None]:
from transformers import pipeline

hub_model_id = "Corkri/mt5-small-finetuned-amazon-en-es"
summarizer = pipeline("summarization", model=hub_model_id)

Here's a example:

In [None]:
def print_summary(idx):
    review = dataset["test"][idx]["review_body"]
    title = dataset["test"][idx]["review_title"]
    summary = summarizer(dataset["test"][idx]["review_body"])[0]["summary_text"]
    print(f"'>>> Review: {review}'")
    print(f"\n'>>> Title: {title}'")
    print(f"\n'>>> Summary: {summary}'")