Medical Fine-tuning with QLoRA using Unsloth
==============================================================

This notebook demonstrates efficient fine-tuning of language models for medical
question-answering using QLoRA (Quantized Low-Rank Adaptation) with Unsloth.

Environment: Colab

# STEP 1: Install dependencies

In [None]:
# Install Unsloth and dependencies
!pip install unsloth --upgrade
!pip install bitsandbytes transformers accelerate datasets peft trl


Collecting unsloth
  Downloading unsloth-2025.10.9-py3-none-any.whl.metadata (59 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/59.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.0/59.0 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth_zoo>=2025.10.10 (from unsloth)
  Downloading unsloth_zoo-2025.10.10-py3-none-any.whl.metadata (31 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.35-py3-none-any.whl.metadata (12 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Collecting bitsandbytes!=0.46.0,!=0.48.0,>=0.45.5 (from unsloth)
  Downloading bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting datasets!=4.0.*,!=4.1.0,>=3.4.1 (from unsloth)
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting transformers!=4.52.0,!=4.52.1,!=4.52.2,!=

# STEP 2: Import & load model (Qwen2.5-1.5B-Instruct)

In [None]:
from unsloth import FastLanguageModel
from datasets import load_dataset
import torch

model_name = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,  # Enables QLoRA
)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.




🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.9: Fast Qwen2 patching. Transformers: 4.56.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/1.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/270 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

# STEP 3: Load your medical reasoning dataset

In [None]:
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en')
print(dataset["train"][0])

medical_o1_sft.json:   0%|          | 0.00/58.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/19704 [00:00<?, ? examples/s]

{'Question': 'Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?', 'Complex_CoT': "Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?\n\nBut wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.\n\nSo, now I'm thinking, how could a clot in the leg end up causing issues like weakness or stroke symptoms?\n\nOh, right! There's this thing called a paradoxical embolism. It can happen if there's some kind of short circuit in the heart - like a hole that shouldn't be there.\n\nLet's put this together: if a blood clot from the leg somehow travels to the le

# STEP 4: Format dataset for supervised fine-tuning

In [None]:
def format_example(example):
    instruction = example.get("instruction", "")
    input_text = example.get("input", "")
    output = example.get("output", "")
    return {
        "text": f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
    }

train_dataset = dataset["train"].map(format_example)


Map:   0%|          | 0/19704 [00:00<?, ? examples/s]

# STEP 5: Apply LoRA adapters (QLoRA)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=32,
    lora_dropout=0,
    bias="none",
)


Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2025.10.9 patched 28 layers with 28 QKV layers, 0 O layers and 0 MLP layers.


# STEP 6: Configure fine-tuning

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir="./medical-qwen-qlora",
    num_train_epochs=2,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    warmup_steps=5,
    learning_rate=2e-4,
    max_steps=60,
    fp16=True,
    logging_steps=10,
    save_strategy="epoch",
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=2048,
    args=training_args,
)


# STEP 7: Start training

In [None]:
trainer.train()


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 19,704 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 2,179,072 of 1,545,893,376 (0.14% trained)


Step,Training Loss
10,0.0017
20,2.7843
30,0.2986
40,0.0
50,0.0
60,0.0


TrainOutput(global_step=60, training_loss=0.5141167039711338, metrics={'train_runtime': 275.9578, 'train_samples_per_second': 1.739, 'train_steps_per_second': 0.217, 'total_flos': 34020510105600.0, 'train_loss': 0.5141167039711338, 'epoch': 0.024360535931790498})

# STEP 8: Save fine-tuned adapter

In [None]:
model.save_pretrained("medical-qwen-qlora-adapter")
tokenizer.save_pretrained("medical-qwen-qlora-adapter")

('medical-qwen-qlora-adapter/tokenizer_config.json',
 'medical-qwen-qlora-adapter/special_tokens_map.json',
 'medical-qwen-qlora-adapter/chat_template.jinja',
 'medical-qwen-qlora-adapter/vocab.json',
 'medical-qwen-qlora-adapter/merges.txt',
 'medical-qwen-qlora-adapter/added_tokens.json',
 'medical-qwen-qlora-adapter/tokenizer.json')

# STEP 9: Reload and test on new medical queries


In [None]:
# Test 1: Stroke-related query
from transformers import pipeline

model.eval()
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    repetition_penalty=1.1,
)

query = "A 60-year-old hypertensive patient presents with sudden weakness on one side. What is the differential diagnosis and immediate management?"
response = pipe(query)[0]['generated_text']
print(response)


Device set to use cuda:0


A 60-year-old hypertensive patient presents with sudden weakness on one side. What is the differential diagnosis and immediate management?:
A: Acute cerebrovascular disease, administer intravenous heparin
B: Acute ischemic stroke, administer intravenous alteplase
C: Acute hemorrhagic stroke, administer intravenous antihypertensives
D: Acute cerebral edema, administer diuretics
E: None of above

Acute ischemic stroke (AIS) is a common cause of acute neurological symptoms in elderly patients with hypertension.
Answer:

B


In [None]:
# Test 2: Chest pain & diagnostic reasoning
from transformers import pipeline

model.eval()
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    repetition_penalty=1.1,
)

new_query = "A 45-year-old man presents with central chest pain radiating to the left arm and jaw, sweating, and nausea. What are the next steps in management and key investigations?"
new_response = pipe(new_query)[0]['generated_text']
print(new_response)

Device set to use cuda:0


A 45-year-old man presents with central chest pain radiating to the left arm and jaw, sweating, and nausea. What are the next steps in management and key investigations? A:
A: ECG
B: Coronary Angiography
C: Cardiac CT scan
D: Blood tests for CK-MB levels
E: None of above

The most appropriate answer is:

A: ECG

Explanation:

### Step-by-Step Reasoning:

1. **Symptom Presentation**:
   - The patient presents with central chest pain that is radiating to the left arm and jaw.
   - This presentation suggests a possible cardiac cause.

2. **Key Diagnostic Criteria**:
   - Central chest pain can be indicative of various conditions including angina, myocardial infarction (heart attack), or other serious cardiovascular issues.
   - Radiating pain from the heart to extremities is particularly concerning and warrants immediate investigation.

3. **Diagnostic Tests**:
   - **Electrocardiogram (ECG)**: 
     - An ECG is essential because it provides information about the electrical activity of th