# Fine-tune FLAN-T5 for chat & dialogue summarization

In this blog, you will learn how to fine-tune [google/flan-t5-xl](https://huggingface.co/google/flan-t5-xl) for chat & dialogue summarization using Hugging Face Transformers. If you already know T5, FLAN-T5 is just better at everything. For the same number of parameters, these models have been fine-tuned on more than 1000 additional tasks covering also more languages. 

In this example we will use the [samsum](https://huggingface.co/datasets/samsum) dataset a collection of about 16k messenger-like conversations with summaries. Conversations were created and written down by linguists fluent in English.

You will learn how to:

1. [Setup Development Environment](#1-setup-development-environment)
2. [Load and prepare samsum dataset](#2-load-and-prepare-samsum-dataset)
3. [Fine-tune and evaluate FLAN-T5](#3-fine-tune-and-evaluate-flan-t5)
4. [Run Inference and summarize ChatGPT dialogues](#4-run-inference-and-summarize-chatgpt-dialogues)

Before we can start, make sure you have a [Hugging Face Account](https://huggingface.co/join) to save artifacts and experiments. 

## Quick intro: FLAN-T5, just a better T5

FLAN-T5 released with the [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf) paper is an enhanced version of T5 that has been finetuned in a mixture of tasks. The paper explores instruction finetuning with a particular focus on (1) scaling the number of tasks, (2) scaling the model size, and (3) finetuning on chain-of-thought data. The paper discovers that overall instruction finetuning is a general method for improving the performance and usability of pretrained language models. 

![flan-t5](../assets/flan-t5.png)

* Paper: https://arxiv.org/abs/2210.11416
* Official repo: https://github.com/google-research/t5x

--- 

Now we know what FLAN-T5 is, let's get started. 🚀

_Note: This tutorial was created and run on a g4dn.xlarge AWS EC2 Instance including a NVIDIA T4._

## 1. Setup Development Environment

Our first step is to install the Hugging Face Libraries, including transformers and datasets. Running the following cell will install all the required packages. 

In [8]:
# python
#! pip install -U ipykernel
#! pip install transformers datasets  torch accelerate langchain sentence_transformers chromadb runhouse tiktoken

In [9]:
# install git-fls for pushing model and logs to the hugging face hub
#!sudo apt-get install git-lfs --yes

This example will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. To be able to push our model to the Hub, you need to register on the [Hugging Face](https://huggingface.co/join). 
If you already have an account, you can skip this step. 
After you have an account, we will use the `notebook_login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk. 

In [7]:
from langchain.document_loaders import TextLoader
import os
import click
from typing import List
from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma 
from langchain.docstore.document import Document
from langchain import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceInstructEmbeddings
from dotenv import load_dotenv
from multiprocessing import Pool
from tqdm import tqdm
from langchain.document_loaders import (
    CSVLoader,
    EverNoteLoader,
    PDFMinerLoader,
    TextLoader,
    UnstructuredEmailLoader,
    UnstructuredEPubLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
    UnstructuredODTLoader,
    UnstructuredPowerPointLoader,
    UnstructuredWordDocumentLoader,
)

In [8]:
import os
import glob
from typing import List
from dotenv import load_dotenv
from multiprocessing import Pool
from tqdm import tqdm

# Custom document loaders
class MyElmLoader(UnstructuredEmailLoader):
    """Wrapper to fallback to text/plain when default does not work"""

    def load(self) -> List[Document]:
        """Wrapper adding fallback for elm without html"""
        try:
            try:
                doc = UnstructuredEmailLoader.load(self)
            except ValueError as e:
                if 'text/html content not found in email' in str(e):
                    # Try plain text
                    self.unstructured_kwargs["content_source"]="text/plain"
                    doc = UnstructuredEmailLoader.load(self)
                else:
                    raise
        except Exception as e:
            # Add file_path to exception message
            raise type(e)(f"{self.file_path}: {e}") from e

        return doc


# Map file extensions to document loaders and their arguments
LOADER_MAPPING = {
    ".csv": (CSVLoader, {}),
    # ".docx": (Docx2txtLoader, {}),
    ".doc": (UnstructuredWordDocumentLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".enex": (EverNoteLoader, {}),
    ".eml": (MyElmLoader, {}),
    ".epub": (UnstructuredEPubLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".md": (UnstructuredMarkdownLoader, {}),
    ".odt": (UnstructuredODTLoader, {}),
    ".pdf": (PDFMinerLoader, {}),
    ".ppt": (UnstructuredPowerPointLoader, {}),
    ".pptx": (UnstructuredPowerPointLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
    # Add more mappings for other file extensions and loaders as needed
}


def load_single_document(file_path: str) -> Document:
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext in LOADER_MAPPING:
        loader_class, loader_args = LOADER_MAPPING[ext]
        loader = loader_class(file_path, **loader_args)
        return loader.load()[0]

    raise ValueError(f"Unsupported file extension '{ext}'")


def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
    """
    Loads all documents from the source documents directory, ignoring specified files
    """
    all_files = []
    for ext in LOADER_MAPPING:
        all_files.extend(
            glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
        )
    filtered_files = [file_path for file_path in all_files if file_path not in ignored_files]

    with Pool(processes=os.cpu_count()) as pool:
        results = []
        with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
            for i, doc in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
                results.append(doc)
                pbar.update()

    return results

## 2. Load and prepare samsum dataset

we will use the [samsum](https://huggingface.co/datasets/samsum) dataset a collection of about 16k messenger-like conversations with summaries. Conversations were created and written down by linguists fluent in English.

```json
{
  "id": "13818513",
  "summary": "Amanda baked cookies and will bring Jerry some tomorrow.",
  "dialogue": "Amanda: I baked cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)"
}
```

In [10]:
import glob

chunk_size = 1000
chunk_overlap = 20
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp
import os
import argparse

load_dotenv()

#embeddings_model_name = #os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')

model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')


documents = []

a=glob.glob("source_documents/*.txt")
print(len(a))
print(a)
for i in range(len(a)):

    documents.extend(TextLoader(a[i]).load())

a=glob.glob("source_documents/*.html")
print(a)
print(len(a))

for i in range(len(a)):

    documents.extend(UnstructuredHTMLLoader(a[i]).load())

a=glob.glob("source_documents/*.pdf")
print(a)
print(len(a))

for i in range(len(a)):

    documents.extend(PDFMinerLoader(a[i]).load())


print(documents)   
 

2
['source_documents/bhagwan.txt', 'source_documents/anupam.txt']
['source_documents/Free Open Source Alternative to ChatGPT — GPT4All by Wei-Meng Lee May, 2023 Level Up Coding.html', 'source_documents/Automated Planning Tool makes work order allocation more efficient - Amazon Science.html', 'source_documents/Anupam Purwar - NSF PG Scholarship Programme.html', 'source_documents/Anupam Purwar - IEEE Xplore Author Profile.html', 'source_documents/Anupam Purwar - Amazon Science.html', 'source_documents/_Anupam Purwar_ - _Google Scholar_.html']
6
['source_documents/arimax-model-for-forecasting-maintenance-work-amfm-a-multi-stage-seasonal-arimax-model-for-work-order-time-series-forecasting.pdf', 'source_documents/hom.pdf', 'source_documents/automated-planning-tool-apt-a-mised-interger-non-linear-programming-problem-solver-for-workorder-scheduling.pdf']
3


In [12]:
from langchain.embeddings import HuggingFaceEmbeddings

model_name = "sentence-transformers/all-mpnet-base-v2"
#model_name = "sentence-transformers/LaBSE"
#model_name= 'intfloat/e5-large-v2'
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
hf = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)

text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
texts = text_splitter.split_documents(documents)
db = Chroma.from_documents(texts, hf)
model_id ="google/flan-t5-base"
#llm =  HuggingFacePipeline.from_model_id(model_id="google/flan-t5-base", task="text2text-generation", model_kwargs={"temperature":1e-1, "max_length" : 256})
#llm =  HuggingFacePipeline.from_model_id(model_id=model_id, task="question-answering", model_kwargs={"temperature":1e-1, "max_length" : 512}) 
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k": 20} )
#callbacks = []  
#qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)


#llm =  HuggingFacePipeline.from_model_id(model_id=model_id, task="question-answering", model_kwargs={"temperature":1e-1, "max_length" : 512}) 
llm =  HuggingFacePipeline.from_model_id(model_id=model_id, task="summarization", model_kwargs={"temperature":1e-1, "max_length" : 512}) 
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k": 20} )
callbacks = []  #if args.mute_stream else [StreamingStdOutCallbackHandler()]
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)

Created a chunk of size 1321, which is longer than the specified 1000
Created a chunk of size 1375, which is longer than the specified 1000
Created a chunk of size 1855, which is longer than the specified 1000
Created a chunk of size 2124, which is longer than the specified 1000
Created a chunk of size 1528, which is longer than the specified 1000
Created a chunk of size 1486, which is longer than the specified 1000
Created a chunk of size 1451, which is longer than the specified 1000
Created a chunk of size 1553, which is longer than the specified 1000
Created a chunk of size 2335, which is longer than the specified 1000
Created a chunk of size 1018, which is longer than the specified 1000
Created a chunk of size 2131, which is longer than the specified 1000
Created a chunk of size 1037, which is longer than the specified 1000
Created a chunk of size 1447, which is longer than the specified 1000
Created a chunk of size 1903, which is longer than the specified 1000
Created a chunk of s

In [None]:
from langchain.document_loaders import YoutubeLoader
loader = YoutubeLoader.from_youtube_url("https://www.youtube.com/watch?v=QsYGlZkevEg", add_video_info=True)
loader.load()

In [None]:

#! pip install youtube-transcript-api
! pip install einops

#! pip3 install youtube-transcript-api
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch

model = "tiiuae/falcon-40b-instruct"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)
sequences = pipeline(
   "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
    max_length=200,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
    print(f"Result: {seq['generated_text']}")



In [4]:
#llm =  HuggingFacePipeline.from_model_id(model_id="google/flan-t5-base", task="text2text-generation", model_kwargs={"temperature":1e-1, "max_length" : 512})
from langchain import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceInstructEmbeddings

model_id ="google/flan-t5-base"
#llm =  HuggingFacePipeline.from_model_id(model_id=model_id, task="question-answering", model_kwargs={"temperature":1e-1, "max_length" : 512}) 
llm =  HuggingFacePipeline.from_model_id(model_id=model_id, task="summarization", model_kwargs={"temperature":1e-1, "max_length" : 512}) 
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k": 20} )
callbacks = []  #if args.mute_stream else [StreamingStdOutCallbackHandler()]
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)

In [63]:
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
texts = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings() #model_name=embeddings_model_name)
print(embeddings)
from sentence_transformers import SentenceTransformer
embeddings = SentenceTransformer('sentence-transformers/LaBSE')
embeddings = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
db = Chroma.from_documents(texts, embeddings)

print(embeddings)

from langchain.embeddings import SentenceTransformerEmbeddings 

embeddings = SentenceTransformerEmbeddings("all-MiniLM-L6-v2")

#from langchain.embeddings.openai import OpenAIEmbeddings
#embeddings = OpenAIEmbeddings()

db = Chroma.from_documents(texts, embeddings)
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k": 20} )
model_name = "sentence-transformers/all-mpnet-base-v2"
llm =  HuggingFacePipeline.from_model_id(model_id="google/flan-t5-base", task="text2text-generation", model_kwargs={"temperature":1e-1, "max_length" : 256})

callbacks = []  #if args.mute_stream else [StreamingStdOutCallbackHandler()]
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)


Created a chunk of size 1321, which is longer than the specified 1000
Created a chunk of size 1375, which is longer than the specified 1000
Created a chunk of size 1855, which is longer than the specified 1000
Created a chunk of size 2124, which is longer than the specified 1000
Created a chunk of size 1528, which is longer than the specified 1000
Created a chunk of size 1486, which is longer than the specified 1000
Created a chunk of size 1451, which is longer than the specified 1000
Created a chunk of size 1553, which is longer than the specified 1000
Created a chunk of size 2335, which is longer than the specified 1000
Created a chunk of size 1018, which is longer than the specified 1000
Created a chunk of size 2131, which is longer than the specified 1000
Created a chunk of size 1037, which is longer than the specified 1000
Created a chunk of size 1447, which is longer than the specified 1000
Created a chunk of size 1903, which is longer than the specified 1000
Created a chunk of s

client=SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
) model_name='sentence-transformers/all-mpnet-base-v2' cache_folder=None model_kwargs={} encode_kwargs={}


Using embedded DuckDB without persistence: data will be transient


In [13]:
dataset_id = "samsum"
print(qa({'query': 'Who is Anupam Purwar'}))

Token indices sequence length is longer than the specified maximum sequence length for this model (3751 > 512). Running this sequence through the model will result in indexing errors


{'query': 'Who is Anupam Purwar', 'result': 'Bhagwan Chowdhry is a Professor of Finance at the Indian School of Business and Research Professor at UCLA Anderson where he has held an appointment since 1988', 'source_documents': [Document(page_content='Author\n\nAnupam Purwar\n\nResearch Scientist', metadata={'source': 'source_documents/Anupam Purwar - Amazon Science.html'}), Document(page_content='Anupam Purwar\n\nadmin', metadata={'source': 'source_documents/Anupam Purwar - NSF PG Scholarship Programme.html'}), Document(page_content='Anupam is currently working as a Research Scientist with Amazon developing tech products for Amazon’s global network. He specializes in solving problems related to Natural Language Processing and Optimization. In his previous role at Amazon, he was credited with developing Railways as the third mode of transport which fetched him two awards from Global VP. Prior to this, Anupam worked as a Research Scientist at Indian Institute of Science (IISc). At IISc, 

In [4]:
#! pip install -U sentence-transformers

from sentence_transformers import SentenceTransformer, util
sentences = ["I'm very happy", "I'm full of happiness"]

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

#Compute embedding for both lists
embedding_1= model.encode(sentences[0], convert_to_tensor=True)
embedding_2 = model.encode(sentences[1], convert_to_tensor=True)

util.pytorch_cos_sim(embedding_1, embedding_2)
## tensor([[0.6003]])


tensor([[0.5758]])

In [10]:

#print(qa({'query': 'Who is Anupam Purwar'}))

#print(qa({'query': 'What is APT'})['result'])

print(qa({'query': 'What is Supervised learning'})['result'])
print(qa({'query': 'What is ANN'})['result'])
print(qa({'query': 'What is RNN'})['result'])
#print(qa({'query': 'Who is Anupam Purwar'})['source_documents'])

Token indices sequence length is longer than the specified maximum sequence length for this model (2670 > 512). Running this sequence through the model will result in indexing errors


Supervised learning is an area of machine learning where the chosen algorithm tries to fit a target using the given input. A set of training data that contains labels is sup plied to the algorithm. Based on a massive set of data, the algorithm will learn a rule that it uses to predict the labels for new observations.
ANN is a primary algorithm used across all types of machine learning.
Recurrent neural networks (RNNs) are called “recurrent” because they perform the same task for every element of a sequence, with the output being dependent on the previous computations. RNNs have a memory, which captures information about what has been calculated so far.


In [50]:
#Token indices sequence length is longer than the specified maximum sequence length for this model (2670 > 512). Running this sequence through the model will result in indexing errors
print(qa({'query': 'What is Supervised learning'})['result'])
print(qa({'query': 'What is ANN'})['result'])
print(qa({'query': 'What is RNN'})['result'])

Supervised learning is an area of machine learning where the chosen algorithm tries to fit a target using the given input. A set of training data that contains labels is sup plied to the algorithm. Based on a massive set of data, the algorithm will learn a rule that it uses to predict the labels for new observations.
ANN is a primary algorithm used across all types of machine learning.
Recurrent neural networks (RNNs) are called “recurrent” because they perform the same task for every element of a sequence, with the output being dependent on the previous computations. RNN models have a memory, which captures information about what has been calculated so far. As shown in Figure 5-4, a recurrent neural net work can be thought of as multiple copies of the same network, each passing a mes sage to a successor.


In [16]:
#print(qa({'query': 'Who is Bhagwan'}))

#print(qa({'query': 'Where is Bhagwan'})['result'])
#print(qa({'query': 'Who is Bhagwan'})['source_documents'])

print(qa({'query': 'What is NLP'})['result'])

print(qa({'query': 'Tell about NLP methods'})['result'])

Natural language processing (NLP) is a subfield of artificial intelligence used to aid computers in understanding natural human language.
Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Processing Natural Language Proce

In [8]:

print(qa({'query': 'What is main idea of Bhagwan Chowdhry'})['result'])


to reduce the dimensionality of a dataset with a large number of variables, while retaining as much variance in the data as possible


In [15]:
print(qa({'query': 'Tell about NLP'})['result'])
print(qa({'query': 'Tell about NLP techniques'})['result'])

Natural language processing (NLP) is a subfield of artificial intelligence used to aid computers in understanding natural human language. Most NLP techniques rely on machine learning to derive meaning from human languages.
Natural Language Processing (NLP) is a subfield of artificial intelligence concerned with programming computers to process textual data in order to gain useful insights. Natural Language Processing (NLP) manifests itself in different forms across many disciplines under various aliases, including (but not limited to) textual analysis, text mining, computational linguistics, and content analysis. Natural Language Processing (NLP) manifests itself in different forms across many disciplines under various aliases, including (but not limited to) textual analysis, text mining, computational linguistics, and content analysis. Natural Language Processing (NLP) is a subfield of artificial intelligence concerned with programming computers to process textual data in order to gai

In [None]:
print(qa({'query': 'Tell about NLP'})['result'])
print(qa({'query': 'Tell about NLP techniques'})['result'])

To load the `samsum` dataset, we use the `load_dataset()` method from the 🤗 Datasets library.


In [13]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset(dataset_id)

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

# Train dataset size: 14732
# Test dataset size: 819

Downloading builder script: 100%|██████████| 3.36k/3.36k [00:00<00:00, 25.9MB/s]
Downloading metadata: 100%|██████████| 1.58k/1.58k [00:00<00:00, 14.7MB/s]
Downloading readme: 100%|██████████| 7.04k/7.04k [00:00<00:00, 36.5MB/s]


Lets checkout an example of the dataset.

In [14]:
from random import randrange        


sample = dataset['train'][randrange(len(dataset["train"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")
print(f"summary: \n{sample['summary']}\n---------------")

To train our model we need to convert our inputs (text) to token IDs. This is done by a 🤗 Transformers Tokenizer. If you are not sure what this means check out [chapter 6](https://huggingface.co/course/chapter6/1?fw=tf) of the Hugging Face Course.

In [15]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_id="google/flan-t5-base"

# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)


before we can start training we need to preprocess our data. Abstractive Summarization is a text2text-generation task. This means our model will take a text as input and generate a summary as output. For this we want to understand how long our input and output will be to be able to efficiently batch our data. 

In [16]:
from datasets import concatenate_datasets

# The maximum total input sequence length after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Max source length: {max_source_length}")

# The maximum total sequence length for target text after tokenization. 
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
print(f"Max target length: {max_target_length}")

In [17]:
def preprocess_function(sample,padding="max_length"):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")

## 3. Fine-tune and evaluate FLAN-T5

After we have processed our dataset, we can start training our model. Therefore we first need to load our [FLAN-T5](https://huggingface.co/models?search=flan-t5) from the Hugging Face Hub. In the example we are using a instance with a NVIDIA V100 meaning that we will fine-tune the `base` version of the model. 
_I plan to do a follow-up post on how to fine-tune the `xxl` version of the model using Deepspeed._


In [18]:
from transformers import AutoModelForSeq2SeqLM

# huggingface hub model id
model_id="google/flan-t5-base"

# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

We want to evaluate our model during training. The `Trainer` supports evaluation during training by providing a `compute_metrics`.  
The most commonly used metrics to evaluate summarization task is [rogue_score](https://en.wikipedia.org/wiki/ROUGE_(metric)) short for Recall-Oriented Understudy for Gisting Evaluation). This metric does not behave like the standard accuracy: it will compare a generated summary against a set of reference summaries

We are going to use `evaluate` library to evaluate the `rogue` score.

In [19]:
import evaluate
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
nltk.download("punkt")

# Metric
metric = evaluate.load("rouge")

# helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

Before we can start training is to create a `DataCollator` that will take care of padding our inputs and labels. We will use the `DataCollatorForSeq2Seq` from the 🤗 Transformers library. 

In [20]:
from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)


The last step is to define the hyperparameters (`TrainingArguments`) we want to use for our training. We are leveraging the [Hugging Face Hub](https://huggingface.co/models) integration of the `Trainer` to automatically push our checkpoints, logs and metrics during training into a repository.

In [21]:
from huggingface_hub import HfFolder
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Hugging Face repository id
repository_id = f"{model_id.split('/')[1]}-{dataset_id}"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=repository_id,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    fp16=False, # Overflows with fp16
    learning_rate=5e-5,
    num_train_epochs=5,
    # logging & evaluation strategies
    logging_dir=f"{repository_id}/logs",
    logging_strategy="steps",
    logging_steps=500,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    # metric_for_best_model="overall_f1",
    # push to hub parameters
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repository_id,
    hub_token=HfFolder.get_token(),
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    compute_metrics=compute_metrics,
)

We can start our training by using the `train` method of the `Trainer`.

In [22]:
# Start training
trainer.train()


![flan-t5-tensorboard](../assets/flan-t5-tensorboard.png)

Nice, we have trained our model. 🎉 Lets run evaluate the best model again on the test set.


In [23]:
trainer.evaluate()

The best score we achieved is an `rouge1` score of `47.23`. 

Lets save our results and tokenizer to the Hugging Face Hub and create a model card. 

In [24]:
# Save our tokenizer and create model card
tokenizer.save_pretrained(repository_id)
trainer.create_model_card()
# Push the results to the hub
trainer.push_to_hub()

## 4. Run Inference

Now we have a trained model, we can use it to run inference. We will use the `pipeline` API from transformers and a `test` example from our dataset.

In [25]:
from transformers import pipeline
from random import randrange        

# load model and tokenizer from huggingface hub with pipeline
summarizer = pipeline("summarization", model="philschmid/flan-t5-base-samsum", device=0)

# select a random test sample
sample = dataset['test'][randrange(len(dataset["test"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")

# summarize dialogue
res = summarizer(sample["dialogue"])

print(f"flan-t5-base summary:\n{res[0]['summary_text']}")

Downloading (…)lve/main/config.json: 100%|██████████| 1.53k/1.53k [00:00<00:00, 9.29MB/s]
Downloading pytorch_model.bin: 100%|██████████| 990M/990M [00:48<00:00, 20.6MB/s] 
Downloading (…)okenizer_config.json: 100%|██████████| 2.54k/2.54k [00:00<00:00, 13.7MB/s]
Downloading spiece.model: 100%|██████████| 792k/792k [00:03<00:00, 222kB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.42M/2.42M [00:01<00:00, 2.00MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 2.20k/2.20k [00:00<00:00, 24.0MB/s]
