# Multi-Stage Medical LLM Fine-tuning (Llama 3 8B)

This notebook implements a 4-stage training pipeline to create a high-quality medical chatbot.

## Pipeline Stages
1.  **Stage 1: Instruction Tuning** (Alpaca / General Instructions)
    - Goal: Teach the model to follow instructions and format responses.
2.  **Stage 2: Domain Adaptation** (HealthCareMagic, iCliniq, MedDialog, PDF)
    - Goal: Infuse medical knowledge and reasoning capabilities.
3.  **Stage 3: Medicine Recommendation** (MIMIC-IV, DrugBank)
    - Goal: Learn safe medication recommendations and drug interactions.
4.  **Stage 4: Follow-up Questions** (FollowupQ)
    - Goal: Learn to ask relevant follow-up questions to clarify patient queries.

## Requirements
- Google Colab (T4 GPU or better)
- Google Drive mounted with datasets in `doctor_online_data/`

In [1]:
%%capture
import torch
major_version, minor_version = torch.cuda.get_device_capability()
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:
    !pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
else:
    !pip install --no-deps xformers trl peft accelerate bitsandbytes
!pip install PyPDF2 pandas datasets

In [2]:
from unsloth import FastLanguageModel
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
import pandas as pd
import json
import os
import PyPDF2

# Setup Paths (Robust to Local/Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_PATH = "/content/drive/MyDrive/doctor_online_data/"
    OUTPUT_DIR = "/content/drive/MyDrive/doctor_online_data/checkpoints/"
except Exception as e:
    print(f"Drive mount failed or not on Colab: {e}")
    print("Using local paths instead.")
    BASE_PATH = "./doctor_online_data/"
    OUTPUT_DIR = "./doctor_online_data/checkpoints/"

if not os.path.exists(BASE_PATH):
    os.makedirs(BASE_PATH, exist_ok=True)
    print(f"Created local data directory at {BASE_PATH}. Please put your datasets here.")

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Base Path: {BASE_PATH}")
print(f"Output Dir: {OUTPUT_DIR}")

# Model Config - LOW MEMORY SETTINGS
max_seq_length = 1024 # Reduced from 2048 to save memory
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""
EOS_TOKEN = tokenizer.eos_token

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
Drive mount failed or not on Colab: mount failed
Using local paths instead.
Base Path: ./doctor_online_data/
Output Dir: ./doctor_online_data/checkpoints/
==((====))==  Unsloth 2025.11.4: Fast Llama patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.11.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


## Training Helper Function

In [3]:
def train_stage(stage_name, dataset, epochs=1, max_steps=-1):
    print(f"\n=== Starting {stage_name} ===")
    print(f"Dataset size: {len(dataset)}")
    
    # Clear cache before training
    torch.cuda.empty_cache()
    
    trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = dataset,
        dataset_text_field = "text",
        max_seq_length = max_seq_length,
        dataset_num_proc = 2,
        packing = False,
        args = TrainingArguments(
            per_device_train_batch_size = 1, # Reduced to 1
            gradient_accumulation_steps = 8, # Increased to 8 to maintain effective batch size
            warmup_steps = 5,
            num_train_epochs = epochs,
            max_steps = max_steps,
            learning_rate = 2e-4,
            fp16 = not torch.cuda.is_bf16_supported(),
            bf16 = torch.cuda.is_bf16_supported(),
            logging_steps = 10,
            optim = "adamw_8bit",
            weight_decay = 0.01,
            lr_scheduler_type = "linear",
            seed = 3407,
            output_dir = f"{OUTPUT_DIR}/{stage_name}",
        ),
    )
    trainer.train()
    # Save adapter for this stage
    model.save_pretrained(f"{OUTPUT_DIR}/{stage_name}_adapter")
    print(f"=== Completed {stage_name} ===\n")

## Stage 1: Instruction Tuning (Alpaca)
We use the standard Alpaca dataset or your `cleaned_dataset` to establish basic instruction following.

