# MedMistral — QLoRA Fine-Tuning on ruslanmv/ai-medical-chatbot
**Colab/Kaggle-ready** to fine-tune **Mistral-7B-Instruct** with **QLoRA**.

> ⚠️ Educational triage guidance only; not medical advice.

In [None]:
# !pip -q install -U transformers peft trl bitsandbytes datasets accelerate sentencepiece

In [None]:

import torch, os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
from huggingface_hub import login

BASE_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.1"
DATASET_ID = "ruslanmv/ai-medical-chatbot"
OUTPUT_DIR = "../artifacts/lora/medical_chatbot_lora"

os.makedirs(OUTPUT_DIR, exist_ok=True)

token = 'HF_token'
login(token)

In [None]:

ds = load_dataset(DATASET_ID)
ds


In [None]:

def to_text(ex):
    q = ex.get("question") or ex.get("input") or ex.get("prompt") or ex.get("query") or ex.get("Patient") or ex.get("user") or ""
    a = ex.get("answer") or ex.get("output") or ex.get("response") or ex.get("Doctor") or ex.get("assistant") or ""
    return {"text": f"<s>[INST] You are a cautious, supportive medical assistant.\nUser: {str(q).strip()}\n[/INST]{str(a).strip()}"}

for split in list(ds.keys()):
    ds[split] = ds[split].map(to_text)
ds['train'][0]['text'][:400]


In [None]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_auth_token=token)

# Load model
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, quantization_config=bnb_config, device_map="auto", use_auth_token=token)
model = prepare_model_for_kbit_training(model)

# LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)
model = get_peft_model(model, lora_config)


In [None]:

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    logging_steps=10,
    num_train_epochs=1.0,
    fp16=False,
    bf16=torch.cuda.is_available(),
    optim="paged_adamw_32bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    save_steps=200,
    save_total_limit=2,
    evaluation_strategy="steps",
    report_to="none",
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds.get("train") or ds[list(ds.keys())[0]],
    dataset_text_field="text",
    peft_config=lora_config,
    max_seq_length=1024,
    packing=False,
    args=training_args,
)
trainer.train()
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Saved LoRA adapters to:", OUTPUT_DIR)


In [None]:

from transformers import pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
prompt = "<s>[INST] You are a cautious, supportive medical assistant.\nUser: I have a fever and sore throat for 2 days.\n[/INST]"
pipe(prompt, max_new_tokens=200, temperature=0.3, top_p=0.9)[0]["generated_text"]
