In [4]:
# Importing the pandas library for data manipulation and analysis
import pandas as pd

# Importing train_test_split from scikit-learn for splitting data into training and testing sets
from sklearn.model_selection import train_test_split

# Importing necessary components from the Hugging Face transformers library
# T5Tokenizer: Tokenizer for the T5 model
# T5ForConditionalGeneration: T5 model for conditional generation tasks (e.g., text summarization, translation)
# AdamW: Optimizer used for fine-tuning the model
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW

# Importing PyTorch for tensor operations and building datasets/dataloaders
import torch

# Importing Dataset and DataLoader from PyTorch for creating custom datasets and loading data in batches
from torch.utils.data import Dataset, DataLoader

# Importing PyTorch Lightning, a high-level library for PyTorch that simplifies the training process
import pytorch_lightning as pl

# Importing TensorBoardLogger for logging training progress and metrics to TensorBoard
from pytorch_lightning.loggers import TensorBoardLogger

# Importing ModelCheckpoint for saving the best model during training
# Importing LearningRateMonitor for tracking the learning rate throughout training
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor





# Reading a CSV file containing the dataset into a pandas DataFrame

# The file path is specified as '/kaggle/input/medical-qa/intern_screening_dataset.csv'
df = pd.read_csv('/kaggle/input/medical-qa/intern_screening_dataset.csv')

# Dropping any rows in the DataFrame that contain missing values (NaNs)
# This ensures the dataset is clean and free of incomplete entries
df = df.dropna()





# Split the dataset

# Splitting the DataFrame 'df' into two sets: 'train_df' for training and 'temp_df' for further splitting
# 'test_size=0.3' indicates that 30% of the data will be allocated to 'temp_df' and the remaining 70% to 'train_df'
# 'random_state=42' ensures reproducibility of the split
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)

# Splitting the temporary DataFrame 'temp_df' into two sets: 'val_df' for validation and 'test_df' for testing
# 'test_size=0.5' indicates that 'temp_df' is split evenly into 'val_df' and 'test_df'
# 'random_state=42' ensures reproducibility of the split
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# This way we will have a training set (train_df), a validation set (val_df), and a test set (test_df) with the respective proportions of the original dataset.







# Defining a custom Dataset class for the Medical QA task

# This class MedicalQADataset extends Dataset from PyTorch and is tailored for the Medical QA task, facilitating the encoding of questions and answers using the T5 tokenizer.

"""
        Initializes the dataset with a dataframe, tokenizer, and maximum sequence length.
        
        Parameters:
        dataframe (pd.DataFrame): The dataframe containing the data.
        tokenizer (T5Tokenizer): The tokenizer for encoding the text.
        max_length (int): The maximum sequence length for the tokenizer.
        """
class MedicalQADataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        question = self.dataframe.iloc[idx, 0]
        answer = self.dataframe.iloc[idx, 1]
# Encoding the question and answer using the tokenizer        
        input_encoding = self.tokenizer.encode_plus(
            question, max_length=self.max_length, truncation=True, return_tensors="pt", padding="max_length"
        )
        target_encoding = self.tokenizer.encode_plus(
            answer, max_length=self.max_length, truncation=True, return_tensors="pt", padding="max_length"
        )
 # Extracting the input_ids for the labels and replacing pad tokens with -100
        # This is done to ignore padding tokens in the loss calculation       
        labels = target_encoding["input_ids"]
        labels[labels == self.tokenizer.pad_token_id] = -100
 # Returning a dictionary containing the encoded inputs and labels       
        return {
            'input_ids': input_encoding['input_ids'].squeeze(), # Input IDs for the question
            'attention_mask': input_encoding['attention_mask'].squeeze(), # Attention mask for the question
            'labels': labels.squeeze() # Labels for the answer with pad tokens replaced by -100
        }
# Defining the model name to be used for the T5 model
MODEL_NAME = 't5-base'





# Creating datasets and dataloaders

# This  section of the code initializes the tokenizer, creates dataset instances for training, validation, and testing, and then sets up dataloaders for each dataset. The dataloaders are responsible for batching and optionally shuffling the data, which is important for training machine learning models.
# Initializing the tokenizer for the T5 model using the pre-trained model specified by MODEL_NAME
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