In [None]:
# Try to load local cleaned dataset, else fallback to HF Alpaca
ds_stage1 = None
local_path = f"{BASE_PATH}cleaned_dataset_with_english_translation.csv"

if os.path.exists(local_path):
    print("Loading local Stage 1 dataset...")
    df = pd.read_csv(local_path)
    # Normalize columns
    if 'instruction' not in df.columns: df['instruction'] = "If you are a doctor, please answer the medical questions based on the patient's description."
    if 'input' not in df.columns: df['input'] = df.get('description', '')
    if 'output' not in df.columns: df['output'] = df.get('doctor_response', '')
    ds_stage1 = Dataset.from_pandas(df[['instruction', 'input', 'output']])
else:
    print("Local dataset not found. Loading yahma/alpaca-cleaned from Hugging Face...")
    ds_stage1 = load_dataset("yahma/alpaca-cleaned", split = "train")

ds_stage1 = ds_stage1.map(formatting_prompts_func, batched = True)
train_stage("Stage1_Instruction", ds_stage1, max_steps=100) # Short run for demo, increase steps for real training

Local dataset not found. Loading yahma/alpaca-cleaned from Hugging Face...


Map:   0%|          | 0/51760 [00:00<?, ? examples/s]


=== Starting Stage1_Instruction ===
Dataset size: 51760


Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/51760 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 51,760 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33manaslari610[0m ([33manaslari610-bengal-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,1.4957
20,1.0029
30,0.9724
40,0.9316
50,0.9598
60,0.9398
70,0.9612


## Stage 2: Domain Adaptation (Medical)
Training on `HealthCareMagic`, `iCliniq`, and your PDF to learn medical reasoning.

In [None]:
datasets_stage2 = []

# HealthCareMagic
hcm_path = f"{BASE_PATH}HealthCareMagic-100k.json"
if os.path.exists(hcm_path):
    with open(hcm_path, 'r') as f: data = json.load(f)
    datasets_stage2.append(Dataset.from_list(data))
    print(f"Loaded HealthCareMagic: {len(data)} rows")

# iCliniq
icliniq_path = f"{BASE_PATH}iCliniq.json"
if os.path.exists(icliniq_path):
    with open(icliniq_path, 'r') as f: data = json.load(f)
    formatted = []
    for item in data:
        ans = item.get('answer_chatdoctor') or item.get('answer_icliniq')
        if ans:
            formatted.append({
                "instruction": "If you are a doctor, please answer the medical questions based on the patient's description.",
                "input": item.get('input', ''),
                "output": ans
            })
    datasets_stage2.append(Dataset.from_list(formatted))
    print(f"Loaded iCliniq: {len(formatted)} rows")

# PDF (2503.17509v1.pdf)
pdf_path = f"{BASE_PATH}2503.17509v1.pdf"
if os.path.exists(pdf_path):
    pdf_text = ""
    try:
        with open(pdf_path, 'rb') as f:
            reader = PyPDF2.PdfReader(f)
            for page in reader.pages:
                pdf_text += page.extract_text() + "\n"
        
        # Chunk text
        chunk_size = 1000
        chunks = [pdf_text[i:i+chunk_size] for i in range(0, len(pdf_text), chunk_size)]
        
        pdf_data = []
        for chunk in chunks:
            pdf_data.append({
                "instruction": "Analyze this medical text and summarize key findings.",
                "input": chunk,
                "output": "The text discusses medical concepts found in the document. (Self-supervised context)"
            })
        datasets_stage2.append(Dataset.from_list(pdf_data))
        print(f"Loaded PDF: {len(pdf_data)} chunks")
    except Exception as e:
        print(f"Error loading PDF: {e}")

if datasets_stage2:
    ds_stage2 = concatenate_datasets(datasets_stage2)
    ds_stage2 = ds_stage2.map(formatting_prompts_func, batched = True)
    train_stage("Stage2_Domain", ds_stage2, max_steps=200)
else:
    print("Skipping Stage 2: No datasets found.")

## Stage 3: Medicine Recommendation (Safety)
Training on MIMIC-IV / DrugBank for safe prescribing. 
**Note**: Requires `mimic_iv.csv` or `drugbank.json` in your Drive folder.

In [None]:
ds_stage3 = None
mimic_path = f"{BASE_PATH}mimic_iv.csv"
drugbank_path = f"{BASE_PATH}drugbank.json"

if os.path.exists(mimic_path):
    print("Loading MIMIC-IV...")
    df = pd.read_csv(mimic_path)
    # Expects columns: patient_profile, medication_plan
    if 'patient_profile' in df.columns and 'medication_plan' in df.columns:
        df['instruction'] = "Based on the patient profile, recommend a safe medication plan."
        df['input'] = df['patient_profile']
        df['output'] = df['medication_plan']
        ds_stage3 = Dataset.from_pandas(df[['instruction', 'input', 'output']])

if ds_stage3:
    ds_stage3 = ds_stage3.map(formatting_prompts_func, batched = True)
    train_stage("Stage3_Meds", ds_stage3, max_steps=100)
else:
    print("Skipping Stage 3: MIMIC/DrugBank data not found.")

## Stage 4: Follow-up Questions
Training the model to ask clarifying questions.

In [None]:
ds_stage4 = None
followup_path = f"{BASE_PATH}followup_q.json"

if os.path.exists(followup_path):
    print("Loading FollowupQ...")
    with open(followup_path, 'r') as f: data = json.load(f)
    # Expects: context, answer, followup
    formatted = []
    for item in data:
        formatted.append({
            "instruction": "Given the patient context and doctor answer, generate a relevant follow-up question.",
            "input": f"Context: {item.get('context')}\nAnswer: {item.get('answer')}",
            "output": item.get('followup')
        })
    ds_stage4 = Dataset.from_list(formatted)

if ds_stage4:
    ds_stage4 = ds_stage4.map(formatting_prompts_func, batched = True)
    train_stage("Stage4_Followup", ds_stage4, max_steps=50)
else:
    print("Skipping Stage 4: FollowupQ data not found.")

## Export Final Model
We free up memory before exporting to avoid crashes.

In [None]:
import gc
# Free up memory from training
if 'trainer' in globals():
    del trainer
if 'ds_stage1' in globals(): del ds_stage1
if 'ds_stage2' in globals(): del ds_stage2
if 'ds_stage3' in globals(): del ds_stage3
if 'ds_stage4' in globals(): del ds_stage4

torch.cuda.empty_cache()
gc.collect()

In [None]:
# Save final merged model to GGUF
if 'model' not in globals():
    print("Error: 'model' variable not found. Please run the training cells first.")
else:
    print("Saving final model to GGUF format...")
    # Save to q4_k_m (balanced) and q8_0 (high quality)
    
    # 1. Save Q4_K_M (Recommended for most users)
    try:
        model.save_pretrained_gguf("model_q4", tokenizer, quantization_method = "q4_k_m")
        !cp model_q4-unsloth.Q4_K_M.gguf {OUTPUT_DIR}/medical_llama3_q4_k_m.gguf
        print(f"Saved Q4_K_M to {OUTPUT_DIR}/medical_llama3_q4_k_m.gguf")
    except Exception as e:
        print(f"Error saving Q4_K_M: {e}")

    # 2. Save Q8_0 (High precision)
    # try:
    #     model.save_pretrained_gguf("model_q8", tokenizer, quantization_method = "q8_0")
    #     !cp model_q8-unsloth.Q8_0.gguf {OUTPUT_DIR}/medical_llama3_q8_0.gguf
    #     print(f"Saved Q8_0 to {OUTPUT_DIR}/medical_llama3_q8_0.gguf")
    # except Exception as e:
    #     print(f"Error saving Q8_0: {e}")