# Fine-tune FLAN-T5 for chat & dialogue summarization

In this blog, you will learn how to fine-tune [google/flan-t5-xl](https://huggingface.co/google/flan-t5-xl) for chat & dialogue summarization using Hugging Face Transformers. If you already know T5, FLAN-T5 is just better at everything. For the same number of parameters, these models have been fine-tuned on more than 1000 additional tasks covering also more languages. 

In this example we will use the [samsum](https://huggingface.co/datasets/samsum) dataset a collection of about 16k messenger-like conversations with summaries. Conversations were created and written down by linguists fluent in English.

You will learn how to:

1. [Setup Development Environment](#1-setup-development-environment)
2. [Load and prepare samsum dataset](#2-load-and-prepare-samsum-dataset)
3. [Fine-tune and evaluate FLAN-T5](#3-fine-tune-and-evaluate-flan-t5)
4. [Run Inference and summarize ChatGPT dialogues](#4-run-inference-and-summarize-chatgpt-dialogues)

Before we can start, make sure you have a [Hugging Face Account](https://huggingface.co/join) to save artifacts and experiments. 

## Quick intro: FLAN-T5, just a better T5

FLAN-T5 released with the [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf) paper is an enhanced version of T5 that has been finetuned in a mixture of tasks. The paper explores instruction finetuning with a particular focus on (1) scaling the number of tasks, (2) scaling the model size, and (3) finetuning on chain-of-thought data. The paper discovers that overall instruction finetuning is a general method for improving the performance and usability of pretrained language models. 

![flan-t5](../assets/flan-t5.png)

* Paper: https://arxiv.org/abs/2210.11416
* Official repo: https://github.com/google-research/t5x

--- 

Now we know what FLAN-T5 is, let's get started. 🚀

_Note: This tutorial was created and run on a g4dn.xlarge AWS EC2 Instance including a NVIDIA T4._

## 1. Setup Development Environment

Our first step is to install the Hugging Face Libraries, including transformers and datasets. Running the following cell will install all the required packages. 

In [1]:
import torch

if torch.cuda.is_available():
    print("GPU is available")
    print(f"GPU device name: {torch.cuda.get_device_name(0)}")
else:
    print("GPU is not available")


GPU is available
GPU device name: NVIDIA GeForce RTX 3090


In [2]:
# python
!pip install pytesseract transformers datasets rouge-score nltk tensorboard py7zr --upgrade

Collecting pytesseract
  Obtaining dependency information for pytesseract from https://files.pythonhosted.org/packages/7a/33/8312d7ce74670c9d39a532b2c246a853861120486be9443eebf048043637/pytesseract-0.3.13-py3-none-any.whl.metadata
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/75/35/07c9879163b603f0e464b0f6e6e628a2340cfc7cdc5ca8e7d52d776710d4/transformers-4.44.2-py3-none-any.whl.metadata
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m366.2 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting datasets
  Obtaining dependency information for datasets from https://files.pythonhosted.org/packages/72/b3/33c4ad44fa020e3757e9b2fad8a5de53d9079b501e6bbc45bdd18f82f893/datasets-2.21.0-py3-none-any.whl.metadata
  Downloading datasets-2.21.0-py3

In [None]:
# install git-fls for pushing model and logs to the hugging face hub
!sudo apt-get install git-lfs --yes

This example will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. To be able to push our model to the Hub, you need to register on the [Hugging Face](https://huggingface.co/join). 
If you already have an account, you can skip this step. 
After you have an account, we will use the `notebook_login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk. 

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
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
T5_DialogueSum

## 2. Load and prepare dialogueSum dataset from local
- This DialogueSum dataset was originally in English but was translated into Korean by teachers using the Solar API for educational purposes. However, the translation seemed somewhat unnatural for native Korean speakers, so I used the Solar API to retranslate it into English to facilitate a more accurate summarization.

To load the `dialogueSum` dataset, we use the `load_dataset()` method from the 🤗 Datasets library.


In [7]:
dataset_id = "dialoguSum_Solar_koen"
# huggingface hub model id
model_id="google/flan-t5-large"

In [3]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset('csv', data_files={'train': "/data/ephemeral/home/data/train_en.csv", 'val': "/data/ephemeral/home/data/dev_en.csv"})

print(f"Train dataset size: {len(dataset['train'])}")
print(f"val dataset size: {len(dataset['val'])}")

# Train dataset size: 12457
# Test dataset size: 499

Train dataset size: 12457
val dataset size: 499


In [4]:
dataset['train']

Dataset({
    features: ['fname', 'dialogue', 'summary', 'topic', 'dialogue_en', 'summary_en', 'topic_en'],
    num_rows: 12457
})

Lets checkout an example of the dataset.

In [5]:
from random import randrange        


sample = dataset['train'][randrange(len(dataset["train"]))]
print(f"dialogue: \n{sample['dialogue_en']}\n---------------")
print(f"summary: \n{sample['summary_en']}\n---------------")

dialogue: 
#Person1#: It's amazing how international business has developed. Take my store, for example. On any given day, you can find goods from more than 20 different countries on our shelves.
#Person2#: How many different types of products do you import from China?
#Person1#: China certainly supplies the majority of our product inventory. We import more than 40 different items from China. Most of the imports from China are low-grade plastics or toys. Japan exports a lot of electronics, and Germany produces excellent machinery products.
#Person2#: Do you import any food items?
#Person1#: Generally speaking, food items are difficult to import. Food with a short shelf life is likely to spoil during the time it takes to ship it from one place to another. The only food items we import are specialty canned or preserved foods, because these products have a longer shelf life.
---------------
summary: 
#Person1# tells #Person2# that people can find goods imported from more than 20 different

To train our model we need to convert our inputs (text) to token IDs. This is done by a 🤗 Transformers Tokenizer. If you are not sure what this means check out [chapter 6](https://huggingface.co/course/chapter6/1?fw=tf) of the Hugging Face Course.

In [77]:
!pip install SentencePiece

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting SentencePiece
  Obtaining dependency information for SentencePiece from https://files.pythonhosted.org/packages/a6/27/33019685023221ca8ed98e8ceb7ae5e166032686fa3662c68f1f1edf334e/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m0m
[?25hInstalling collected packages: SentencePiece
Successfully installed SentencePiece-0.2.0
[0m

In [8]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)



before we can start training we need to preprocess our data. Abstractive Summarization is a text2text-generation task. This means our model will take a text as input and generate a summary as output. For this we want to understand how long our input and output will be to be able to efficiently batch our data. 

In [9]:
from datasets import concatenate_datasets

# The maximum total input sequence length after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["val"]]).map(lambda x: tokenizer(x["dialogue_en"], truncation=True), batched=True, remove_columns=["dialogue_en", "summary_en"])
max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Max source length: {max_source_length}")
min_source_length = min([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Min source length: {min_source_length}")


# The maximum total sequence length for target text after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([dataset["train"], dataset["val"]]).map(lambda x: tokenizer(x["summary_en"], truncation=True), batched=True, remove_columns=["dialogue_en", "summary_en"])
max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
print(f"Max target length: {max_target_length}")
min_target_length = min([len(x) for x in tokenized_targets["input_ids"]])
print(f"Min target length: {min_target_length}")

Max source length: 512
Min source length: 49
Max target length: 231
Min target length: 9


In [10]:
special_tokens = ['#CarNumber#', '#SSN#', '#PhoneNumber#', '#PassportNumber#', '#Email#', '#CardNumber#', '#Address#', '#DateOfBirth#', \
'#Person4#', '#Person7#', '#Person3#', '#Person2#', '#Person#', '#Person6#', '#Person5#', '#Person1#']
for token in special_tokens:
    if token in tokenizer.get_vocab():
        print(f"'{token}' is already in the vocabulary.")
    else:
        print(f"'{token}' is not in the vocabulary.")


'#CarNumber#' is not in the vocabulary.
'#SSN#' is not in the vocabulary.
'#PhoneNumber#' is not in the vocabulary.
'#PassportNumber#' is not in the vocabulary.
'#Email#' is not in the vocabulary.
'#CardNumber#' is not in the vocabulary.
'#Address#' is not in the vocabulary.
'#DateOfBirth#' is not in the vocabulary.
'#Person4#' is not in the vocabulary.
'#Person7#' is not in the vocabulary.
'#Person3#' is not in the vocabulary.
'#Person2#' is not in the vocabulary.
'#Person#' is not in the vocabulary.
'#Person6#' is not in the vocabulary.
'#Person5#' is not in the vocabulary.
'#Person1#' is not in the vocabulary.


In [11]:
original_vocab_size = len(tokenizer)

special_tokens = ['#CarNumber#', '#SSN#', '#PhoneNumber#', '#PassportNumber#', '#Email#', '#CardNumber#', '#Address#', '#DateOfBirth#', \
'#Person4#', '#Person7#', '#Person3#', '#Person2#', '#Person#', '#Person6#', '#Person5#', '#Person1#']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
new_vocab_size = len(tokenizer)

print(f"Original vocab size: {original_vocab_size}")
print(f"New vocab size: {new_vocab_size}")

# Original vocab size: 32100
# New vocab size: 32116

Original vocab size: 32100
New vocab size: 32116


In [6]:
# token 추가
import pandas as pd

Original_vocab_size = len(tokenizer)
print(f"Original vocab size: {Original_vocab_size}")
# Step 1: Extract unique words from the dataset
unique_words1 = set()
for sentence in dataset['train']['dialogue_en']:
    words = sentence.split()  # Simple split, you might want to use a tokenizer for better results
    unique_words1.update(words)
unique_words2 = set()
for sentence in dataset['train']['summary_en']:
    words = sentence.split()  # Simple split, you might want to use a tokenizer for better results
    unique_words2.update(words)

unique_words3 = set()
for sentence in dataset['val']['dialogue_en']:
    words = sentence.split()  # Simple split, you might want to use a tokenizer for better results
    unique_words3.update(words)
unique_words4 = set()
for sentence in dataset['val']['summary_en']:
    words = sentence.split()  # Simple split, you might want to use a tokenizer for better results
    unique_words4.update(words)

test = pd.read_csv(r'/data/ephemeral/home/data/test_en.csv')
unique_words5 = set()
for sentence in test['dialogue_en']:
    words = sentence.split()  # Simple split, you might want to use a tokenizer for better results
    unique_words5.update(words)    
    
# Step 2: Add these words to the tokenizer vocabulary
# The tokenizer will automatically handle the splitting and add only those not already in the vocab
tokenizer.add_tokens(list(unique_words1))
tokenizer.add_tokens(list(unique_words2))
tokenizer.add_tokens(list(unique_words3))
tokenizer.add_tokens(list(unique_words4))
tokenizer.add_tokens(list(unique_words5))
tokenizer.add_tokens(list(special_tokens))
# Step 3: Check the new vocabulary size
new_vocab_size = len(tokenizer)
print(f"New vocab size: {new_vocab_size}")

Original vocab size: 32100
New vocab size: 92921


In [12]:
# 작동 잘 되는지 확인
# Define a test sentence
sentence = dataset["train"]['dialogue_en'][0]


# Encode the sentence using the tokenizer, returning PyTorch tensors
sentence_encoded = tokenizer(sentence, 
                             max_length=max_source_length, 
                             padding="max_length", 
                             truncation=True, 
                             add_special_tokens=True)

# Decode the encoded sentence, skipping special tokens
sentence_decoded = tokenizer.decode(
        sentence_encoded["input_ids"], 
        max_length=max_target_length, 
        padding="max_length", 
        truncation=True, 
        add_special_tokens=True,
        skip_special_tokens=False
    )

# Print SENTENCE
print('SENTENCE:')
print(sentence)

# Print the encoded sentence's representation
print('\nENCODED SENTENCE:')
print(sentence_encoded["input_ids"])

# Print the decoded sentence
print('\nDECODED SENTENCE:')
print(sentence_decoded)

SENTENCE:
#Person1#: Hello, Mr. Smith. I'm Dr. Hawkins. Why are you here today?
#Person2#: I thought it would be a good idea to have a checkup.
#Person1#: I see, you haven't had one in five years. You should have one every year.
#Person2#: I know. But if nothing is wrong, why should I go to see a doctor?
#Person1#: The best way to avoid serious illness is to catch these early. So for your own good, come at least once a year.
#Person2#: I see.
#Person1#: Look here. Your eyes and ears seem fine. Take a deep breath. Do you smoke, Mr. Smith?
#Person2#: Yes.
#Person1#: As you know, smoking is the leading cause of lung cancer and heart disease. You really should quit.
#Person2#: I've tried hundreds of times, but I find it hard to break the habit.
#Person1#: We have classes and medications that can help. I'll give you more information before you leave.
#Person2#: OK, thank you, doctor.

ENCODED SENTENCE:
[32115, 3, 10, 8774, 6, 1363, 5, 3931, 5, 27, 31, 51, 707, 5, 12833, 77, 7, 5, 1615, 33, 

## Summarizing Using Prompt Engineering

### Applying Zero Shot Inference

In [16]:
# zero shot
from transformers import AutoModelForSeq2SeqLM

# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

# Define a test sentence
sentence = dataset["train"]['dialogue_en'][20]
golden = dataset["train"]['summary_en'][20]

instruction = f"""
Dialogue:

{sentence}

What was going on?
"""
#instruction = ["Please summarize the conversation by clearly stating what each speaker did or said. : " +sentence]
# instruction = ["In this '#Person1#: Hello, Mr. Smith. I'm Dr. Hawkins.' dialogue, the speaker is #Person1#. \
#     Summarize the conversation with a focus on the speakers, ensuring that each speaker's name or identifier, such as #Person1#, is accurately used as the subject in the summary. : " + sentence]
# Encode the sentence using the tokenizer, returning PyTorch tensors
sentence_encoded = tokenizer(instruction, 
                             max_length=max_source_length, 
                             padding="max_length", 
                             truncation=True, 
                             add_special_tokens=True,
                             return_tensors="pt")  # Ensure tensors are returned for model input

# Generate the summary using the model
summary_ids = model.generate(
    sentence_encoded["input_ids"], 
    max_length=max_target_length, 
    min_length=40, 
    num_beams=5,  # Optional: control the generation strategy
    early_stopping=True,  # Optional: stop early when all beams are finished
    no_repeat_ngram_size=2
)

# Decode the encoded sentence, skipping special tokens
sentence_decoded = tokenizer.decode(
    summary_ids[0],  # Select the first (and usually only) sequence generated
    skip_special_tokens=True  # Skip special tokens in the final output
    )

# Print the encoded sentence's representation
print('\nENCODED SENTENCE:')
print(sentence_encoded["input_ids"])

# Print the decoded sentence
print('\nDECODED SENTENCE:')
print(sentence_decoded)

# Print SENTENCE
print('\nGOLDEN:')
print(golden)



ENCODED SENTENCE:
tensor([[ 5267, 10384,    10,     3, 32115,     3,    10,   571,   103,    27,
          5026,   747,    48,  3143,    58,    27,   214,   132,    31,     7,
             3,     9, 19540,  5775,     5,     3, 32111,     3,    10,   363,
            33,    25,   692,    58,     3, 32115,     3,    10,    27,    31,
            51,   652, 13205,     6,   149,   405,    34,   320,    58,     3,
         32111,     3,    10,    27,   317,    25,    31,    60,  1119,    12,
           129, 13205,     5,  3963,    25,  2612,    62,    31,    60,    16,
             3,     9,   443,    30,     8,  1373,    58,     3, 32115,     3,
            10,    27,    31,    51,   207,    44,    48,     5,   465,    80,
            54,   217,   140,     5,     3, 32111,     3,    10,  1521,    25,
         29744,   140,    58,   148,    31,    60,   352,    12,  1137,    46,
          3125,   116,   151,  8876,    16,     3,     9,  1123,    55,     3,
         32115,     3,    10,   9

### Applying One Shot Inference

In [21]:
def make_prompt(example_indices_full, example_index_to_summarize):
    prompt = ''
    for index in example_indices_full:
        dialogue = dataset['train']['dialogue_en'][index]
        summary = dataset['train']['summary_en'][index]

        # The stop sequence '{summary}\n\n\n' is important for FLAN-T5. Other models may have their own preferred stop sequence.
        prompt += f"""
Dialogue:

{dialogue}

What was going on?
{summary}


"""

    dialogue = dataset['train']['dialogue_en'][example_index_to_summarize]

    prompt += f"""
Dialogue:

{dialogue}

What was going on?
"""

    return prompt

In [24]:
example_indices_full = [20]
example_index_to_summarize = 100

one_shot_prompt = make_prompt(example_indices_full, example_index_to_summarize)

print(one_shot_prompt)


Dialogue:

#Person1#: How do I recline this seat? I know there's a lever somewhere.
#Person2#: What are you doing?
#Person1#: I'm getting dressed, how does it look?
#Person2#: I think you're trying to get dressed. Did you forget we're in a car on the road?
#Person1#: I'm good at this. No one can see me.
#Person2#: Are you kidding me? You're going to cause an accident when people stare in awe!
#Person1#: Alright, pull over at that gas station. I'll get dressed in the ladies room.
#Person2#: I'd be happy to do that.

What was going on?
#Person1# is getting dressed in the car, and #Person2# warns her not to. #Person1# will get dressed at the gas station.



Dialogue:

#Person1#: I have a problem with my cable.
#Person2#: What is the problem?
#Person1#: My cable has not been working since last week or so.
#Person2#: The cable is down at the moment. We are really sorry about that.
#Person1#: When will it be working again?
#Person2#: It should be working again in a few days.
#Person1#: Do I

In [27]:

# model understanding more context of the conversation with one shot inference

summary = dataset['train']['summary_en'][example_index_to_summarize]

inputs = tokenizer(one_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

#print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
#print(dash_line)
print(f'MODEL GENERATION - ONE SHOT:\n{output}')

BASELINE HUMAN SUMMARY:
#Person1# is having trouble with their cable. #Person2# promises to get it working again, and that #Person1# won't have to pay while the cable is down.

MODEL GENERATION - ONE SHOT:
Severin's cable hasn't been working for a few days. He will get a discount while the cable is down.


### Applying few Shot Inference

In [35]:
example_indices_full = [11, 21, 51]
example_index_to_summarize = 101

few_shot_prompt = make_prompt(example_indices_full, example_index_to_summarize)

print(few_shot_prompt)


Dialogue:

#Person1#: Look! It's a picture of Mom in her cap and gown.
#Person2#: Isn't that great! That's when she got her Master's degree from Miami University.
#Person1#: Yes, we're all so proud of her.
#Person2#: Oh, I like this one of all of you together. Do you have the negative film? Can I have a copy?
#Person1#: Sure, I'll make one for you. Do you want a print?
#Person2#: No. I want a slide, I have a new projector.
#Person1#: I'd like to see that.
#Person2#: Make me a wallet size print too.
#Person1#: You bet.

What was going on?
#Person2# thinks the photos are beautiful and asks #Person1# for a slide and a wallet-sized print.



Dialogue:

#Person1#: We should check in at the Air China counter half an hour before takeoff, Joy.
#Person2#: Yeah, I know. The boarding time on the ticket is 17:05, and it's 16:15 now. I think we have enough time.
#Person1#: Do we need to show our IDs when we check in?
#Person2#: Yeah, that's a must.
#Person1#: What about our luggage?
#Person2#: We 

In [36]:
summary = dataset['train']['summary_en'][example_index_to_summarize]

inputs = tokenizer(few_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

#print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
#print(dash_line)
print(f'MODEL GENERATION - FEW SHOT:\n{output}')

BASELINE HUMAN SUMMARY:
#Person2# is looking for an MP-3 player. #Person1# recommends a pioneer and #Person2# chooses yellow.

MODEL GENERATION - FEW SHOT:
rien is looking for an MP-3 player. He wants to buy it in yellow.


In [37]:
print(model)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1024, out_features=2816, bias=False)
              (wi_1): Linear(in_features=1024, out_features=2816, bias=False)
       

In [17]:
def preprocess_function(sample,padding="max_length"):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue_en"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["summary_en"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore padding in the loss.
    if padding == "max_length":
        
        if isinstance(labels["input_ids"][0], list):  # Check if it is a list of lists
            print(f'labels["input_ids"][0]: {labels["input_ids"][0]}')
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        else:  # Handle single instance case
            print(f'labels["input_ids"]: {labels["input_ids"]}')
            labels["input_ids"] = [(l if l != tokenizer.pad_token_id else -100) for l in labels["input_ids"]]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['fname', 'dialogue', 'summary', 'topic', 'dialogue_en', 'summary_en', 'topic_en'])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")

Map:   0%|          | 0/12457 [00:00<?, ? examples/s]

labels["input_ids"][0]: [1263, 107, 2862, 117, 458, 114, 1312, 2616, 108, 111, 982, 107, 28199, 10095, 120, 178, 133, 114, 1312, 2616, 290, 232, 107, 982, 107, 28199, 138, 361, 1263, 107, 2862, 257, 160, 1745, 111, 6098, 120, 137, 225, 342, 7209, 6003, 107, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
labels["input_ids"][0]: [1768, 33892, 740, 4009, 111, 1768, 33892, 522, 4009, 1002, 160, 3227, 131, 116, 852, 204, 166, 107, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

Map:   0%|          | 0/499 [00:00<?, ? examples/s]

labels["input_ids"][0]: [1768, 33892, 522, 4009, 148, 5200, 5985, 107, 139, 2214, 6937, 1768, 33892, 740, 4009, 160, 136, 111, 117, 313, 112, 1053, 1768, 33892, 522, 4009, 112, 114, 9577, 3192, 107, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']


## 3. Fine-tune and evaluate FLAN-T5

After we have processed our dataset, we can start training our model. Therefore we first need to load our [FLAN-T5](https://huggingface.co/models?search=flan-t5) from the Hugging Face Hub. In the example we are using a instance with a NVIDIA V100 meaning that we will fine-tune the `base` version of the model. 
_I plan to do a follow-up post on how to fine-tune the `xxl` version of the model using Deepspeed._


In [12]:
from transformers import AutoModelForSeq2SeqLM

# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

We want to evaluate our model during training. The `Trainer` supports evaluation during training by providing a `compute_metrics`.  
The most commonly used metrics to evaluate summarization task is [rogue_score](https://en.wikipedia.org/wiki/ROUGE_(metric)) short for Recall-Oriented Understudy for Gisting Evaluation). This metric does not behave like the standard accuracy: it will compare a generated summary against a set of reference summaries

We are going to use `evaluate` library to evaluate the `rogue` score.

In [13]:
import evaluate
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
nltk.download("punkt")

# Metric
metric = evaluate.load("rouge")

# helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

[nltk_data] Downloading package punkt to
[nltk_data]     /data/ephemeral/home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Before we can start training is to create a `DataCollator` that will take care of padding our inputs and labels. We will use the `DataCollatorForSeq2Seq` from the 🤗 Transformers library. 

In [14]:
from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)


The last step is to define the hyperparameters (`TrainingArguments`) we want to use for our training. We are leveraging the [Hugging Face Hub](https://huggingface.co/models) integration of the `Trainer` to automatically push our checkpoints, logs and metrics during training into a repository.

In [15]:
from huggingface_hub import HfFolder
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Hugging Face repository id
repository_id = f"{model_id.split('/')[1]}-{dataset_id}"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=repository_id,
    per_device_train_batch_size=1, #8
    per_device_eval_batch_size=1, #8
    predict_with_generate=True,
    fp16=False, # Overflows with fp16
    learning_rate=1e-5, #5e-5
    num_train_epochs=5,
    # logging & evaluation strategies
    logging_dir=f"{repository_id}/logs",
    logging_strategy="steps",
    logging_steps=100, #500
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    # metric_for_best_model="overall_f1",
    # push to hub parameters
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repository_id,
    hub_token=HfFolder.get_token(),
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["val"],
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


We can start our training by using the `train` method of the `Trainer`.

In [16]:
# Start training
trainer.train()

#2.16일 예상

KeyError: 0


![flan-t5-tensorboard](../assets/flan-t5-tensorboard.png)

Nice, we have trained our model. 🎉 Lets run evaluate the best model again on the test set.


In [None]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 819
  Batch size = 8


{'eval_loss': 1.3715944290161133,
 'eval_rouge1': 47.2358,
 'eval_rouge2': 23.5135,
 'eval_rougeL': 39.6266,
 'eval_rougeLsum': 43.3458,
 'eval_gen_len': 17.39072039072039,
 'eval_runtime': 108.99,
 'eval_samples_per_second': 7.514,
 'eval_steps_per_second': 0.945,
 'epoch': 5.0}

The best score we achieved is an `rouge1` score of `47.23`. 

Lets save our results and tokenizer to the Hugging Face Hub and create a model card. 

In [None]:
# Save our tokenizer and create model card
tokenizer.save_pretrained(repository_id)
trainer.create_model_card()
# Push the results to the hub
trainer.push_to_hub()

## 4. Run Inference

Now we have a trained model, we can use it to run inference. We will use the `pipeline` API from transformers and a `test` example from our dataset.

In [None]:
from transformers import pipeline
from random import randrange        

# load model and tokenizer from huggingface hub with pipeline
summarizer = pipeline("summarization", model="philschmid/flan-t5-base-samsum", device=0)

# select a random test sample
sample = dataset['test'][randrange(len(dataset["test"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")

# summarize dialogue
res = summarizer(sample["dialogue"])

print(f"flan-t5-base summary:\n{res[0]['summary_text']}")

dialogue: 
Abby: Have you talked to Miro?
Dylan: No, not really, I've never had an opportunity
Brandon: me neither, but he seems a nice guy
Brenda: you met him yesterday at the party?
Abby: yes, he's so interesting
Abby: told me the story of his father coming from Albania to the US in the early 1990s
Dylan: really, I had no idea he is Albanian
Abby: he is, he speaks only Albanian with his parents
Dylan: fascinating, where does he come from in Albania?
Abby: from the seacoast
Abby: Duress I believe, he told me they are not from Tirana
Dylan: what else did he tell you?
Abby: That they left kind of illegally
Abby: it was a big mess and extreme poverty everywhere
Abby: then suddenly the border was open and they just left 
Abby: people were boarding available ships, whatever, just to get out of there
Abby: he showed me some pictures, like <file_photo>
Dylan: insane
Abby: yes, and his father was among the people
Dylan: scary but interesting
Abby: very!
---------------
flan-t5-base summary:
A