# Fine-Tuning GPT-2 on MedQA Dataset: Second Iteration for Enhanced Accuracy

In this iteration, we introduce several improvements to further boost the model's generation accuracy by up to 10×. Enhancements include:

- **Model Upgrade:** We switch from `gpt2` to `gpt2-medium` for increased model capacity (if GPU memory permits).
- **Training Enhancements:** We add gradient accumulation steps, early stopping callbacks, and more frequent logging to optimize training.
- **Rigorous Evaluation:** We calculate perplexity from evaluation loss to quantitatively measure performance.
- **Data Quality:** Continue to focus on high-quality data by processing only the top 100 focus areas.

Let's dive in!

## 1. Install Dependencies

Install the required libraries. We also ensure that the correct version of `pyarrow` is installed for compatibility.

In [1]:
%%capture
!pip -q uninstall pyarrow -y
!pip -q install pyarrow==15.0.2 datasets accelerate transformers kaggle

import os
import json
import torch
import pandas as pd
import numpy as np
import math
from zipfile import ZipFile
from sklearn.model_selection import train_test_split

from google.colab import drive
drive.mount('/content/drive/')

# Set up Kaggle API credentials (make sure your kaggle.json is in the correct location)
os.environ['KAGGLE_CONFIG_DIR'] = "/content/drive/My Drive/Colab Notebooks/GPT2FineTune/.kaggle"
print('Dependencies installed and environment set up.')

## 2. Download and Load the MedQA Dataset

Download the MedQA dataset from Kaggle and load it into a Pandas DataFrame.

In [2]:
# Define Kaggle dataset identifier
dataset = "gpreda/medquad"

# Download the dataset from Kaggle
os.system(f'kaggle datasets download -d {dataset}')

# Unzip the dataset
with ZipFile('medquad.zip', 'r') as zip_ref:
    zip_ref.extractall('medquad')

# Load the CSV file into a DataFrame
df = pd.read_csv('medquad/medquad.csv')
print('First few rows of the dataset:')
print(df.head())

First few rows of the dataset:
                                 question  \
0                What is (are) Glaucoma ?   
1                  What causes Glaucoma ?   
2     What are the symptoms of Glaucoma ?   
3  What are the treatments for Glaucoma ?   
4                What is (are) Glaucoma ?   

                                              answer           source  \
0  Glaucoma is a group of diseases that can damag...  NIHSeniorHealth   
1  Nearly 2.7 million people have glaucoma, a lea...  NIHSeniorHealth   
2  Symptoms of Glaucoma  Glaucoma can develop in ...  NIHSeniorHealth   
3  Although open-angle glaucoma cannot be cured, ...  NIHSeniorHealth   
4  Glaucoma is a group of diseases that can damag...  NIHSeniorHealth   

  focus_area  
0   Glaucoma  
1   Glaucoma  
2   Glaucoma  
3   Glaucoma  
4   Glaucoma  


## 3. Data Preprocessing

Clean the dataset by converting text to lowercase, removing extra spaces/newlines, dropping rows with missing values, and removing duplicate Q&A pairs. Removed rows are optionally saved for audit.

In [3]:
# Convert all text columns to lowercase and remove extra spaces/newlines
for col in df.select_dtypes(include=['object']).columns:
    df[col] = df[col].str.lower().str.split().str.join(' ')

print('Cleaned DataFrame preview:')
print(df.head())

# Remove rows with missing values and record them
removed_records = pd.DataFrame()
missing_values = df[df.isnull().any(axis=1)]
removed_records = pd.concat([removed_records, missing_values])
df = df.dropna()

# Remove duplicate Q&A pairs based on 'question' and 'answer'
duplicates = df[df.duplicated(subset=['question', 'answer'], keep=False)]
removed_records = pd.concat([removed_records, duplicates])
df = df.drop_duplicates(subset=['question', 'answer'])

# Optionally, save removed records for auditing
if not removed_records.empty:
    removed_records.to_csv('removed_records_audit.csv', index=False)

print('Data preprocessing complete.')

Cleaned DataFrame preview:
                                 question  \
0                what is (are) glaucoma ?   
1                  what causes glaucoma ?   
2     what are the symptoms of glaucoma ?   
3  what are the treatments for glaucoma ?   
4                what is (are) glaucoma ?   

                                              answer           source  \
0  glaucoma is a group of diseases that can damag...  nihseniorhealth   
1  nearly 2.7 million people have glaucoma, a lea...  nihseniorhealth   
2  symptoms of glaucoma glaucoma can develop in o...  nihseniorhealth   
3  although open-angle glaucoma cannot be cured, ...  nihseniorhealth   
4  glaucoma is a group of diseases that can damag...  nihseniorhealth   

  focus_area  
