## 🚀 Day 5/15 — Fine-Tuning with Unsloth AI
# TibbScholar: A Medical AI Assistant

### Project Overview

This notebook documents the process of creating **TibbScholar**, a specialized Large Language Model (LLM) designed to function as a knowledgeable and reliable medical assistant.

By fine-tuning a powerful base model on a vast corpus of medical question-answer pairs, we aim to build an AI that can provide accurate, context-aware information for medical students, researchers, and professionals. The project leverages the Unsloth library for highly efficient, memory-optimized training on a single GPU.

---

### 🛠️ Technology Stack & Key Components

*   **Base Model:** Meta's Llama 3.2 (3B) - `unsloth/Llama-3.2-3B-unsloth-bnb-4bit`. Chosen for its strong instruction-following capabilities and robust performance in the 3B parameter class.
*   **Fine-tuning Framework:** Unsloth AI, for 2x faster training and up to 70% less memory usage.
*   **Technique:** QLoRA (4-bit Quantized Low-Rank Adaptation) to efficiently fine-tune the model with limited VRAM.
*   **Dataset:** MIRIAD (`miriad/miriad-4.4M`), a comprehensive collection of medical QA pairs, ideal for instilling deep domain knowledge.
*   **Platform:** Google Colab Pro with an NVIDIA A100 GPU.
---

### 📋 Dataset

To teach **TibbScholar** to act as a reliable assistant, we need to structure our training data in a clear and consistent format.

#### Dataset
We are using the `miriad/miriad-4.4M` dataset, which contains about millions of medical question-answer pairs. For this training run, we will use a randomly selected subset of **100,000** examples to balance training time and performance. The data is streamed and shuffled to ensure the model sees a diverse and representative sample.


### Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

### Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048
dtype = None
load_in_4bit =  True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-unsloth-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.7.11: Fast Llama patching. Transformers: 4.54.1.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

We now add LoRA adapters so we only need to update 1 to 10% of all parameters!

In [3]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    loftq_config = None,
)

Unsloth 2025.7.11 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


### Load the dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset("miriad/miriad-4.4M", split = "train")

In [5]:
# 1. Define the prompt format
alpaca_prompt = """Question:
{}

Answer:
{}"""

# 2. Create the formatting function
def formatting_prompts_func(examples):
    questions = examples["question"]
    answers = examples["answer"]
    texts = []
    for question, answer in zip(questions, answers):
        text = alpaca_prompt.format(question, answer)
        texts.append(tokenizer.bos_token + text + tokenizer.eos_token)
    return { "text" : texts }

dataset = dataset.shuffle(seed=40).select(range(100_000))
dataset = dataset.map(formatting_prompts_func, batched=True)

print("Dataset prepared and formatted.")
print(f"First example:\n{dataset[0]['text']}")


Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Dataset prepared and formatted.
First example:
<|begin_of_text|>Question:
What are the symptoms and clinical manifestations of oculomotor dysfunction in patients with motor neuron disease?


Answer:
Oculomotor dysfunction in patients with motor neuron disease can manifest as difficulty in voluntary opening or closing of the eyelids. Patients may experience a conscious and sustained effort to close their eyes, and may develop a habit of gently closing the eyelids with their fingers. However, their eyelids remain closed normally during sleep. Other symptoms of motor neuron disease, such as slurred speech, difficulty in swallowing, muscle atrophy, weakness, and fasciculation in the upper limbs and shoulders, may also be present. Cognitive dysfunction, including attention and recall impairment, may occur. Neurological examination may reveal both upper and lower motor neuron signs with bulbar involvement. Reflex blinking to various stimuli is typically normal, but voluntary initiation of ey

In [65]:
import random; print(dataset[random.randint(1, 100000)]['text'])

<|begin_of_text|>Question:
What are the risk factors for nerve damage during dental implant surgery?


Answer:
The main risk factors for nerve damage during dental implant surgery include errors in evaluation and planning, errors in injection of local anesthetic, bone preparation mistakes, and placement errors of the implant. Other factors that may increase the risk include inadequate maintenance of oral hygiene, plaque accumulation, and misdirection of supra structure. These factors can contribute to implant failure and increase the likelihood of nerve damage.<|end_of_text|>


<a name="Train"></a>
### Train the model


In [27]:
from trl import SFTConfig, SFTTrainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    packing = False,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 500,
        max_steps = 10000,
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
    ),
)

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train(resume_from_checkpoint = True)

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

<a name="Inference"></a>
### Inference
Let's run the model! You can change the instruction and input - leave the output blank!



In [46]:
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
prompt_template = """Question:
{}

Answer:
"""
question = "How does chronic exercise training impact in the elderly??"

# Format the prompt
prompt = prompt_template.format(question)

inputs = tokenizer(
    [prompt],
    return_tensors = "pt"
    ).to("cuda")

outputs = model.generate(
    **inputs,
    max_new_tokens = 512,
    use_cache = True,
    do_sample = True,
    temperature = 0.7,
    top_p = 0.9
    )

response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

# Print the result
print(response_text)

Question:
How does chronic exercise training impact in the elderly??

Answer:
Chronic exercise training has been shown to improve functional status and muscle mass in the elderly. It also increases muscle fiber cross-sectional area and has a positive effect on the rate of decline in muscle strength.


In [None]:
from transformers import TextStreamer

FastLanguageModel.for_inference(model)

inputs = tokenizer(
    [prompt_template.format("What is oculomotor dysfunction? What are its symptoms?")],
    return_tensors = "pt").to("cuda")


text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512)

<|begin_of_text|>Question:
What is oculomotor dysfunction? What are its symptoms?

Answer:
Oculomotor dysfunction is a condition characterized by impaired movement of the eye. It can result from various causes, including injuries, diseases, or neurological disorders. Symptoms of oculomotor dysfunction may include impaired eye movement, double vision, inability to focus on objects, and impaired pupil response to light. In the case of oculomotor palsy, the symptoms may be more severe and can lead to difficulty in performing daily activities.<|end_of_text|>


In [43]:

FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
    [prompt_template.format("What are the early and late-stage changes observed in hip osteoarthritis?")],
    return_tensors = "pt").to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)

<|begin_of_text|>Question:
What are the early and late-stage changes observed in hip osteoarthritis?

Answer:
Early-stage changes in hip osteoarthritis include increased cartilage thickness, increased cartilage volume, and increased cartilage density. Late-stage changes include reduced cartilage thickness and volume, and reduced cartilage density.<|end_of_text|>


### Saving Model

In [None]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')

# Merge to 16bit
model.push_to_hub_merged("Aasher/TibbScholar", tokenizer, save_method = "merged_16bit", token = hf_token, maximum_memory_usage=0.6)

### Load the Model

In [None]:
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    "Aasher/TibbScholar",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit
    )


In [66]:
FastLanguageModel.for_inference(model)
prompt = """Question:
What are the risks in dental implant surgery?

Answer:
"""


inputs = tokenizer(
    [prompt],
    return_tensors = "pt").to("cuda")

text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, temperature=1, max_new_tokens = 256)

<|begin_of_text|>Question:
What are the risks in dental implant surgery?

Answer:
The risks in dental implant surgery can include complications such as infection, failure to heal, or complications related to the implant placement. These risks should be clearly discussed with the patient and documented in the patient's records.<|end_of_text|>