# Creating instances of the custom MedicalQADataset for training, validation, and test sets
# Each dataset is initialized with the respective dataframe and the tokenizer
train_dataset = MedicalQADataset(train_df, tokenizer)
val_dataset = MedicalQADataset(val_df, tokenizer)
test_dataset = MedicalQADataset(test_df, tokenizer)

# Creating DataLoader instances for each dataset
# DataLoader allows for batching and shuffling of data during training


# DataLoader for the training set
# batch_size=8 specifies that each batch will contain 8 samples
# shuffle=True ensures that the data is shuffled at the beginning of each epoch
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,)

# DataLoader for the validation set
# batch_size=8 specifies that each batch will contain 8 samples
# shuffle=False (default) means that the data will not be shuffled
val_loader = DataLoader(val_dataset, batch_size=8, )

# DataLoader for the test set
# batch_size=8 specifies that each batch will contain 8 samples
# shuffle=False (default) means that the data will not be shuffled
test_loader = DataLoader(test_dataset, batch_size=8)





# Define the model

# This class T5FineTuner extends pl.LightningModule and is used to fine-tune a pre-trained T5 model for a specific task. It includes methods for the forward pass, training, validation, and testing steps, configuring the optimizer, and logging example questions and answers at the end of each epoch.

# Defining a PyTorch Lightning module for fine-tuning a T5 model
class T5FineTuner(pl.LightningModule):
    def __init__(self, model_name=MODEL_NAME, tokenizer_name=MODEL_NAME, learning_rate=5e-5, log_interval=30):

        super(T5FineTuner, self).__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name) # Loading the pre-trained T5 model
        self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_name) # Loading the pre-trained tokenizer
        self.learning_rate = learning_rate # Setting the learning rate
        self.log_interval = log_interval # Setting the logging interval


    def forward(self, input_ids, attention_mask, labels=None):
        return self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        if batch_idx % self.log_interval ==0:
            self.on_epoch_end()
        loss = outputs.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        loss = outputs.loss
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        loss = outputs.loss
        self.log('test_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

    def on_epoch_end(self):
        example_questions = [
            "What are the symptoms of diabetes?",
            "How is hypertension diagnosed?",
            "How to prevent Glaucoma? "
        ]
        for i, question in enumerate(example_questions):
            answer = generate_answer(self, question, self.tokenizer) # Generate answer for the given question using the fine-tuned model
            print(answer)
            self.logger.experiment.add_text(f"example_question_{i}", question, self.current_epoch) # Log the question and generated answer to TensorBoard
            self.logger.experiment.add_text(f"example_answer_{i}", answer, self.current_epoch)








# Function to generate answers

# This  generate_answer function takes a fine-tuned T5 model, a question, and a tokenizer as inputs and generates an answer using the model. It uses beam search for more accurate generation and decodes the output token IDs to a readable string format.
"""
    Generates an answer to a given question using the fine-tuned T5 model.
    
    Parameters:
    model (T5FineTuner): The fine-tuned T5 model.
    question (str): The input question as a string.
    tokenizer (T5Tokenizer): The tokenizer for encoding the question.
    max_length (int): The maximum length of the generated answer.
    
    Returns:
    answer (str): The generated answer as a string.
    """
def generate_answer(model, question, tokenizer, max_length=150):
    model.eval()
# Encoding the question using the tokenizer    
    inputs = tokenizer.encode_plus(question, return_tensors="pt", truncation=True, max_length=512)
    input_ids = inputs["input_ids"].to(model.device) # Moving input IDs tensor to the model's device (GPU)
    attention_mask = inputs["attention_mask"].to(model.device) # Moving attention mask tensor to the model's device
 # Generating an answer without computing gradients (inference mode)   
    with torch.no_grad():
        outputs = model.model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, num_beams=5, early_stopping=True)  # Using beam search with 5 beams to generate more accurate answers
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)  # Decoding the generated token IDs to a string, skipping special tokens
    return answer






# Set up logging and checkpoints

