In [1]:
from datasets import load_dataset

# Load dataset
ds = load_dataset("ruslanmv/ai-medical-chatbot")

# Work with 'train' split
train_ds = ds["train"]

# Shuffle
train_ds = train_ds.shuffle(seed=42)

# Select 100_000 examples
train_ds = train_ds.select(range(100_000))

# Train/test split
train_test_split = train_ds.train_test_split(test_size=0.1)

# Access your new train and test datasets
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']


In [2]:
train_dataset


Dataset({
    features: ['Description', 'Patient', 'Doctor'],
    num_rows: 90000
})

In [3]:
import wandb


In [4]:
wandb.login()


wandb: Currently logged in as: aadhil-aseena (aadhil-aseena-rutgers-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

In [5]:
# Cell 2: Load model and configure LoRA + quantization
import torch
import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training  # <-- Added prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)

# Local path to the downloaded LLaMA-3-8B model
base = "./llama-3-8b"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base)

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",            # Normal Float 4 quantization
    bnb_4bit_compute_dtype=torch.float16, # Computation in float16
    bnb_4bit_use_double_quant=True,       # Use double quantization for better compression
)

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    base,
    quantization_config=bnb_config,
    device_map="auto"   # Automatically map layers to your GPU
)

# Enable gradient checkpointing to save VRAM (~30% memory savings)
model.gradient_checkpointing_enable()

# 💡 Prepare model for 4-bit LoRA fine-tuning (important for gradients!)
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj",
        "o_proj", "up_proj", "down_proj", "gate_proj"
    ],
    task_type="CAUSAL_LM"
)

# Apply LoRA adapters to the model
model = get_peft_model(model, lora_cfg)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
print(torch.__version__)

2.8.0.dev20250607+cu128


In [10]:
# Cell 3: Fine-tuning setup and execution
from trl import SFTTrainer
from transformers import TrainingArguments

# Define training arguments
args = TrainingArguments(
    output_dir="llama3-med",                   # Directory to save checkpoints and outputs
    per_device_train_batch_size=1,              # Small batch size (good for 16 GB VRAM)
    gradient_accumulation_steps=4,              # Effective batch size = 4
    num_train_epochs=3,                         # 1 epoch for fast experiment (increase later if needed)
    learning_rate=2e-4,                         # Learning rate suitable for LoRA
    logging_steps=10,                           # Log every 10 steps
    fp16=True,                                  # Enable mixed precision (good for memory)
    optim="paged_adamw_32bit",                  # Memory-efficient optimizer
    save_strategy="epoch",                      # Save at end of each epoch
    report_to="wandb",                          # Report training metrics to Weights and Biases
    run_name="llama3-medical-chatbot-qlora-evaloss",     # Name for the run on WandB
)

# Define formatting function for your dataset
def formatting_func(example):
    return f"<|user|>\n{example['Description']}\nPatient says: {example['Patient']}\n<|assistant|>\nDoctor replies: {example['Doctor']}"

# Create the SFTTrainer (supervised fine-tuning trainer)
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=lora_cfg,
    args=args,
    formatting_func=formatting_func,    # <-- custom formatting function
)

# Start the training
trainer.train()

# Save the final LoRA adapter
trainer.model.save_pretrained("adapter")


Applying formatting function to train dataset:   0%|          | 0/90000 [00:00<?, ? examples/s]

Converting train dataset to ChatML:   0%|          | 0/90000 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/90000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/90000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/90000 [00:00<?, ? examples/s]

Applying formatting function to eval dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/10000 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
10,2.717
20,2.3825
30,2.3817
40,2.3968
50,2.37
60,2.3188
70,2.392
80,2.3527
90,2.2888
100,2.3666


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


In [None]:
print(train_dataset.column_names)

In [2]:
# Cell 4: Merge LoRA adapter into base LLaMA model

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Paths
base_model_path = "./llama-3-8b"     # Local path to base LLaMA model
adapter_path = "adapter"             # Path where LoRA adapter is saved
merged_model_path = "llama3-merged"  # Output folder for merged model

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load and merge LoRA adapter
model = PeftModel.from_pretrained(base_model, adapter_path)
merged_model = model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained(merged_model_path)

# Save tokenizer (optional but recommended)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.save_pretrained(merged_model_path)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


Saving checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

('llama3-merged\\tokenizer_config.json',
 'llama3-merged\\special_tokens_map.json',
 'llama3-merged\\chat_template.jinja',
 'llama3-merged\\tokenizer.json')

In [3]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch

# Load merged model and tokenizer
model_path = "llama3-merged"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Construct a prompt using ChatML format
messages = [
    {"role": "user", "content": "Hello doctor, I have bad acne. How do I get rid of it?"}
]

# Use chat template formatting
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Create text generation pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Generate response
outputs = pipe(
    prompt,
    max_new_tokens=120,
    do_sample=True,
    temperature=0.7,
    top_k=50,
    top_p=0.95
)

# Print output
print(outputs[0]["generated_text"])


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.
Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Hello doctor, I have bad acne. How do I get rid of it?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Hello. Thank you for writing to us at healthcaremagicYou seem to have acne vulgaris type of acne. Acne is a multifactorial disorder. It is associated with genetic predisposition, hormonal imbalance and environmental factors. Most common factors that cause acne are oily skin, oily hair, oily food and stress. Acne is a dynamic condition and tends to occur in cycles of remission and relapse. It is best to treat acne in the early stages before it progresses to scarring. You can use a salicylic acid face wash. This is the most important step in the management of
