# Transfer Learning with Flan-T5 on SQuAD Dataset

**Author**: Axel Sirota

In this notebook, we will explore the process of transfer learning using the `flan-t5-base` model (or large if you have sufficient hardware), fine-tuning it on the SQuAD (Stanford Question Answering Dataset) for a question-answering task. We'll focus on preparing the data, fine-tuning the model, and evaluating its performance.

### Introduction to SQuAD and Flan-T5-base

The SQuAD dataset is a large-scale dataset for question answering. It contains a series of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to each question is a segment of text, or span, from the corresponding reading passage.

`flan-t5-base` is a variant of the T5 (Text-To-Text Transfer Transformer) model, specifically fine-tuned on a wide variety of datasets to generalize well to different tasks. In this notebook, we will use this model for a question-answering task on the SQuAD dataset.

## Preparation

First, let's ensure we have all necessary libraries installed. We'll be using the `transformers` library for the model and tokenizer, as well as the `datasets` library to load and process the SQuAD dataset, and TensorFlow for training.

In [None]:
# Install required libraries
!pip install -U transformers datasets tensorflow

### Importing Libraries

After installing the required packages, let's import them. We'll also set up some global parameters such as `EPOCHS` and `BATCH_SIZE` for fine-tuning.

In [None]:
# Import necessary libraries
import tensorflow as tf  # TensorFlow for deep learning
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
import warnings

# Set up global parameters for fine-tuning
EPOCHS = 3  # Number of training epochs
BATCH_SIZE = 8  # Batch size for training (adjust based on your hardware)
warnings.filterwarnings('ignore')

## Loading and Preparing the SQuAD Dataset

We'll load the SQuAD dataset using the `datasets` library. This dataset contains a collection of articles from Wikipedia, each paired with a set of questions and their corresponding answers. We'll preprocess the dataset and convert it into a format suitable for TensorFlow.

In [None]:
# Load the SQuAD dataset
dataset = None

# Display a sample from the training set
print(dataset['train'][0])

:### Tokenization and Data Preparation

We need to tokenize the input data and prepare it for the `flan-t5-base` model. We'll use the `AutoTokenizer` from the `transformers` library to handle this step. The tokenizer will convert text into tokens that the model can understand, and we'll create TensorFlow datasets for training and validation.

It is key to remember to truncate and pad since T5 requires all inputs to have the same shape. Also in seq2seq models apart from the `input_ids` we need to generate the `decoder_input_ids` which would play the role of the input to the decoder

In [None]:
# Load the tokenizer for flan-t5-base
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')

def preprocess_function(examples):
    # Combine the context and the question for the input to the model as a list of strings
    inputs = None

    # Tokenize the inputs (questions + contexts) with max_length 512 and remember to truncate and pad
    model_inputs = None

    # Tokenize the targets (answers) to max_length 128
    labels = None

    # The labels are the decoder's input ids
    model_inputs['labels'] = labels['input_ids']
    model_inputs['decoder_input_ids'] = labels['input_ids']

    return model_inputs

# Apply the preprocessing to the training and validation sets
train_tokenized_datasets = dataset['train'].select(range(10000)).map(preprocess_function, batched=True)   # We just use 10000 examples to make training go faster
validation_tokenized_datasets = dataset['validation'].select(range(1000)).map(preprocess_function, batched=True)  # We just use 1000 examples to make training go faster

In [None]:

# Convert datasets to TensorFlow format
train_dataset = None

val_dataset = None

In [None]:
for input, labels in train_dataset.take(1):
    print(input['input_ids'].shape)
    print(input['decoder_input_ids'].shape)
    print(labels.shape)

## Fine-Tuning the Flan-T5-Large Model

With the data prepared, we can now fine-tune the `flan-t5-large` model on the SQuAD dataset using TensorFlow.

In [None]:
# Load the pre-trained flan-t5-base model
model = None

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
    loss=model.hf_compute_loss,  # Use the model's built-in loss function
)

# Print the model summary
model.summary()

In [None]:
# Make the encoder, embedding and decoder layers non trainable

model.summary()

In [None]:
# Fit the model

## Evaluating the Model

After fine-tuning, we will evaluate the model's performance on the validation set. We will use the predictions to calculate the Exact Match (EM) and F1 scores, which are common metrics for question answering tasks.

In [None]:
# Evaluate the model on the validation set
results = model.evaluate(val_dataset)

# Display the evaluation metrics
print('Validation Loss:', results)

**Homework**: Think of a nice way to manually test this model!