In [2]:
import os, torch, threading
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer

### 📚 Dataset Overview: `qiaojin/PubMedQA`

The **PubMedQA** dataset is a biomedical question-answering benchmark derived from **PubMed** abstracts.  
It was introduced by *Qiao Jin et al.* (EMNLP 2019) to evaluate language models on factual reasoning over scientific literature.

This Hugging Face version — **`qiaojin/PubMedQA`** — provides three configurations:

- **`pqa_labeled`** – 1,000 expert-annotated *yes/maybe/no* questions with their corresponding PubMed abstracts and detailed long answers.  
  *Recommended for supervised fine-tuning.*
- **`pqa_artificial`** – automatically generated Q&A pairs created from PubMed titles and abstracts.  
- **`pqa_unlabeled`** – questions collected from PubMed titles without gold answers.

**Summary:**  

| Config name        | Description                                                              | Use case                                   |
| ------------------ | ------------------------------------------------------------------------ | ------------------------------------------ |
| `"pqa_labeled"`    | 1,000 expert-annotated questions with yes/maybe/no answers and abstracts | **Recommended** for supervised fine-tuning |
| `"pqa_artificial"` | Automatically generated Q–A pairs                                        | Good for pre-training / augmentation       |
| `"pqa_unlabeled"`  | Questions without human answers                                          | For semi-supervised or retrieval testing   |


Each entry in the labeled set contains:

| Field | Description |
|--------|--------------|
| `pubid` | PubMed ID of the source article |
| `question` | A research-style yes/no/maybe question |
| `context` | The relevant abstract text |
| `long_answer` | A natural-language justification of the answer |
| `final_decision` | The categorical label (`yes`, `no`, or `maybe`) |

> The dataset supports tasks such as biomedical question answering, literature-based reasoning, and fine-tuning domain-specific LLMs for evidence-grounded responses.

**Citation:**  
> Jin Q., Dhingra B., Liu Z., Cohen W.W., & Lu X. (2019).  
> *PubMedQA: A Dataset for Biomedical Research Question Answering.*  
> Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP).



In [1]:
from datasets import load_dataset


ds = load_dataset("sentence-transformers/pubmedqa", name="triplet-all", split="train")
print(ds[0])

{'anchor': 'Does a history of unintended pregnancy lessen the likelihood of desire for sterilization reversal?', 'positive': 'Unintended pregnancy has been significantly associated with subsequent female sterilization. Whether women who are sterilized after experiencing an unintended pregnancy are less likely to express desire for sterilization reversal is unknown.', 'negative': 'Changes in serum hormone levels induced by combined contraceptives.'}


In [2]:
ds = load_dataset("qiaojin/PubMedQA", name="pqa_labeled", split="train")
print(ds)


Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 1000
})


In [3]:
def to_messages(ex):
    return {
        "messages": [
            {"role": "system", "content": "You are a biomedical research assistant. Be concise, cite PMIDs when possible."},
            {"role": "user", "content": f"Question: {ex['question']}\n\nAbstract (PMID:{ex['pubid']}): {ex['context']}"},
            {"role": "assistant", "content": f"{ex['long_answer']} [PMID:{ex['pubid']}]"}
        ]
    }

chat_ds = ds.map(to_messages, remove_columns=ds.column_names)
print(chat_ds[0])


