# Fine-tune FLAN-T5 for chat & dialogue summarization

In this blog, you will learn how to fine-tune [google/flan-t5-xl](https://huggingface.co/google/flan-t5-xl) for chat & dialogue summarization using Hugging Face Transformers. If you already know T5, FLAN-T5 is just better at everything. For the same number of parameters, these models have been fine-tuned on more than 1000 additional tasks covering also more languages. 

In this example we will use the [samsum](https://huggingface.co/datasets/samsum) dataset a collection of about 16k messenger-like conversations with summaries. Conversations were created and written down by linguists fluent in English.

You will learn how to:

1. [Setup Development Environment](#1-setup-development-environment)
2. [Load and prepare samsum dataset](#2-load-and-prepare-samsum-dataset)
3. [Fine-tune and evaluate FLAN-T5](#3-fine-tune-and-evaluate-flan-t5)
4. [Run Inference and summarize ChatGPT dialogues](#4-run-inference-and-summarize-chatgpt-dialogues)

Before we can start, make sure you have a [Hugging Face Account](https://huggingface.co/join) to save artifacts and experiments. 

## Quick intro: FLAN-T5, just a better T5

FLAN-T5 released with the [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf) paper is an enhanced version of T5 that has been finetuned in a mixture of tasks. The paper explores instruction finetuning with a particular focus on (1) scaling the number of tasks, (2) scaling the model size, and (3) finetuning on chain-of-thought data. The paper discovers that overall instruction finetuning is a general method for improving the performance and usability of pretrained language models. 

![flan-t5](../assets/flan-t5.png)

* Paper: https://arxiv.org/abs/2210.11416
* Official repo: https://github.com/google-research/t5x

--- 

Now we know what FLAN-T5 is, let's get started. 🚀

_Note: This tutorial was created and run on a g4dn.xlarge AWS EC2 Instance including a NVIDIA T4._

## 1. Setup Development Environment

Our first step is to install the Hugging Face Libraries, including transformers and datasets. Running the following cell will install all the required packages. 

In [3]:
import torch

if torch.cuda.is_available():
    print("GPU is available")
    print(f"GPU device name: {torch.cuda.get_device_name(0)}")
else:
    print("GPU is not available")


GPU is available
GPU device name: NVIDIA GeForce RTX 3090


In [2]:
# python
!pip install pytesseract transformers datasets rouge-score nltk tensorboard py7zr --upgrade

Collecting pytesseract
  Obtaining dependency information for pytesseract from https://files.pythonhosted.org/packages/7a/33/8312d7ce74670c9d39a532b2c246a853861120486be9443eebf048043637/pytesseract-0.3.13-py3-none-any.whl.metadata
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/75/35/07c9879163b603f0e464b0f6e6e628a2340cfc7cdc5ca8e7d52d776710d4/transformers-4.44.2-py3-none-any.whl.metadata
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m366.2 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting datasets
  Obtaining dependency information for datasets from https://files.pythonhosted.org/packages/72/b3/33c4ad44fa020e3757e9b2fad8a5de53d9079b501e6bbc45bdd18f82f893/datasets-2.21.0-py3-none-any.whl.metadata
  Downloading datasets-2.21.0-py3

In [None]:
# install git-fls for pushing model and logs to the hugging face hub
!sudo apt-get install git-lfs --yes

This example will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. To be able to push our model to the Hub, you need to register on the [Hugging Face](https://huggingface.co/join). 
If you already have an account, you can skip this step. 
After you have an account, we will use the `notebook_login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk. 

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
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
T5_DialogueSum

## 2. Load and prepare dialogueSum dataset from local
- This DialogueSum dataset was originally in English but was translated into Korean by teachers using the Solar API for educational purposes. However, the translation seemed somewhat unnatural for native Korean speakers, so I used the Solar API to retranslate it into English to facilitate a more accurate summarization.

To load the `dialogueSum` dataset, we use the `load_dataset()` method from the 🤗 Datasets library.


In [6]:
dataset_id = "dialoguSum_Solar_koen"
# huggingface hub model id
model_id="google/flan-t5-large"

In [8]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset('csv', data_files={'train': "/data/ephemeral/home/data/train_en.csv", 'val': "/data/ephemeral/home/data/dev_en.csv"})

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['val'])}")

# Train dataset size: 12457
# Test dataset size: 499

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

Train dataset size: 12457
Test dataset size: 499


In [9]:
dataset['train']

Dataset({
    features: ['fname', 'dialogue', 'summary', 'topic', 'dialogue_en', 'summary_en', 'topic_en'],
    num_rows: 12457
})

Lets checkout an example of the dataset.

In [12]:
from random import randrange        


sample = dataset['train'][randrange(len(dataset["train"]))]
print(f"dialogue: \n{sample['dialogue_en']}\n---------------")
print(f"summary: \n{sample['summary_en']}\n---------------")

dialogue: 
#Person1#: Hi, Lily. Are you enjoying your graduation party?
#Person2#: Yes, everyone here seems to be having fun. Do you have any plans for the future?
#Person1#: Well, I'm interested in finance and my uncle has a company in Hong Kong, so I've decided to go to the University of Hong Kong.
#Person2#: I hear that's great. Hong Kong is an international financial center. Surely you'll go far there.
#Person1#: What about you? What are you going to do?
#Person2#: I want to go to a university in Beijing.
#Person1#: What do you want to major in, computer science or medicine?
#Person2#: I prefer medicine. It's always been my dream to become a doctor.
---------------
summary: 
#Person1# is going to study finance at the University of Hong Kong. Lily is going to study medicine at a university in Beijing.
---------------


To train our model we need to convert our inputs (text) to token IDs. This is done by a 🤗 Transformers Tokenizer. If you are not sure what this means check out [chapter 6](https://huggingface.co/course/chapter6/1?fw=tf) of the Hugging Face Course.

In [13]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]



before we can start training we need to preprocess our data. Abstractive Summarization is a text2text-generation task. This means our model will take a text as input and generate a summary as output. For this we want to understand how long our input and output will be to be able to efficiently batch our data. 

In [15]:
from datasets import concatenate_datasets

# The maximum total input sequence length after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["val"]]).map(lambda x: tokenizer(x["dialogue_en"], truncation=True), batched=True, remove_columns=["dialogue_en", "summary_en"])
max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Max source length: {max_source_length}")
min_source_length = min([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Min source length: {min_source_length}")


# The maximum total sequence length for target text after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([dataset["train"], dataset["val"]]).map(lambda x: tokenizer(x["summary_en"], truncation=True), batched=True, remove_columns=["dialogue_en", "summary_en"])
max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
print(f"Max target length: {max_target_length}")
min_target_length = min([len(x) for x in tokenized_targets["input_ids"]])
print(f"Min target length: {min_target_length}")

Map:   0%|          | 0/12956 [00:00<?, ? examples/s]

Max source length: 512
Min source length: 49
Max target length: 231
Min target length: 9


In [16]:
special_tokens = ['#CarNumber#', '#SSN#', '#PhoneNumber#', '#PassportNumber#', '#Email#', '#CardNumber#', '#Address#', '#DateOfBirth#', \
'#Person4#', '#Person7#', '#Person3#', '#Person2#', '#Person#', '#Person6#', '#Person5#', '#Person1#']
for token in special_tokens:
    if token in tokenizer.get_vocab():
        print(f"'{token}' is already in the vocabulary.")
    else:
        print(f"'{token}' is not in the vocabulary.")


'#CarNumber#' is not in the vocabulary.
'#SSN#' is not in the vocabulary.
'#PhoneNumber#' is not in the vocabulary.
'#PassportNumber#' is not in the vocabulary.
'#Email#' is not in the vocabulary.
'#CardNumber#' is not in the vocabulary.
'#Address#' is not in the vocabulary.
'#DateOfBirth#' is not in the vocabulary.
'#Person4#' is not in the vocabulary.
'#Person7#' is not in the vocabulary.
'#Person3#' is not in the vocabulary.
'#Person2#' is not in the vocabulary.
'#Person#' is not in the vocabulary.
'#Person6#' is not in the vocabulary.
'#Person5#' is not in the vocabulary.
'#Person1#' is not in the vocabulary.


In [17]:
original_vocab_size = len(tokenizer)

special_tokens = ['#CarNumber#', '#SSN#', '#PhoneNumber#', '#PassportNumber#', '#Email#', '#CardNumber#', '#Address#', '#DateOfBirth#', \
'#Person4#', '#Person7#', '#Person3#', '#Person2#', '#Person#', '#Person6#', '#Person5#', '#Person1#']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
new_vocab_size = len(tokenizer)

print(f"Original vocab size: {original_vocab_size}")
print(f"New vocab size: {new_vocab_size}")

Original vocab size: 32100
New vocab size: 32116


In [44]:
# Define a test sentence
sentence = dataset["train"]['dialogue_en'][0]


# Encode the sentence using the tokenizer, returning PyTorch tensors
sentence_encoded = tokenizer(sentence, 
                             max_length=max_source_length, 
                             padding="max_length", 
                             truncation=True, 
                             add_special_tokens=True)

# Decode the encoded sentence, skipping special tokens
sentence_decoded = tokenizer.decode(
        sentence_encoded["input_ids"], 
        max_length=max_target_length, 
        padding="max_length", 
        truncation=True, 
        add_special_tokens=True,
        skip_special_tokens=False
    )

# Print SENTENCE
print('SENTENCE:')
print(sentence)

# Print the encoded sentence's representation
print('\nENCODED SENTENCE:')
print(sentence_encoded["input_ids"])

# Print the decoded sentence
print('\nDECODED SENTENCE:')
print(sentence_decoded)

SENTENCE:
#Person1#: Hello, Mr. Smith. I'm Dr. Hawkins. Why are you here today?
#Person2#: I thought it would be a good idea to have a checkup.
#Person1#: I see, you haven't had one in five years. You should have one every year.
#Person2#: I know. But if nothing is wrong, why should I go to see a doctor?
#Person1#: The best way to avoid serious illness is to catch these early. So for your own good, come at least once a year.
#Person2#: I see.
#Person1#: Look here. Your eyes and ears seem fine. Take a deep breath. Do you smoke, Mr. Smith?
#Person2#: Yes.
#Person1#: As you know, smoking is the leading cause of lung cancer and heart disease. You really should quit.
#Person2#: I've tried hundreds of times, but I find it hard to break the habit.
#Person1#: We have classes and medications that can help. I'll give you more information before you leave.
#Person2#: OK, thank you, doctor.

ENCODED SENTENCE:
[32115, 3, 10, 8774, 6, 1363, 5, 3931, 5, 27, 31, 51, 707, 5, 12833, 77, 7, 5, 1615, 33, 

In [57]:
from transformers import AutoModelForSeq2SeqLM

# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

# Define a test sentence
sentence = dataset["train"]['dialogue_en'][10]
golden = dataset["train"]['summary_en'][10]

#instrunction = ["Please provide a detailed summary of the conversation, including all key points and important information. : " + sentence]
instruction = ["In this '#Person1#: Hello, Mr. Smith. I'm Dr. Hawkins.' dialogue, the speaker is #Person1#. \
    Summarize the conversation with a focus on the speakers, ensuring that each speaker's name or identifier, such as #Person1#, is accurately used as the subject in the summary. : " + sentence]
# Encode the sentence using the tokenizer, returning PyTorch tensors
sentence_encoded = tokenizer(instrunction, 
                             max_length=max_source_length, 
                             padding="max_length", 
                             truncation=True, 
                             add_special_tokens=True,
                             return_tensors="pt")  # Ensure tensors are returned for model input

# Generate the summary using the model
summary_ids = model.generate(
    sentence_encoded["input_ids"], 
    max_length=max_target_length, 
    num_beams=5,  # Optional: control the generation strategy
    early_stopping=True,  # Optional: stop early when all beams are finished
    no_repeat_ngram_size=2
)

# Decode the encoded sentence, skipping special tokens
sentence_decoded = tokenizer.decode(
    summary_ids[0],  # Select the first (and usually only) sequence generated
    skip_special_tokens=True  # Skip special tokens in the final output
    )

# Print the encoded sentence's representation
# print('\nENCODED SENTENCE:')
# print(sentence_encoded["input_ids"])

# Print the decoded sentence
print('\nDECODED SENTENCE:')
print(sentence_decoded)

# Print SENTENCE
print('\nGOLDEN:')
print(golden)


DECODED SENTENCE:
Orient will go to the store and buy sugar, four oranges and a half gallon of milk.

GOLDEN:
#Person1# asks #Person2# for a favor. #Person2# agrees and buys a small bag of sugar, six oranges, and a half-gallon of milk.


In [55]:
sentence

"#Person1#: Can you do me a favor?\n#Person2#: Sure. What is it?\n#Person1#: Can you go to the store for me? I need a few things.\n#Person2#: Okay. What do you want me to get for you?\n#Person1#: Well, can you get me some sugar?\n#Person2#: Okay. How much do you want me to get?\n#Person1#: Just a small bag. And I need some oranges.\n#Person2#: How many do you want?\n#Person1#: Well, let's see... about six.\n#Person2#: Is there anything else you need?\n#Person1#: Yes. We're also out of milk.\n#Person2#: Okay. How much do you want me to get? A gallon?\n#Person1#: No. I think a half gallon will be enough.\n#Person2#: Is that all you need?\n#Person1#: I think so. Did you remember everything?\n#Person2#: Yes. A small bag of sugar, four oranges, and a half gallon of milk.\n#Person1#: Do you have enough money?\n#Person2#: I think so.\n#Person1#: Thank you very much. I appreciate it."

In [18]:
def preprocess_function(sample,padding="max_length"):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue_en"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["summary_en"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore padding in the loss.
    if padding == "max_length":
        
        if isinstance(labels["input_ids"][0], list):  # Check if it is a list of lists
            print(f'labels["input_ids"][0]: {labels["input_ids"][0]}')
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        else:  # Handle single instance case
            print(f'labels["input_ids"]: {labels["input_ids"]}')
            labels["input_ids"] = [(l if l != tokenizer.pad_token_id else -100) for l in labels["input_ids"]]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['fname', 'dialogue', 'summary', 'topic', 'dialogue_en', 'summary_en', 'topic_en'])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")

Map:   0%|          | 0/12457 [00:00<?, ? examples/s]

labels["input_ids"][0]: [1363, 5, 3931, 19, 578, 3, 9, 1722, 3631, 6, 11, 707, 5, 12833, 77, 7, 1568, 7, 24, 3, 88, 43, 3, 9, 1722, 3631, 334, 215, 5, 707, 5, 12833, 77, 7, 56, 428, 1363, 5, 3931, 251, 81, 2287, 11, 11208, 24, 54, 199, 376, 10399, 10257, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
labels["input_ids"][0]: [32115, 11, 3, 32111, 1350, 81, 3059, 31, 7, 1112, 147, 97, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

Map:   0%|          | 0/499 [00:00<?, ? examples/s]

labels["input_ids"][0]: [32111, 65, 8565, 10882, 5, 37, 2472, 987, 7, 3, 32115, 81, 48, 11, 19, 352, 12, 1299, 3, 32111, 12, 3, 9, 5084, 4253, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']


## one dialogue test

## 3. Fine-tune and evaluate FLAN-T5

After we have processed our dataset, we can start training our model. Therefore we first need to load our [FLAN-T5](https://huggingface.co/models?search=flan-t5) from the Hugging Face Hub. In the example we are using a instance with a NVIDIA V100 meaning that we will fine-tune the `base` version of the model. 
_I plan to do a follow-up post on how to fine-tune the `xxl` version of the model using Deepspeed._


In [12]:
from transformers import AutoModelForSeq2SeqLM

# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

We want to evaluate our model during training. The `Trainer` supports evaluation during training by providing a `compute_metrics`.  
The most commonly used metrics to evaluate summarization task is [rogue_score](https://en.wikipedia.org/wiki/ROUGE_(metric)) short for Recall-Oriented Understudy for Gisting Evaluation). This metric does not behave like the standard accuracy: it will compare a generated summary against a set of reference summaries

We are going to use `evaluate` library to evaluate the `rogue` score.

In [13]:
import evaluate
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
nltk.download("punkt")

# Metric
metric = evaluate.load("rouge")

# helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

[nltk_data] Downloading package punkt to
[nltk_data]     /data/ephemeral/home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Before we can start training is to create a `DataCollator` that will take care of padding our inputs and labels. We will use the `DataCollatorForSeq2Seq` from the 🤗 Transformers library. 

In [14]:
from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)


The last step is to define the hyperparameters (`TrainingArguments`) we want to use for our training. We are leveraging the [Hugging Face Hub](https://huggingface.co/models) integration of the `Trainer` to automatically push our checkpoints, logs and metrics during training into a repository.

In [15]:
from huggingface_hub import HfFolder
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Hugging Face repository id
repository_id = f"{model_id.split('/')[1]}-{dataset_id}"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=repository_id,
    per_device_train_batch_size=1, #8
    per_device_eval_batch_size=1, #8
    predict_with_generate=True,
    fp16=False, # Overflows with fp16
    learning_rate=1e-5, #5e-5
    num_train_epochs=5,
    # logging & evaluation strategies
    logging_dir=f"{repository_id}/logs",
    logging_strategy="steps",
    logging_steps=100, #500
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    # metric_for_best_model="overall_f1",
    # push to hub parameters
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repository_id,
    hub_token=HfFolder.get_token(),
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["val"],
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


We can start our training by using the `train` method of the `Trainer`.

In [16]:
# Start training
trainer.train()

#2.16일 예상

KeyError: 0


![flan-t5-tensorboard](../assets/flan-t5-tensorboard.png)

Nice, we have trained our model. 🎉 Lets run evaluate the best model again on the test set.


In [None]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 819
  Batch size = 8


{'eval_loss': 1.3715944290161133,
 'eval_rouge1': 47.2358,
 'eval_rouge2': 23.5135,
 'eval_rougeL': 39.6266,
 'eval_rougeLsum': 43.3458,
 'eval_gen_len': 17.39072039072039,
 'eval_runtime': 108.99,
 'eval_samples_per_second': 7.514,
 'eval_steps_per_second': 0.945,
 'epoch': 5.0}

The best score we achieved is an `rouge1` score of `47.23`. 

Lets save our results and tokenizer to the Hugging Face Hub and create a model card. 

In [None]:
# Save our tokenizer and create model card
tokenizer.save_pretrained(repository_id)
trainer.create_model_card()
# Push the results to the hub
trainer.push_to_hub()

## 4. Run Inference

Now we have a trained model, we can use it to run inference. We will use the `pipeline` API from transformers and a `test` example from our dataset.

In [None]:
from transformers import pipeline
from random import randrange        

# load model and tokenizer from huggingface hub with pipeline
summarizer = pipeline("summarization", model="philschmid/flan-t5-base-samsum", device=0)

# select a random test sample
sample = dataset['test'][randrange(len(dataset["test"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")

# summarize dialogue
res = summarizer(sample["dialogue"])

print(f"flan-t5-base summary:\n{res[0]['summary_text']}")

dialogue: 
Abby: Have you talked to Miro?
Dylan: No, not really, I've never had an opportunity
Brandon: me neither, but he seems a nice guy
Brenda: you met him yesterday at the party?
Abby: yes, he's so interesting
Abby: told me the story of his father coming from Albania to the US in the early 1990s
Dylan: really, I had no idea he is Albanian
Abby: he is, he speaks only Albanian with his parents
Dylan: fascinating, where does he come from in Albania?
Abby: from the seacoast
Abby: Duress I believe, he told me they are not from Tirana
Dylan: what else did he tell you?
Abby: That they left kind of illegally
Abby: it was a big mess and extreme poverty everywhere
Abby: then suddenly the border was open and they just left 
Abby: people were boarding available ships, whatever, just to get out of there
Abby: he showed me some pictures, like <file_photo>
Dylan: insane
Abby: yes, and his father was among the people
Dylan: scary but interesting
Abby: very!
---------------
flan-t5-base summary:
A