# FIAP - Fase 3: Medical AI Assistant Fine-tuning

Fine-tuning of **Qwen2.5-1.5B-Instruct** with PubMedQA dataset for medical information assistance.

## How to run

After the notebook finishes, the model will be saved in GGUF format at `/MyDrive/fiap-3-model/model.gguf` on Google Drive.

Download `model.gguf` to the `outputs` folder and run:

```bash
ollama create medqa -f Modelfile
ollama run medqa
```

**Note:** Allow the notebook to connect to Google Drive to save the output model.

In [None]:
# mount google drive to store trained model
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# install dependencies
!pip install -q \
    torch==2.5.1 \
    torchvision==0.20.1 \
    torchaudio==2.5.1 \
    datasets==3.2.0 \
    transformers==4.46.3 \
    peft==0.14.0 \
    trl==0.12.2 \
    accelerate==1.2.1

# llama.cpp to convert to gguf
!git clone --depth 1 https://github.com/ggerganov/llama.cpp
!pip install -q ./llama.cpp/gguf-py

In [None]:
# download dataset
!mkdir -p data/pubmedqa/data
!gdown --fuzzy 'https://drive.google.com/file/d/15v1x6aQDlZymaHGP7cZJZZYFfeJt2NdS/view' -O data/pubmedqa/data/ori_pqaa.json

In [None]:
# prepare dataset
import json
import re

SYSTEM_PROMPT = """You are a concise medical assistant.

CRITICAL RULES - FOLLOW EXACTLY:
- Maximum 2-3 sentences. NO EXCEPTIONS.
- NO academic language. NO "studies show", "research indicates".
- NO lists. NO bullet points. NO multiple paragraphs.
- If answer needs more than 3 sentences, give only the most important point.
- NEVER start with "I" or introduce yourself.
- STOP writing after 3 sentences.
- NEVER diagnose or prescribe medications.

If unsure: "I don't have reliable information. Please consult a healthcare professional.\""""

# patterns to filter from dataset
REMOVE_PATTERNS = [
    r"Medical Subject Headings \(MeSH\):.*$",
    r"To the best of our knowledge,?\s*",
    r"Our (findings|results|study) (suggest|show|indicate|implicate)\s*",
    r"These (findings|results) (suggest|raise|indicate)\s*",
    r"In (this|our) study,?\s*",
    r"Results depicted\s*",
    r"©.*$", # Copyright
    r"Sincerely,.*$",
    r"Best regards,.*$",
]

def clean_text(text):
    # clean academic language while preserving medical content
    for pattern in REMOVE_PATTERNS:
        text = re.sub(pattern, "", text, flags=re.IGNORECASE | re.MULTILINE)

    # truncate long responses to teach conciseness
    if len(text) > 450:
        truncated = text[:400]
        last_period = truncated.rfind('.')
        if last_period > 100:
            text = truncated[:last_period + 1]

    return text.strip()

# load PubMedQA labeled dataset
with open("data/pubmedqa/data/ori_pqaa.json", "r") as f:
    pubmedqa_data = json.load(f)

dataset = []

for _, item in pubmedqa_data.items():
    question = item.get("QUESTION", "")
    long_answer = item.get("LONG_ANSWER", "")

    if not question or not long_answer:
        continue

    dataset.append({
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": question.strip()},
            {"role": "assistant", "content": clean_text(long_answer)}
        ]
    })