{'messages': [{'content': 'You are a biomedical research assistant. Be concise, cite PMIDs when possible.', 'role': 'system'}, {'content': "Question: Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?\n\nAbstract (PMID:21645374): {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.', 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage 

In [4]:
import torch, os, random
from datasets import load_dataset, disable_progress_bar
disable_progress_bar()  # avoids widget issues in VS Code notebooks

# Pick a model size: start small, then scale up
MODEL_NAME = "unsloth/gpt-oss-7b-instruct-bnb-4bit"  # fast prototype
# For larger run later: "unsloth/gpt-oss-20b-unsloth-bnb-4bit"

OUT_DIR = "gptoss_pubmedqa_sft"
SEED = 3407
MAX_SEQ_LEN = 2048
MAX_STEPS = 800          # small pilot; increase when happy
BATCH_PER_DEVICE = 1
GRAD_ACCUM = 8           # effective batch = BATCH_PER_DEVICE * GRAD_ACCUM * n_gpus
LR = 2e-4
WARMUP_STEPS = 50
PACKING = True
EVAL_SAMPLES = 256       # small, fast evaluation subset

random.seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7fe47c0e8350>

In [5]:
from datasets import load_dataset

# Load labeled split
raw = load_dataset("qiaojin/PubMedQA", name="pqa_labeled")
# Typical splits may be missing; if only "train" exists, create your own split:
if "train" in raw and "validation" not in raw:
    raw = raw["train"].train_test_split(test_size=0.1, seed=SEED)
    train_ds, eval_ds = raw["train"], raw["test"]
else:
    train_ds = raw["train"]
    eval_ds  = raw.get("validation", raw["train"].select(range(min(EVAL_SAMPLES, len(raw["train"])))))

def to_messages(ex):
    # Build instruction style with citations
    q   = ex.get("question", "")
    ctx = ex.get("context", "") or ""
    pmid = ex.get("pubid", "")
    long_ans = (ex.get("long_answer", "") or "").strip()
    final = (ex.get("final_decision","") or "").strip()  # yes/no/maybe

    user = f"Question: {q}\n\nAbstract (PMID:{pmid}): {ctx}\n\n" \
           f"Task: Answer concisely and cite the PMID in brackets."
    assistant = f"{long_ans} [PMID:{pmid}]".strip()

    return {"messages":[
        {"role":"system",    "content":"You are a biomedical research assistant. Be accurate, concise, and always cite PMIDs in brackets."},
        {"role":"user",      "content": user},
        {"role":"assistant", "content": assistant}
    ],
    "final_decision": final, "pubid": pmid}

train_chat = train_ds.map(to_messages, remove_columns=train_ds.column_names)
eval_chat  = eval_ds.map(to_messages,  remove_columns=eval_ds.column_names)
print(train_chat[0])


{'pubid': 11458136, 'final_decision': 'maybe', 'messages': [{'content': 'You are a biomedical research assistant. Be accurate, concise, and always cite PMIDs in brackets.', 'role': 'system'}, {'content': "Question: Does managed care enable more low income persons to identify a usual source of care?\n\nAbstract (PMID:11458136): {'contexts': ['By requiring or encouraging enrollees to obtain a usual source of care, managed care programs hope to improve access to care without incurring higher costs.', '(1) To examine the effects of managed care on the likelihood of low-income persons having a usual source of care and a usual physician, and; (2) To examine the association between usual source of care and access.', 'Cross-sectional survey of households conducted during 1996 and 1997.', 'A nationally representative sample of 14,271 low-income persons.', 'Usual source of care, usual physician, managed care enrollment, managed care penetration.', 'High managed care penetration in the community 

In [6]:
from unsloth import FastLanguageModel
from peft import LoraConfig, TaskType

MODEL_NAME = "unsloth/gpt-oss-20b-unsloth-bnb-4bit"  # confirmed on HF
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name      = MODEL_NAME,
    dtype           = None,
    max_seq_length  = MAX_SEQ_LEN,
    load_in_4bit    = True,
    full_finetuning = False,
)

# Prepare for inference forward passes (saves mem), then attach LoRA
FastLanguageModel.for_inference(model)





🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.9.11: Fast Gpt_Oss patching. Transformers: 4.56.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.564 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gpt_oss won't work! Using float32.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

GptOssForCausalLM(
  (model): GptOssModel(
    (embed_tokens): Embedding(201088, 2880, padding_idx=199999)
    (layers): ModuleList(
      (0-23): 24 x GptOssDecoderLayer(
        (self_attn): GptOssAttention(
          (q_proj): Linear4bit(in_features=2880, out_features=4096, bias=True)
          (k_proj): Linear4bit(in_features=2880, out_features=512, bias=True)
          (v_proj): Linear4bit(in_features=2880, out_features=512, bias=True)
          (o_proj): Linear4bit(in_features=4096, out_features=2880, bias=True)
        )
        (mlp): GptOssMLP(
          (router): GptOssTopKRouter(
            (linear): Linear(in_features=2880, out_features=32, bias=True)
          )
          (experts): GptOssExperts(
            (gate_up_projs): ModuleList(
              (0-31): 32 x Linear4bit(in_features=2880, out_features=5760, bias=True)
            )
            (down_projs): ModuleList(
              (0-31): 32 x Linear4bit(in_features=2880, out_features=2880, bias=True)
            )


In [7]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules="all-linear",   # or a list like ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
)


Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Unsloth: Making `model.base_model.model.model` require gradients


In [8]:
tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
device = next(model.parameters()).device
device

device(type='cuda', index=0)

In [9]:
from functools import partial
from transformers import PreTrainedTokenizerBase

def render_chat(messages, tokenizer: PreTrainedTokenizerBase, max_len: int):
    # Render using model's chat template; include assistant text (SFT)
    rendered = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=False,   # training sees assistant outputs
        tokenize=False,
    )
    ids = tokenizer(
        rendered, truncation=True, max_length=max_len, return_tensors="pt"
    )
    return {"input_ids": ids["input_ids"][0], "attention_mask": ids["attention_mask"][0]}

def build_supervised(ex, tokenizer, max_len):
    out = render_chat(ex["messages"], tokenizer, max_len)
    out["labels"] = out["input_ids"].clone()
    return out

proc = partial(build_supervised, tokenizer=tokenizer, max_len=MAX_SEQ_LEN)
train_tok = train_chat.map(proc, remove_columns=train_chat.column_names)
eval_tok  = eval_chat.map(proc,  remove_columns=eval_chat.column_names)


train_dataset = train_tok

len(train_dataset)


900

In [10]:
import torch, os

has_cuda = torch.cuda.is_available()
bf16_ok  = has_cuda and torch.cuda.is_bf16_supported()   # Ampere+ with proper CUDA
# fp16 is widely supported on NVIDIA GPUs even if bf16 isn't
fp16_ok  = has_cuda and not bf16_ok

# If you sometimes run on CPU, also switch the optimizer (bitsandbytes needs CUDA)
optim_name = "paged_adamw_8bit" if has_cuda else "adamw_torch"

from trl import SFTConfig

training_args = SFTConfig(
    output_dir=OUT_DIR,
    per_device_train_batch_size=BATCH_PER_DEVICE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,           # or use num_train_epochs
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    bf16=bf16_ok,                  # <-- only True on Ampere+
    fp16=fp16_ok,                  # <-- True on pre-Ampere NVIDIA GPUs
    gradient_checkpointing=True,
    optim=optim_name,              # bitsandbytes on GPU, torch AdamW on CPU
    lr_scheduler_type="linear",
    report_to="none",
    packing=PACKING,
    seed=SEED,
)




In [11]:
# Do this once, BEFORE instantiating SFTTrainer
import unsloth_zoo.logging_utils as _lu
_lu.PatchRLStatistics = lambda *args, **kwargs: None  # disable Unsloth's RL stats patch temporarily


In [12]:
def to_text(ex):
    return {
        "text": tokenizer.apply_chat_template(
            ex["messages"],
            add_generation_prompt=False,
            tokenize=False,
        )
    }

train_text = train_chat.map(to_text, remove_columns=train_chat.column_names)

from trl import SFTTrainer, SFTConfig
import torch

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_text,   # now has a "text" field
    dataset_text_field = "text",  # <-- tell TRL which field to use
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 30,
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "paged_adamw_8bit" if torch.cuda.is_available() else "adamw_torch",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
        packing = True,
        bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
        fp16 = torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
        gradient_checkpointing = True,
        save_steps = 200,
        save_total_limit = 2,
    ),
)


Unsloth: Switching to float32 training since model cannot work with float16


  super().__init__(


In [None]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 199998, 'pad_token_id': 200017}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 302 | Num Epochs = 1 | Total steps = 30
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 7,962,624 of 20,922,719,808 (0.04% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,5.8727
2,5.8685
3,5.5847
4,5.0864
5,4.1728
6,3.2695
7,2.8301
8,2.2462
9,2.1626
10,1.8458


In [None]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 199998, 'pad_token_id': 200017}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 281 | Num Epochs = 1 | Total steps = 30
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 7,962,624 of 20,922,719,808 (0.04% trained)