0   glaucoma  
1   glaucoma  
2   glaucoma  
3   glaucoma  
4   glaucoma  
Data preprocessing complete.


## 4. Create Training and Validation Splits

We focus on the top 100 focus areas (ensuring balanced, high-quality data) and split the dataset: 4 records per category for training and 1 for validation.

In [4]:
# Select the top 100 focus areas based on record counts
top_100_categories = df['focus_area'].value_counts().nlargest(100).index.tolist()

train_data = pd.DataFrame()
val_data = pd.DataFrame()

for category in top_100_categories:
    # Sample 4 records for training
    train_samples = df[df['focus_area'] == category].sample(n=4, random_state=42)

    # Sample 1 record for validation (excluding training samples)
    val_samples = df[(df['focus_area'] == category) & (~df.index.isin(train_samples.index))].sample(n=1, random_state=42)

    train_data = pd.concat([train_data, train_samples])
    val_data = pd.concat([val_data, val_samples])

print(f"Training set size: {train_data.shape[0]}")
print(f"Validation set size: {val_data.shape[0]}")

Training set size: 400
Validation set size: 100


## 5. Prepare the Data for Fine-Tuning

Combine each Q&A pair into a single text sequence using custom special tokens:

- `<question>` marks the beginning of a question
- `<answer>` separates the question and answer
- `<end>` indicates the end of the sequence

This structured formatting helps the model understand the Q&A pattern.

In [5]:
def combine_text(df):
    return df.apply(lambda row: f"<question>{row['question']}<answer>{row['answer']}<end>", axis=1)

train_sequences = combine_text(train_data)
val_sequences = combine_text(val_data)

# Join sequences with newline delimiters
train_text = '\n'.join(train_sequences)
val_text = '\n'.join(val_sequences)

# Save to text files
with open('train_data.txt', 'w') as f:
    f.write(train_text)

with open('val_data.txt', 'w') as f:
    f.write(val_text)

print('Training and validation text files have been saved.')

Training and validation text files have been saved.


## 6. Load the Pretrained GPT-2 Model and Tokenizer

For enhanced performance, we load the `gpt2-medium` model (a larger variant of GPT-2) along with its tokenizer. We then add our custom special tokens.

In [6]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Use a larger model variant for improved capacity (if your hardware allows)
model = GPT2LMHeadModel.from_pretrained('gpt2-medium', use_cache=False)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

# Add special tokens matching our Q&A format
special_tokens = {
    'pad_token': '<pad>',
    'bos_token': '<question>',
    'eos_token': '<end>',
    'sep_token': '<answer>'
}
tokenizer.add_special_tokens(special_tokens)

print('Special tokens added and gpt2-medium model/tokenizer loaded.')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Special tokens added and gpt2-medium model/tokenizer loaded.


## 7. Tokenize the Dataset

We load the prepared text files using Hugging Face's `load_dataset` and tokenize each sequence. Sequences are truncated/padded to a maximum of 1024 tokens.

In [7]:
from datasets import load_dataset

def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=1024)

# Load dataset from our text files
dataset = load_dataset('text', data_files={'train': 'train_data.txt', 'validation': 'val_data.txt'})

# Set dataset format to PyTorch tensors
dataset.set_format('torch')

print('Dataset loaded:')
print(dataset)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text'])
print('Tokenized dataset preview:')
print(tokenized_datasets)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset loaded:
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 400
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 100
    })
})


Map (num_proc=4):   0%|          | 0/400 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/100 [00:00<?, ? examples/s]

Tokenized dataset preview:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 400
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 100
    })
})


## 8. Create a Data Collator for Language Modeling

The `DataCollatorForLanguageModeling` handles padding and constructs the label tensor for causal language modeling (GPT-2).

In [8]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # GPT-2 uses causal language modeling
    return_tensors='pt'
)

# Optionally, inspect a few samples
num_samples = min(5, len(tokenized_datasets['train']))
for i in range(num_samples):
    batch = data_collator([tokenized_datasets['train'][i]])
    print(f"\nData Collator Output for sample {i}:")
    for key, value in batch.items():
        print(f"{key}: shape {value.shape}, dtype {value.dtype}")

print("\nDecoded sample:")
decoded = tokenizer.decode(tokenized_datasets['train'][0]['input_ids'], skip_special_tokens=True)
print(decoded)


Data Collator Output for sample 0:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 1:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 2:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 3:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.int64
labels: shape torch.Size([1, 1024]), dtype torch.int64

Data Collator Output for sample 4:
input_ids: shape torch.Size([1, 1024]), dtype torch.int64
attention_mask: shape torch.Size([1, 1024]), dtype torch.i

## 9. Set Up Training Arguments and Trainer with Early Stopping

We now define our training parameters. Enhancements in this iteration include:

