# Fine-Tuning GPT-2 on MedQA Dataset

This notebook demonstrates how to fine-tune GPT-2 on a MedQA dataset. We reformat and rephrase the original tutorial for clarity and include several improvements.

### Overview

1. **Install Dependencies:** Set up the required libraries and tools.
2. **Download Dataset:** Fetch the MedQA dataset from Kaggle.
3. **Preprocess Data:** Clean the dataset by lowercasing text, removing duplicates, and handling missing values.
4. **Train/Validation Split:** Use the top 100 focus areas to build balanced splits (4 samples for training and 1 for validation per category).
5. **Prepare Data:** Combine questions and answers with special tokens (`<question>`, `<answer>`, `<end>`).
6. **Load & Tokenize:** Load the GPT-2 model and tokenizer, add custom tokens, and tokenize the data.
7. **Set Up Training:** Define training arguments and create a `Trainer` instance.
8. **Training & Inference:** Fine-tune the model, save it, and define a helper function for inference.

### How to Improve Accuracy Further?

To boost the model’s accuracy by up to 10×, consider the following strategies:

- **Increase Dataset Size:** More high-quality data will improve generalization.
- **Hyperparameter Tuning:** Experiment with learning rates, batch sizes, number of epochs, and warmup steps.
- **Advanced Regularization:** Techniques such as gradient accumulation, early stopping, and optimized weight decay can help reduce overfitting.
- **Data Augmentation:** Enrich the dataset using augmentation methods to enhance diversity.
- **Model Enhancements:** Consider larger GPT-2 variants or adapter modules for parameter-efficient fine-tuning.

## 1. Install Dependencies

Run the following cell to install the necessary packages. Note that we also reinstall a specific version of `pyarrow` to ensure compatibility.

In [None]:
%%capture
!pip -q uninstall pyarrow -y
!pip -q install pyarrow==15.0.2 datasets accelerate transformers kaggle

import os
import json
import torch
import pandas as pd
import numpy as np
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from zipfile import ZipFile

from google.colab import drive
drive.mount('/content/drive/')

# Set up Kaggle API credentials (make sure your kaggle.json is in the correct location)
os.environ['KAGGLE_CONFIG_DIR'] = "/content/drive/My Drive/Colab Notebooks/GPT2FineTune/.kaggle"
print('Dependencies installed and environment set up.')

## 2. Download and Load the MedQA Dataset

This cell downloads the MedQA dataset from Kaggle and loads it into a Pandas DataFrame.

In [None]:
# Define Kaggle dataset identifier
dataset = "gpreda/medquad"

# Download the dataset from Kaggle
os.system(f'kaggle datasets download -d {dataset}')

# Unzip the dataset
with ZipFile('medquad.zip', 'r') as zip_ref:
    zip_ref.extractall('medquad')

# Load the CSV file into a DataFrame
df = pd.read_csv('medquad/medquad.csv')
print('First few rows of the dataset:')
print(df.head())

First few rows of the dataset:
                                 question  \
0                What is (are) Glaucoma ?   
1                  What causes Glaucoma ?   
2     What are the symptoms of Glaucoma ?   
3  What are the treatments for Glaucoma ?   
4                What is (are) Glaucoma ?   

                                              answer           source  \
0  Glaucoma is a group of diseases that can damag...  NIHSeniorHealth   
1  Nearly 2.7 million people have glaucoma, a lea...  NIHSeniorHealth   
2  Symptoms of Glaucoma  Glaucoma can develop in ...  NIHSeniorHealth   
3  Although open-angle glaucoma cannot be cured, ...  NIHSeniorHealth   
4  Glaucoma is a group of diseases that can damag...  NIHSeniorHealth   

  focus_area  
0   Glaucoma  
1   Glaucoma  
2   Glaucoma  
3   Glaucoma  
4   Glaucoma  


## 3. Data Preprocessing

Clean the dataset by converting text to lowercase, removing extra spaces, dropping rows with missing values, and removing duplicates. Keeping a record of removed rows is optional for auditing.

In [None]:
# Clean text columns: convert to lowercase and remove extra spaces/newlines
for col in df.select_dtypes(include=['object']).columns:
    df[col] = df[col].str.lower().str.split().str.join(' ')

print('Cleaned DataFrame preview:')
print(df.head())

# Remove rows with missing values and record them
removed_records = pd.DataFrame()
missing_values = df[df.isnull().any(axis=1)]
removed_records = pd.concat([removed_records, missing_values])
df = df.dropna()

