
# 6.3.2 - Fine-Tuning T5 for Generative QA

In this notebook, we'll demonstrate how to fine-tune the T5 model on a QA dataset. T5 (Text-to-Text Transfer Transformer) is a flexible model that reformulates all tasks into a text generation format.

We will:
- Load the SQuAD dataset using `datasets`
- Preprocess the data for T5's format
- Tokenize the inputs and targets
- Fine-tune `t5-small` on QA using `Trainer`
- Evaluate and generate sample predictions


In [None]:

!pip install transformers datasets accelerate


In [None]:

from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset
import torch


In [None]:

dataset = load_dataset("squad")
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)


In [None]:

def preprocess(example):
    input_text = f"question: {example['question']} context: {example['context']}"
    target_text = example["answers"]["text"][0] if example["answers"]["text"] else ""
    inputs = tokenizer(input_text, max_length=512, padding="max_length", truncation=True)
    targets = tokenizer(target_text, max_length=64, padding="max_length", truncation=True)
    inputs["labels"] = targets["input_ids"]
    return inputs


In [None]:

tokenized_dataset = dataset.map(preprocess, batched=True)
train_dataset = tokenized_dataset["train"]
val_dataset = tokenized_dataset["validation"]


In [None]:

training_args = TrainingArguments(
    output_dir="./t5-qa",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10
)


In [None]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)


In [None]:

trainer.train()


In [None]:

model.save_pretrained("./t5-finetuned-qa")
tokenizer.save_pretrained("./t5-finetuned-qa")


In [None]:

test_question = "What is the capital of France?"
test_context = "France is a country in Europe. The capital of France is Paris."

input_text = f"question: {test_question} context: {test_context}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids, max_length=64)
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("Generated Answer:", answer)