- **Gradient Accumulation:** Set to 2 to simulate larger batch sizes without exceeding GPU memory.
- **Early Stopping:** Using `EarlyStoppingCallback` to halt training if evaluation loss does not improve.
- **Frequent Logging:** Logging steps are set to monitor progress closely.
- **Load Best Model:** Save and load the best checkpoint based on evaluation loss.

These modifications aim to stabilize and optimize training for higher accuracy.

In [11]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Resize model embeddings to include the added tokens
model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir='./gpt_finetuned_MedQuAD_v2',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    gradient_accumulation_steps=4,  # Increased gradient accumulation steps
    logging_steps=100,
    per_device_train_batch_size=2,  # Reduced batch size
    per_device_eval_batch_size=2,  # Reduced batch size
    num_train_epochs=7,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    logging_dir='./logs',
    fp16=True,  # Enable mixed precision training if supported by your hardware
    gradient_checkpointing=True,  # Enable gradient checkpointing
)

# Define Trainer with an EarlyStoppingCallback (patience set to 2 epochs)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

print('Starting training with enhanced settings...')
trainer.train()



Starting training with enhanced settings...


Epoch,Training Loss,Validation Loss
1,No log,2.135555
2,2.093800,2.019944
3,2.093800,1.95826
4,1.648700,1.928721
5,1.648700,1.929083
6,1.443000,1.927209
7,1.443000,1.936986


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


TrainOutput(global_step=350, training_loss=1.675875222342355, metrics={'train_runtime': 1485.1547, 'train_samples_per_second': 1.885, 'train_steps_per_second': 0.236, 'total_flos': 5200723889356800.0, 'train_loss': 1.675875222342355, 'epoch': 7.0})

## 10. Evaluate and Compute Perplexity

After training, we evaluate the model and compute perplexity as an indicator of model performance.

Perplexity is computed as the exponential of the evaluation loss.

In [12]:
eval_metrics = trainer.evaluate()
perplexity = math.exp(eval_metrics['eval_loss'])
print(f"Evaluation Loss: {eval_metrics['eval_loss']}")
print(f"Perplexity: {perplexity}")

Evaluation Loss: 1.927209496498108
Perplexity: 6.870311838189805


## 11. Save the Fine-Tuned Model and Tokenizer

Save the best model checkpoint and tokenizer for later use in inference.

In [13]:
model.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
print('Model and tokenizer saved to:', training_args.output_dir)

Model and tokenizer saved to: ./gpt_finetuned_MedQuAD_v2


## 12. Inference

Define a helper function to generate responses with the fine-tuned model. The function tokenizes the input prompt, generates text, and then decodes the output.

In [14]:
def generate_response(model, tokenizer, prompt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    # Tokenize prompt
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Generate response
    with torch.no_grad():
        output = model.generate(
            inputs,
            max_length=1024,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,          # Enable sampling for more diverse outputs
            top_k=50,                # Use top-k sampling
            top_p=0.95,              # Use nucleus sampling
            temperature=0.7,         # Slight randomness
            pad_token_id=tokenizer.eos_token_id
        )
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_response = response[len(prompt):].strip()
    return generated_response

# Example inference
prompt = "what is (are) breast cancer ?"
response = generate_response(model, tokenizer, prompt)
print('Generated Response:')
print(response)

Generated Response:
breast cancers are among the most common cancers in women. about 10 to 20 percent of all women will have breast disease, and many women with breast tumors never get an early diagnosis. if you have a breast tumor, you can help your doctor make an informed decision about whether to remove it or treat it. what is the diagnosis and prognosis for breast and ovarian cancer? the doctor will make a diagnosis of breast or ovarian cancers based on your symptoms and the information available at the time of diagnosis, your age, the type of cancer, how advanced the disease is, where the cancer is located, what medicines you take, signs of treatment response and your family history. the following chart shows how often a cancer diagnosis is made in each type and how long the treatment will take. - breast (top): - ovarian (bottom): breast breast: 85 to 90 percent ovarian: 60 to 70 percent breast/ovarian cancer: 50 to 60 percent - invasive breast carcinoma: 40 to 50 percent invasive

## Conclusion

This second iteration notebook demonstrates a refined approach to fine-tuning GPT-2 on a MedQA dataset. By using a larger model variant, enhanced training strategies (including gradient accumulation and early stopping), and rigorous evaluation metrics (perplexity), we aim to significantly improve the generation accuracy.

### Download all files and folders

In [19]:
import os
import shutil
from google.colab import files

# Define the directory to zip (usually '/content/')
dir_to_zip = '/content/'

# Define the output zip file name
output_filename = 'colab_session_files'

# Create a zip archive of the entire directory
shutil.make_archive(output_filename, 'zip', dir_to_zip)

# Download the zip file
files.download(output_filename + '.zip')

OSError: [Errno 28] No space left on device