# Remove duplicate rows based on 'question' and 'answer'
duplicates = df[df.duplicated(subset=['question', 'answer'], keep=False)]
removed_records = pd.concat([removed_records, duplicates])
df = df.drop_duplicates(subset=['question', 'answer'])

# Optionally, save removed records for audit
if not removed_records.empty:
    removed_records.to_csv('removed_records_audit.csv', index=False)

print('Data preprocessing complete.')

Cleaned DataFrame preview:
                                 question  \
0                what is (are) glaucoma ?   
1                  what causes glaucoma ?   
2     what are the symptoms of glaucoma ?   
3  what are the treatments for glaucoma ?   
4                what is (are) glaucoma ?   

                                              answer           source  \
0  glaucoma is a group of diseases that can damag...  nihseniorhealth   
1  nearly 2.7 million people have glaucoma, a lea...  nihseniorhealth   
2  symptoms of glaucoma glaucoma can develop in o...  nihseniorhealth   
3  although open-angle glaucoma cannot be cured, ...  nihseniorhealth   
4  glaucoma is a group of diseases that can damag...  nihseniorhealth   

  focus_area  
0   glaucoma  
1   glaucoma  
2   glaucoma  
3   glaucoma  
4   glaucoma  
Data preprocessing complete.


## 4. Create Training and Validation Splits

We build our training and validation sets by focusing on the top 100 focus areas. For each of these categories, 4 samples are used for training and 1 for validation.

**Note:** Using only the top 100 focus areas ensures that only well-represented categories are included, thus improving data quality and model training efficiency.

In [None]:
# Select the top 100 focus areas based on frequency
top_100_categories = df['focus_area'].value_counts().nlargest(100).index.tolist()

train_data = pd.DataFrame()
val_data = pd.DataFrame()

for category in top_100_categories:
    # Sample 4 records for training
    train_samples = df[df['focus_area'] == category].sample(n=4, random_state=42)

    # Sample 1 record for validation (excluding training samples)
    val_samples = df[(df['focus_area'] == category) & (~df.index.isin(train_samples.index))].sample(n=1, random_state=42)

    train_data = pd.concat([train_data, train_samples])
    val_data = pd.concat([val_data, val_samples])

print(f"Training set size: {train_data.shape[0]}")
print(f"Validation set size: {val_data.shape[0]}")

Training set size: 400
Validation set size: 100


## 5. Prepare the Data for Fine-Tuning

Each record is converted into a single text sequence using special tokens:

- `<question>`: Marks the beginning of a question
- `<answer>`: Separates the question from the answer
- `<end>`: Indicates the end of the sequence

This formatting helps the model learn the structure of Q&A pairs.

In [None]:
def combine_text(df):
    return df.apply(lambda row: f"<question>{row['question']}<answer>{row['answer']}<end>", axis=1)

train_sequences = combine_text(train_data)
val_sequences = combine_text(val_data)

# Create text strings for training and validation by joining with newline characters
train_text = '\n'.join(train_sequences)
val_text = '\n'.join(val_sequences)

# Save the prepared data to text files
with open('train_data.txt', 'w') as f:
    f.write(train_text)

with open('val_data.txt', 'w') as f:
    f.write(val_text)

print('Training and validation text files have been saved.')

Training and validation text files have been saved.


## 6. Load the Pretrained GPT-2 Model and Tokenizer

We load the GPT-2 model and its tokenizer, then add custom special tokens that match our data format.

In [None]:
# Load the GPT-2 model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2', use_cache=False)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Add special tokens to accommodate our data structure
special_tokens = {'pad_token': '<pad>', 'bos_token': '<question>', 'eos_token': '<end>', 'sep_token': '<answer>'}
tokenizer.add_special_tokens(special_tokens)

print('Special tokens added and GPT-2 model/tokenizer loaded.')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Special tokens added and GPT-2 model/tokenizer loaded.


## 7. Tokenize the Dataset

We load our text files using Hugging Face's `load_dataset` and then tokenize them. Each text entry is truncated and padded to a maximum length of 1024 tokens.

In [None]:
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=1024)

# Load dataset from text files
dataset = load_dataset('text', data_files={'train': 'train_data.txt', 'validation': 'val_data.txt'})

# Convert dataset to PyTorch tensors
dataset.set_format('torch')

print('Dataset loaded:')
print(dataset)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text'])
print('Tokenized dataset preview:')
print(tokenized_datasets)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset loaded:
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 400
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 100
    })
})


Map (num_proc=4):   0%|          | 0/400 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/100 [00:00<?, ? examples/s]

Tokenized dataset preview:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 400
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 100
    })
})


## 8. Create a Data Collator for Language Modeling

