# 🏥 Medical Fine-tuning with QLoRA using Unsloth (Colab-ready)

**Educational demo — not for clinical use.**  
Ensure you use de-identified data and follow licenses and institutional approvals.

In [None]:
# GPU check
import sys, torch, os
print("Python:", sys.version)
print("Torch CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    try:
        print("GPU:", torch.cuda.get_device_name(0))
    except:
        pass
!nvidia-smi || echo "nvidia-smi not available"

## 1) Install required libraries
This installs Unsloth and common fine-tuning tooling. Runtime may take a minute.

In [None]:
%%bash
python -V
pip install -q --upgrade pip
pip install -q unsloth "transformers>=4.43" "datasets>=2.18" "accelerate>=0.30" "trl>=0.9.6" bitsandbytes peft
python -c "import unsloth, transformers; print('unsloth', getattr(unsloth, '__version__', 'unknown'))"

## 2) (Optional) Hugging Face login
If your base model is gated (or you want to push the adapter), provide an HF token.

In [None]:
from getpass import getpass
import os
HF_TOKEN = os.environ.get("HF_TOKEN") or ""
if not HF_TOKEN:
    HF_TOKEN = getpass("Hugging Face token (press Enter to skip): ")
if HF_TOKEN:
    from huggingface_hub import login
    login(HF_TOKEN)
    print("Logged into Hugging Face")
else:
    print("Skipping HF login")

## 3) Config: choose base model, dataset and hyperparameters
Adjust these values to match your GPU availability.

In [None]:
# Configuration - adjust as needed
BASE_MODEL   = "meta-llama/Llama-3.1-8B-Instruct"  # example; change if desired
DATASET_NAME = "bigbio/med_qa"                     # example domain dataset
SPLIT_TRAIN  = "train"
SPLIT_EVAL   = "validation"                       # may be "validation" or "test"
MAX_SEQ_LEN  = 2048
MICRO_BATCH  = 1
GRAD_ACCUM   = 8
EPOCHS       = 1
LR           = 2e-4
LORA_R       = 32
LORA_ALPHA   = 16
LORA_DROPOUT = 0.05
PROJECT_NAME = "med-qlora-unsloth"
print("Config loaded.")

## 4) Load & preprocess dataset
We convert dataset rows into a `messages`-style chat object: `system -> user -> assistant`.
Start with a small subset (e.g., `.select(range(200))`) for debugging if needed.

In [None]:
from datasets import load_dataset, DatasetDict
import random, textwrap

raw = load_dataset(DATASET_NAME)
print(raw)

def standardize(example):
    # Basic heuristics for common medical QA datasets - adjust for your dataset schema.
    q = example.get("question") or example.get("questions") or example.get("query") or ""
    a = example.get("answer") or example.get("final_answer") or example.get("long_answer") or ""
    if isinstance(a, list):
        a = a[0] if a else ""
    if a is None:
        a = ""
    messages = [
        {"role":"system","content":"You are a helpful, careful medical assistant. Do not provide professional diagnosis; provide general information and safety guidance."},
        {"role":"user","content": str(q)},
        {"role":"assistant","content": str(a)}
    ]
    return {"messages": messages, "prompt": str(q), "answer": str(a)}

# Map dataset - recommend testing on a small subset first for speed
proc = {}
for split, ds in raw.items():
    try:
        # small debug: ds = ds.select(range(min(200, len(ds))))
        proc[split] = ds.map(standardize, remove_columns=ds.column_names)
    except Exception as e:
        # fallback: simple mapping to string for unknown schema
        proc[split] = ds.map(lambda ex: {"messages":[{"role":"user","content": str(ex)}]}, remove_columns=ds.column_names)

dataset = DatasetDict(proc)
print(dataset)

## 5) Load 4-bit base model & attach QLoRA adapters using Unsloth
This uses Unsloth helpers to load the model in 4-bit (NF4 / bnb) and attach PEFT adapters.

In [None]:
from unsloth import FastLanguageModel
from transformers import AutoTokenizer
import torch

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

print("Loading model in 4-bit (this may take a while)...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = BASE_MODEL,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True,
    use_gradient_checkpointing = "unsloth",
    use_cache = False,
    tokenizer = tokenizer,
)

# Prepare model for training
FastLanguageModel.for_training(model)

# Attach LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r = LORA_R,
    target_modules = "all-linear",
    lora_alpha = LORA_ALPHA,
    lora_dropout = LORA_DROPOUT,
    bias = "none",
)
model.print_trainable_parameters()
print('Model ready.')

## 6) Tokenize / apply chat template
Use the tokenizer's chat template helper if available to format inputs for supervised finetuning.

