In [None]:
# How to Fine-Tune GPT-2 for Medical Q&A (CPU-friendly)
# ==============================
# 0️⃣ Install & imports
# (In Colab: this will install required packages. In a local env you might skip install.)
!pip install -q transformers datasets

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import Dataset
import re
import os

# Force CPU usage (even if a GPU is present; remove this line if you want GPU use)
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# Device detection
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ==============================
# 1️⃣ Prepare Q&A dataset (supervised)
# - We train on Q: <question> A: <answer> lines so the model learns to map questions → answers
# - Add or replace these pairs with your own medical Q&A to improve accuracy
# ==============================
qa_pairs = [
    {"q": "What are the main types of cardiovascular diseases?",
     "a": "Coronary artery disease, stroke, heart failure, and hypertension."},

    {"q": "What are the risk factors for cardiovascular diseases?",
     "a": "High blood pressure, smoking, obesity, diabetes, and high cholesterol."},

    {"q": "How can cardiovascular diseases be prevented?",
     "a": "Healthy diet, regular exercise, quitting smoking, and controlling blood pressure and cholesterol."},

    {"q": "What causes diabetes?",
     "a": "Diabetes results from problems with insulin production or insulin use in the body."},

    {"q": "How can diabetes be managed effectively?",
     "a": "With a balanced diet, regular exercise, medication or insulin, and monitoring blood sugar levels."},

    {"q": "What are complications of uncontrolled diabetes?",
     "a": "Kidney damage, vision loss, nerve damage, and increased risk of heart disease."},

    {"q": "What are common symptoms of flu?",
     "a": "Fever, cough, sore throat, body aches, and fatigue."},

    {"q": "How can flu be prevented?",
     "a": "Seasonal vaccination, hand washing, and avoiding close contact with sick people."},

    {"q": "How can we improve our immune system naturally?",
     "a": "Eat nutritious food, get regular sleep, exercise, stay hydrated, and manage stress."},

    {"q": "What vitamins help immunity?",
     "a": "Vitamins C and D and minerals like zinc support immune function."},

    {"q": "What are common respiratory diseases?",
     "a": "Asthma, chronic obstructive pulmonary disease (COPD), pneumonia, and bronchitis."},

    {"q": "How to maintain good respiratory health?",
     "a": "Avoid smoking, reduce pollution exposure, get vaccinations, and seek prompt care for infections."}
]

# Convert to HF Dataset format: each example is a single text string "Q: ... A: ..."
train_texts = [{"text": f"Q: {item['q']} A: {item['a']}"} for item in qa_pairs]
dataset = Dataset.from_list(train_texts)

# ==============================
# 2️⃣ Load GPT-2 tokenizer & model (small GPT-2)
# - Using GPT-2 (small) to keep runtime low on CPU
# ==============================
model_name = "gpt2"  # small GPT-2
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)

# GPT-2 has no pad token by default — set pad_token to eos
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

# ==============================
# 3️⃣ Tokenize dataset
# - We set a comfortable max_length (128). For longer Q/A, increase this.
# ==============================
def tokenize_fn(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=128)

tokenized = dataset.map(tokenize_fn, batched=True)
tokenized = tokenized.remove_columns([c for c in tokenized.column_names if c not in ("input_ids","attention_mask")])
tokenized.set_format(type="torch", columns=["input_ids","attention_mask"])

# ==============================
# 4️⃣ Training arguments (CPU-friendly)
# - Small batch size and modest epochs for CPU runs
# ==============================
training_args = TrainingArguments(
    output_dir="./gpt2-medical-qa-cpu",
    overwrite_output_dir=True,
    num_train_epochs=4,                # increase epochs if you have more data
    per_device_train_batch_size=1,     # keep 1 on CPU
    gradient_accumulation_steps=1,     # increase to simulate larger batch if desired
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_steps=50,
    save_total_limit=2,
    fp16=False,                        # disabled for CPU
    report_to=[]                       # disable wandb/other logging integrations
)

# ==============================
# 5️⃣ Data collator
# - Simple collator that returns input_ids, attention_mask and labels (labels = input_ids for causal LM)
# ==============================
def data_collator(batch):
    return {
        "input_ids": torch.stack([b["input_ids"] for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
        "labels": torch.stack([b["input_ids"] for b in batch])
    }

# ==============================
# 6️⃣ Trainer setup
# ==============================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# ==============================
# 7️⃣ Fine-tune the model (this will run on CPU)
# - Training time depends on number of epochs and dataset size; be patient on CPU.
# ==============================
trainer.train()

# ==============================
# 8️⃣ Save the fine-tuned model and tokenizer
# ==============================
trainer.save_model("./gpt2-medical-qa-cpu")
tokenizer.save_pretrained("./gpt2-medical-qa-cpu")

# Optionally reload (verifies save)
model = GPT2LMHeadModel.from_pretrained("./gpt2-medical-qa-cpu").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("./gpt2-medical-qa-cpu")
tokenizer.pad_token = tokenizer.eos_token

# ==============================
# 9️⃣ Inference: generate answers (deterministic beam search)
# - We use the "Q: ... A:" prompt format the model was trained on.
# - Beam search (do_sample=False, num_beams>1) gives more stable factual outputs.
# - We also post-process to remove repeated question text and tidy whitespace.
# ==============================
def ask_question(question, max_length=80):
    prompt = f"Q: {question} A:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **inputs,
        max_length=max_length,
        do_sample=False,          # deterministic decoding
        num_beams=4,              # beam size — higher = more compute but often more accurate
        early_stopping=True,
        no_repeat_ngram_size=2,
        pad_token_id=tokenizer.eos_token_id
    )
    decoded = tokenizer.decode(out[0], skip_special_tokens=True)
    # Keep only text after "A:" and remove any repeated question text
    if "A:" in decoded:
        decoded = decoded.split("A:",1)[1]
    decoded = re.sub(re.escape(question), "", decoded, flags=re.IGNORECASE)
    decoded = re.sub(r"\n+", " ", decoded)
    decoded = re.sub(r"\s{2,}", " ", decoded).strip()
    return decoded

# ==============================
# 10️⃣ Test sample questions (prints Q / A pairs cleanly)
# ==============================
test_questions = [
    "What are the risk factors for cardiovascular diseases?",
    "How can diabetes be managed effectively?",
    "What are the common symptoms of flu?",
    "How can we improve our immune system naturally?",
    "What are common respiratory diseases?"
]

for i, q in enumerate(test_questions, 1):
    print(f"Q{i}: {q}")
    print(f"A{i}: {ask_question(q)}\n")


Using device: cpu


Map:   0%|          | 0/12 [00:00<?, ? examples/s]

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.


Step,Training Loss
10,2.2153
20,0.4781
30,0.3313
40,0.2961


Q1: What are the risk factors for cardiovascular diseases?
A1: High blood pressure, smoking, high cholesterol, diabetes, heart disease, and hypertension.

Q2: How can diabetes be managed effectively?
A2: Diabetes is caused by insulin deficiency, insulin resistance, high blood pressure, and high cholesterol.

Q3: What are the common symptoms of flu?
A3: Fever, cough, sore throat, and body aches.

Q4: How can we improve our immune system naturally?
A4: Healthy diet, regular exercise, and regular blood pressure control.

Q5: What are common respiratory diseases?
A5: Chronic obstructive pulmonary disease (COPD), pneumonia, and bronchitis.