# This code sets up a logging system using TensorBoard and creates callbacks for saving model checkpoints and monitoring the learning rate during training. The TensorBoardLogger will save logs to a directory for later visualization, ModelCheckpoint will save the best models based on validation loss, and LearningRateMonitor will log the learning rate at each step.
# Logs will be saved in the directory "tb_logs" under the name "T5_finetuning"
logger = TensorBoardLogger("tb_logs", name="T5_finetuning")
# Setting up the model checkpoint callback
# This callback will monitor the validation loss ("val_loss") and save the top 3 models with the lowest validation loss
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", # Metric to monitor
    dirpath="checkpoints", # Directory to save the checkpoints
    filename="T5-{epoch:02d}-{val_loss:.2f}", # Filename template for the checkpoints
    save_top_k=3,  # Number of top models to save
    mode="min", # Mode to determine the best models, "min" means lower is better
)
# Setting up the learning rate monitor
# This callback logs the learning rate at each step
lr_monitor = LearningRateMonitor(logging_interval='step')






# Training the model

# This code sets up and trains the T5FineTuner model using PyTorch Lightning's high-level API, ensuring that logging, checkpointing, and learning rate monitoring are integrated into the training loop.
model = T5FineTuner() # Initializing the T5FineTuner model
# Initializing the PyTorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=3,  # Setting the maximum number of epochs for training
    logger=logger,  # Using the previously defined TensorBoard logger
    callbacks=[checkpoint_callback, lr_monitor] # Adding the checkpoint and learning rate monitor callbacks
)
trainer.fit(model, train_loader, val_loader) # Training the model using the fit method






# Testing the model using the test DataLoader
trainer.test(model, test_loader)





# Example questions to test the model
example_questions = [
    "What are the symptoms of diabetes?",
    "How is hypertension diagnosed?",
    "How to prevent Glaucoma? "
]
# Loop through each example question, generate an answer using the fine-tuned model, and print the question and answer
for question in example_questions:
    print(f"Question: {question}")
    print(f"Answer: {generate_answer(model, question, tokenizer)}")

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-05-21 15:16:40.025417: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-21 15:16:40.025552: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-21 15:16:40.163075: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

What are the symptoms of diabetes?
How is hypertension diagnosed?
How to prevent Glaucoma?
What are the symptoms of diabetes?
How is hypertension diagnosed?
How to prevent Glaucoma?
What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes?
How is hypertension diagnosed?
How to prevent Glaucoma? How to prevent Glaucoma? How to prevent Glaucoma?
What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes?
How is hypertension diagnosed?
How to prevent Glaucoma? How to prevent Glaucoma? How to prevent Glaucoma? How to prevent Glaucoma?
What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes? What are the signs and symptoms of diabetes?
How is hypertension diagnosed?
How to prevent Glaucoma? How to prevent Glaucoma? How t

Validation: |          | 0/? [00:00<?, ?it/s]

What are the signs and symptoms of diabetes? The signs and symptoms of diabetes can vary from person to person. The signs and symptoms of diabetes can vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person to person. The signs and symptoms of diabetes may vary from person
How is hypertension diagnosed? The diagnosis of hypertension is based on a number of factors, including the type of hypertension, the type of hypertension, the type of hypertension, the type of hypertension, the type of hypertens

Validation: |          | 0/? [00:00<?, ?it/s]

Signs and symptoms of diabetes include high blood glucose (HDL) and high blood sugar (HDL). High blood glucose (HDL) is a condition in which the body does not make enough insulin. High blood glucose (HDL) is a condition in which the body does not make enough insulin. High blood glucose (HDL) is a condition in which the body does not make enough insulin. High blood glucose (HDL) is a condition in which the body does not make enough insulin. High blood glucose (HDL) is a condition in which the body does not make enough insulin. High blood glucose (hypoglycemia) is a condition in which the body does not make
How is hypertension diagnosed? A diagnosis of hypertension is made based on the signs and symptoms of the condition. A diagnosis of hypertension is made based on the signs and symptoms of the condition. A diagnosis of hypertension is made based on the signs and symptoms of the condition.
How is glaucoma treated? The treatment of glaucoma is based on the following factors: - How is gla

Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

Question: What are the symptoms of diabetes?
Answer: Signs and symptoms of diabetes include high blood sugar (HDL), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar (hyperglycemia), high blood sugar
Question: How is hypertension diagnosed?
Answer: How is hypertension diagnosed? A diagnosis of hypertension is made based on the signs and symptoms present in each person. The diagnosis of hypertension is made based on the signs and symptoms present in each person. The diagnosis of hypertension is made based on the signs and symptoms present in each person.
Question: How to prevent Glaucoma? 
Answer: How is glaucoma treated? The goal of glaucoma treatment