In [None]:
def apply_template(example):
    # Unsloth tokenizers often include an apply_chat_template helper. Use tokenize=False to get text.
    try:
        text = tokenizer.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False)
    except Exception:
        # fallback: simple concatenation
        parts = []
        for m in example.get("messages", []):
            parts.append(f"{m.get('role', '')}: {m.get('content', '')}")
        text = "\n".join(parts)
    return {"text": text}

tokenized_ds = {}
for split, ds in dataset.items():
    # small debug subset option: ds = ds.select(range(200))
    tokenized = ds.map(apply_template, remove_columns=ds.column_names)
    def tok(ex):
        out = tokenizer(ex["text"], truncation=True, padding="max_length", max_length=MAX_SEQ_LEN)
        out["labels"] = out["input_ids"].copy()
        return out
    tokenized = tokenized.map(tok, remove_columns=tokenized.column_names)
    tokenized_ds[split] = tokenized

from datasets import DatasetDict
tokenized = DatasetDict(tokenized_ds)
print("Tokenized dataset keys:", list(tokenized.keys()))

## 7) Train with TRL's `SFTTrainer`
Start with small epochs/subset for debugging. Use packing for efficiency.

In [None]:
from trl import SFTTrainer, SFTConfig
import torch

train_data = tokenized.get(SPLIT_TRAIN, tokenized[list(tokenized.keys())[0]])
eval_data = tokenized.get(SPLIT_EVAL, None)

args = SFTConfig(
    output_dir = f"outputs/{PROJECT_NAME}",
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = MICRO_BATCH,
    gradient_accumulation_steps = GRAD_ACCUM,
    learning_rate = LR,
    logging_steps = 10,
    save_strategy = "epoch",
    evaluation_strategy = "no" if eval_data is None else "epoch",
    bf16 = torch.cuda.is_available(),
    fp16 = False,
    optim = "paged_adamw_8bit",
    lr_scheduler_type = "cosine",
    warmup_ratio = 0.05,
    report_to = "none",
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_data,
    eval_dataset = eval_data,
    dataset_text_field = "text",
    args = args,
    packing = True,
    max_seq_length = MAX_SEQ_LEN,
)

print("Beginning training...")
trainer.train()
print("Training finished.")

## 8) Monitor VRAM & performance
Use this cell to inspect GPU memory during/after training.

In [None]:
import torch, gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("Allocated (MB):", torch.cuda.memory_allocated()/1e6)
else:
    print("No CUDA available.")
!nvidia-smi || echo "nvidia-smi not available"

## 9) Save LoRA adapter (small artifact)
Saves adapter weights and tokenizer.

In [None]:
import os
save_dir = f"outputs/{PROJECT_NAME}-adapter"
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print("Adapter saved to:", save_dir)

## 10) Quick test / inference
Loads the adapter in runtime and generates a sample response for verification.

In [None]:
prompt = "A 65-year-old with chest pain and shortness of breath: what immediate red flags suggest ER referral?"
messages = [
    {"role":"system","content":"You are a careful, non-diagnostic medical assistant. Provide safety guidance and recommend clinician review."},
    {"role":"user","content": prompt}
]
# build input text. Use helper if available.
try:
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception:
    parts = []
    for m in messages:
        parts.append(f"{m['role']}: {m['content']}")
    input_text = "\n".join(parts)

inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
model.eval()
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
print(tokenizer.decode(out[0], skip_special_tokens=True))

## 11) (Optional) Merge adapter and export for Ollama (GGUF)
This step can be memory-heavy: re-load base model in 16-bit, merge LoRA, then export to GGUF.
Enable only if you have enough RAM/GPU and want a GGUF file for Ollama.

In [None]:
MERGE_AND_EXPORT = False  # set True to run export steps
if MERGE_AND_EXPORT:
    from unsloth import FastLanguageModel
    from unsloth.export import save_to_gguf, save_to_ollama
    print("Re-loading base in 16-bit to merge...")
    base16, _ = FastLanguageModel.from_pretrained(
        model_name = BASE_MODEL,
        max_seq_length = MAX_SEQ_LEN,
        load_in_4bit = False,
        dtype = None,
        tokenizer = tokenizer,
    )
    merged = FastLanguageModel.merge_lora(base16, model)
    os.makedirs('gguf', exist_ok=True)
    save_to_gguf(merged, tokenizer, "gguf/med_qlora", quantizations=["Q8_0"])
    print("GGUF export complete.")
    save_to_ollama("gguf/med_qlora.Q8_0.gguf", model_name="med-qlora-llama31-8b")
    print("Ollama modelfile created.")

---

### Tips
- Start with a **small subset** to validate tokenization and generation before full training.
- If you hit OOM, reduce `MAX_SEQ_LEN`, use smaller `MICRO_BATCH`, or enable more gradient accumulation.
- Always evaluate with clinicians before any real-world use.