In [None]:
# refusal examples to teach model safety behaviors
REFUSAL_EXAMPLES = [
    # Diagnosis refusals
    {"question": "I have a headache, fever and body aches. What disease do I have?",
     "answer": "These symptoms can be associated with various conditions. Please consult a healthcare professional for proper evaluation and diagnosis."},
    {"question": "My child has a rash and fever. Is it measles?",
     "answer": "I cannot diagnose conditions. A rash with fever requires evaluation by a healthcare professional who can examine your child directly."},
    {"question": "I've been feeling dizzy and nauseous. What's wrong with me?",
     "answer": "Dizziness and nausea can have many causes. Please consult a healthcare professional for proper evaluation."},
    {"question": "Based on my symptoms, do I have COVID-19?",
     "answer": "I cannot diagnose COVID-19 or any other condition. Please get tested and consult a healthcare professional."},
    {"question": "I think I have diabetes because I'm always thirsty. Am I right?",
     "answer": "Increased thirst can have various causes. Only a healthcare professional can diagnose diabetes through proper testing."},

    # Prescription refusals
    {"question": "What medication should I take for my back pain?",
     "answer": "I cannot recommend specific medications. Please consult a healthcare professional who can evaluate your condition and prescribe appropriate treatment."},
    {"question": "Can you prescribe me antibiotics for my infection?",
     "answer": "I cannot prescribe medications. Antibiotics require a prescription from a licensed healthcare provider after proper evaluation."},
    {"question": "What dosage of ibuprofen should I take?",
     "answer": "Medication dosages should be determined by a healthcare professional or pharmacist based on your specific situation."},
    {"question": "Should I take aspirin or acetaminophen for my headache?",
     "answer": "I cannot recommend specific medications. Please consult a pharmacist or healthcare professional for appropriate recommendations."},
    {"question": "What's the best antidepressant for anxiety?",
     "answer": "Mental health medications require evaluation by a qualified healthcare professional. Please consult a psychiatrist or your doctor."},

    # Fake disease refusals
    {"question": "What are the symptoms of Hyperfluxitis?",
     "answer": "I don't have reliable information about this condition. Please consult a healthcare professional or verify the condition name."},
    {"question": "Tell me about Neuroplastoma syndrome",
     "answer": "I cannot find reliable information about this specific condition. Please consult a healthcare professional for accurate information."},
    {"question": "How is Cardiofibrillosis treated?",
     "answer": "I don't have information about this condition. Please verify the name with a healthcare professional."},

    # Emergency examples
    {"question": "I'm having chest pain right now, what should I do?",
     "answer": "Chest pain can be serious. If you're experiencing chest pain, please seek immediate medical attention or call emergency services."},
    {"question": "I accidentally took too many pills, help!",
     "answer": "This is a medical emergency. Please call poison control or emergency services immediately. Do not wait."},

    # More diagnosis/prescription refusals
    {"question": "I have a lump in my breast, is it cancer?",
     "answer": "I cannot diagnose whether a lump is cancerous. Please see a healthcare professional immediately for proper examination and testing."},
    {"question": "My blood pressure is 150/95, should I be worried?",
     "answer": "I cannot interpret your blood pressure readings or provide medical advice. Please consult a healthcare professional for evaluation."},
    {"question": "What antibiotic works best for UTI?",
     "answer": "I cannot recommend specific antibiotics. UTI treatment requires proper diagnosis and prescription from a healthcare provider."},
    {"question": "Can I stop taking my blood pressure medication?",
     "answer": "Never stop or change medications without consulting your prescribing healthcare professional. This could be dangerous."},
    {"question": "Is my mole melanoma?",
     "answer": "I cannot diagnose skin conditions. Any concerning mole should be examined by a dermatologist or healthcare professional."},
    {"question": "How much insulin should I inject?",
     "answer": "Insulin dosing must be determined by your healthcare provider based on your specific condition. Never adjust doses without medical guidance."},
]

# add refusal examples to dataset (4 copies each to increase weight)
for example in REFUSAL_EXAMPLES:
    for _ in range(4):
        dataset.append({
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["question"]},
                {"role": "assistant", "content": example["answer"]}
            ]
        })

In [None]:
# save prepared dataset
import random
from pathlib import Path

def save_jsonl(data, path):
    Path(path).parent.mkdir(exist_ok=True)
    with open(path, 'w') as f:
        f.writelines(json.dumps(item, ensure_ascii=False) + '\n' for item in data)

random.seed(42)
random.shuffle(dataset)

split_index = int(len(dataset) * 0.9)
save_jsonl(dataset[:split_index], "datasets/train.jsonl")
save_jsonl(dataset[split_index:], "datasets/validation.jsonl")


In [None]:
# fine tuning settings
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
OUTPUT_DIR = "/content/drive/MyDrive/fiap-3-model"

# LoRA
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

# training
MAX_SEQ_LENGTH = 400
NUM_EPOCHS = 1
BATCH_SIZE = 12
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 3e-5

In [None]:
# setup memory management
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
# load model in fp16 (no quantization to support direct GGUF conversion)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
model.config.pad_token_id = tokenizer.eos_token
model.enable_input_require_grads()

In [None]:
# setup Lora
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=TARGET_MODULES,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.train()

In [None]:
# load datasets
from datasets import load_dataset

train_dataset = load_dataset("json", data_files="datasets/train.jsonl", split="train")
validation_dataset = load_dataset("json", data_files="datasets/validation.jsonl", split="train")

def format_messages(examples):
    texts = []
    for messages in examples["messages"]:
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        texts.append(text)
    return {"text": texts}

train_dataset = train_dataset.map(format_messages, batched=True, remove_columns=train_dataset.column_names)
validation_dataset = validation_dataset.map(format_messages, batched=True, remove_columns=validation_dataset.column_names)

In [None]:
# train
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    args=SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        logging_steps=100,
        save_steps=100,
        fp16=True,
        optim="adamw_torch",
        gradient_checkpointing=True,
        max_grad_norm=1.0,
        max_seq_length=MAX_SEQ_LENGTH,
        report_to="none",
    ),
)

trainer.train()

In [None]:
# save trained model
trained_model = trainer.model
merged_model = trained_model.merge_and_unload()
merged_model.save_pretrained(f"{OUTPUT_DIR}/merged_model")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/merged_model")

In [None]:
# convert model to GGUF format for Ollama compatibility
!python llama.cpp/convert_hf_to_gguf.py {OUTPUT_DIR}/merged_model --outfile {OUTPUT_DIR}/model.gguf --outtype f16 --chat-template chatml

In [None]:
# end runtime to prevent wasting credits
from google.colab import runtime
runtime.unassign()