The `DataCollatorForLanguageModeling` batches tokenized inputs while taking care of padding and label creation (for causal language modeling, the labels are the same as the inputs shifted by one token).

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # GPT-2 uses causal language modeling
    return_tensors='pt'
)

# Optional: Inspect a few samples from the collated batches
num_samples = min(5, len(tokenized_datasets['train']))
for i in range(num_samples):
    batch = data_collator([tokenized_datasets['train'][i]])
    print(f"\nData Collator Output for sample {i}:")
    for key, value in batch.items():
        print(f"{key}: shape {value.shape}, dtype {value.dtype}")

print("\nDecoded sample:")
decoded = tokenizer.decode(tokenized_datasets['train'][0]['input_ids'], skip_special_tokens=True)
print(decoded)


Data Collator Output for sample 0:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 1:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 2:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 3:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 4:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.i

## 9. Set Up Training Arguments and Train the Model

We move the model to the available device (GPU if possible), adjust the token embeddings to include our new tokens, and define training parameters such as batch size, number of epochs, and warmup steps. Finally, we initialize the `Trainer` and start the training process.

In [9]:
# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Resize model embeddings to account for the additional tokens
model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir='./gpt_finetuned_MedQuAD',
    eval_strategy='epoch',
    save_strategy='epoch',
    warmup_steps=500,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_dir='./logs',
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator
)

print('Starting training...')
trainer.train()

Starting training...




<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mprabhundn[0m ([33mprabhundn-iiser-bhopal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss
1,No log,2.697938
2,No log,2.556359
3,No log,2.450659


Epoch,Training Loss,Validation Loss
1,No log,2.697938
2,No log,2.556359
3,No log,2.450659
4,No log,2.357926
5,2.493100,2.28003


TrainOutput(global_step=500, training_loss=2.49309619140625, metrics={'train_runtime': 1249.4097, 'train_samples_per_second': 1.601, 'train_steps_per_second': 0.4, 'total_flos': 1045168128000000.0, 'train_loss': 2.49309619140625, 'epoch': 5.0})

## 10. Save the Fine-Tuned Model

Once training is complete, save both the model and tokenizer for future inference.

In [10]:
model.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
print('Model and tokenizer saved.')

Model and tokenizer saved.


## 11. Inference

Define a helper function to generate responses using the fine-tuned model. Given a prompt (question), the function tokenizes it, generates a response, and then decodes the output.

In [11]:
def generate_response(model, tokenizer, prompt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    # Tokenize the prompt
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Generate response
    with torch.no_grad():
        output = model.generate(
            inputs,
            max_length=1024,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=False,
            temperature=0,
            pad_token_id=tokenizer.eos_token_id
        )
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        # Remove the prompt from the generated output
        generated_response = response[len(prompt):].strip()
    return generated_response

# Example usage
prompt = "what is (are) breast cancer ?"
response = generate_response(model, tokenizer, prompt)
print('Generated Response:')
print(response)



Generated Response:
the most common type of breast cancers is breast carcinoma. the most commonly diagnosed type is colorectal cancer. colostomy is a type that is more common in women than men. it is the second most frequent cancer in the united states. in men, colotoma is about twice as common as colocal cancer, and colocutaneous cancer is twice the common cancer of the u.s. and other countries. breast tumors are the leading cause of death in children. they are most often found in infants, children, young adults, older adults and people with weakened immune systems.the second leading cancer cause in adults is cancer called colitis. this type causes a small amount of cancer to develop in breast tissue. these cancers are more likely to be found on the skin, in bone, or in other parts of a woman's body. there is no cure for coli. however, there are treatments that can help prevent colic.cancer treatment is not always the best way to prevent breast and ovarian cancer; however it can be a 

## Conclusion

This notebook has demonstrated how to fine-tune GPT-2 on the MedQA dataset with improvements in data handling and training setups.

### Key Points:

- **Data Quality:** Focusing on the top 100 focus areas helps ensure robust training data.
- **Preprocessing:** Thorough cleaning and proper formatting (using special tokens) is critical for model understanding.
- **Improvements:** Enhancements such as increasing dataset size, rigorous hyperparameter tuning, and advanced regularization techniques can further improve model accuracy.

### Download all files and folders

In [None]:
import os
import shutil
from google.colab import files

# Define the directory to zip (usually '/content/')
dir_to_zip = '/content/'

# Define the output zip file name
output_filename = 'colab_session_files'

# Create a zip archive of the entire directory
shutil.make_archive(output_filename, 'zip', dir_to_zip)

# Download the zip file
files.download(output_filename + '.zip')