<a href="https://www.kaggle.com/code/ahabbsheraz/flan-t5-fine-tuning?scriptVersionId=189974919" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# FLAN-T5 Fine-tuning
Notebook created by M. Ahabb Sheraz

## Download Medical Dataset

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset
dataset = load_dataset("medalpaca/medical_meadow_medical_flashcards", trust_remote_code=True)

In [None]:
#Exploration over train samples
dataset['train'][14]

In [None]:
!pip install evaluate

## Loading the FLAN-T5 Model

In [None]:
import nltk
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import T5Tokenizer, DataCollatorForSeq2Seq
from transformers import T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [None]:
# Load the tokenizer, model, and data collator
MODEL_NAME = "google/flan-t5-small"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

## Preparing Data for Fine-Tuning

In [None]:
dataset = dataset["train"].train_test_split(test_size=0.3)

In [None]:
dataset

In [None]:
# We prefix our tasks with "answer the question"
prefix = "Answer this question truthfully: "

# Define the preprocessing function

def preprocess_function(examples):
    """Add prefix to the sentences, tokenize the text, and set the labels"""
   # The "inputs" are the tokenized answer:
    inputs = [prefix + doc for doc in examples["input"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
  
   # The "labels" are the tokenized outputs:
    labels = tokenizer(text_target=examples["output"], 
                      max_length=512,         
                      truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# Map the preprocessing function across our dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
tokenized_dataset['train'][14]

## FLAN-T5 Training and Fine-Tuning


In [None]:
!pip install rouge_score

In [None]:
nltk.download("punkt", quiet=True)
metric = evaluate.load("rouge")

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  
    return result

In [None]:
batch_size = 8
num_train_epochs = 3

# Show the training loss with every epoch
logging_steps = len(tokenized_dataset["train"]) // batch_size #184.125


#default hyperparameters
args = Seq2SeqTrainingArguments(
    output_dir="check",
    eval_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,  #only save up to 3 checkpoints during training
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,  #the decoder performs inference by predicting tokens one by one, and this is implemented by the modelâ€™s generate() method. Setting predict_with_generate=True tells the Seq2SeqTrainer to use that method for evaluation.
    logging_steps=logging_steps
)

In [None]:
trainer = Seq2SeqTrainer(
   model=model,
   args=args,
   train_dataset=tokenized_dataset["train"],
   eval_dataset=tokenized_dataset["test"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
trainer.save_model("/kaggle/working/")

## Model Inference

In [None]:
last_checkpoint = "/kaggle/working/results/checkpoint-1000"

finetuned_model = T5ForConditionalGeneration.from_pretrained(last_checkpoint)
tokenizer = T5Tokenizer.from_pretrained(last_checkpoint)

In [None]:
my_question = "What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?"
inputs = "Answer this question truthfully: " + my_question

In [None]:
inputs = tokenizer(inputs, return_tensors="pt")
outputs = finetuned_model.generate(**inputs)
answer = tokenizer.decode(outputs[0])
actual = dataset['train']['output'][0]
from textwrap import fill

print("Predicted answer: ", fill(answer, width=80))
print("Actual answer: ", fill("Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.", width=80))