In [1]:
from datasets import Dataset
import pandas as pd
import torch
import json
import os


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from unsloth import FastLanguageModel

# Optimal sequence length for medical prescriptions (short responses)
max_seq_length = 1024

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=None,  # Auto-detect best dtype
    load_in_4bit=True,  # 4-bit quantization for efficiency
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
#### Unsloth: `hf_xet==1.1.10` and `ipykernel>6.30.1` breaks progress bars. Disabling for now in XET.
#### Unsloth: To re-enable progress bars, please downgrade to `ipykernel==6.30.1` or wait for a fix to
https://github.com/huggingface/xet-core/issues/526
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.7: Fast Gemma3 patching. Transformers: 4.56.2.
   \\   /|    NVIDIA GeForce RTX 4060. Num GPUs = 1. Max memory: 7.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.


In [3]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,  # LoRA rank - balanced for medical precision
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,  # Scaling factor (typically = r)
    lora_dropout = 0,  # 0 is optimized for Unsloth
    bias = "none",  # "none" is optimized
    use_gradient_checkpointing = "unsloth",  # 30% less VRAM
    random_state = 42,
    use_rslora = False,
    loftq_config = None,
)



Unsloth: Making `base_model.model.model.vision_tower.vision_model` require gradients


In [4]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma",  # Gemma-2 uses "gemma" template
)

In [5]:
# Load clinic chatbot training data
data_path = os.path.join(os.path.dirname(os.getcwd()), 'data', 'clinic_chatbot_training_data.json')
with open(data_path, 'r') as f:
    clinic_data = json.load(f)

print(f"Loaded {len(clinic_data)} training examples")
print("Sample:", clinic_data[0])

Loaded 100 training examples
Sample: {'patient_input': 'I have nausea and vomiting.', 'doctor_output': {'prescription_text': 'You may have Gastric discomfort. Take Domperidone 10mg as prescribed for relief.', 'medicine_name': 'Domperidone 10mg', 'dose_size': '1 tablet', 'frequency': 'Before meals', 'duration': '3 days'}}


In [6]:

# Add "speech" field to doctor_output - natural language explanation
def add_speech_field(item):
    """Convert structured prescription into natural language speech"""
    doc_out = item['doctor_output']
    
    # Generate natural speech from prescription
    speech = (
        f"{doc_out['prescription_text']} "
        f"I'm prescribing {doc_out['medicine_name']}. "
        f"Please take {doc_out['dose_size']} {doc_out['frequency'].lower()} "
        f"for {doc_out['duration']}. "
        f"Make sure to follow the dosage instructions carefully and contact me if symptoms persist or worsen."
    )
    
    # Add speech field to doctor_output
    doc_out['speech'] = speech
    
    return {
        "patient_input": item['patient_input'],
        "doctor_output": doc_out
    }

# Process all data to add speech field
processed_data = [add_speech_field(item) for item in clinic_data]
print("Sample with speech field:", processed_data[0])


Sample with speech field: {'patient_input': 'I have nausea and vomiting.', 'doctor_output': {'prescription_text': 'You may have Gastric discomfort. Take Domperidone 10mg as prescribed for relief.', 'medicine_name': 'Domperidone 10mg', 'dose_size': '1 tablet', 'frequency': 'Before meals', 'duration': '3 days', 'speech': "You may have Gastric discomfort. Take Domperidone 10mg as prescribed for relief. I'm prescribing Domperidone 10mg. Please take 1 tablet before meals for 3 days. Make sure to follow the dosage instructions carefully and contact me if symptoms persist or worsen."}}


In [7]:
# Convert to chat format for training
def convert_to_chat_format(item):
    """Convert clinic data to chat conversation format"""
    
    # Format doctor output as JSON string for the model to learn
    doctor_response = json.dumps(item['doctor_output'], indent=2)
    
    conversation = [
        {
            "role": "system",
            "content": "You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation)."
        },
        {
            "role": "user",
            "content": item['patient_input']
        },
        {
            "role": "assistant",
            "content": doctor_response
        }
    ]
    
    return {"conversation": conversation}

# Convert all data
chat_data = [convert_to_chat_format(item) for item in processed_data]

# Create Dataset
dataset = Dataset.from_list(chat_data)
print(f"Dataset size: {len(dataset)}")
print("First conversation:", dataset[0])

Dataset size: 100
First conversation: {'conversation': [{'content': 'You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation).', 'role': 'system'}, {'content': 'I have nausea and vomiting.', 'role': 'user'}, {'content': '{\n  "prescription_text": "You may have Gastric discomfort. Take Domperidone 10mg as prescribed for relief.",\n  "medicine_name": "Domperidone 10mg",\n  "dose_size": "1 tablet",\n  "frequency": "Before meals",\n  "duration": "3 days",\n  "speech": "You may have Gastric discomfort. Take Domperidone 10mg as prescribed for relief. I\'m prescribing Domperidone 10mg. Please take 1 tablet before meals for 3 days. Make sure to follow the dosage instructions carefully and contact me if symptoms persist or worsen."\n}', 'role': 'assistant'}]}


In [8]:
# Format conversations with chat template
def formatting_prompts_func(examples):
    convos = examples["conversation"]
    texts = [
        tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
        for convo in convos
    ]
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)
print("Formatted prompt sample:")
print(dataset[0]['text'][:500] + "...")

Map: 100%|██████████| 100/100 [00:00<00:00, 5983.83 examples/s]

Formatted prompt sample:
<bos><start_of_turn>user
You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation). I have nausea and vomiting.<end_of_turn>
<start_of_turn>model
{
  "prescription_text": "You may have Gastric discomfort. Take Domperidone 10mg as prescribed for relief.",
  "medicine_name": "Domperidone 10mg",
  "dose_size...





In [9]:
from trl import SFTTrainer, SFTConfig

# Optimal training configuration for medical data
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,  # Lower for medical precision
        gradient_accumulation_steps = 4,  # Effective batch size = 8
        warmup_steps = 10,  # 10% of total steps (~50 examples)
        num_train_epochs = 3,  # Multiple epochs for small dataset
        learning_rate = 2e-5,  # Lower LR for medical accuracy
        logging_steps = 10,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",  # Cosine for smooth convergence
        seed = 42,
        output_dir = "../models/clinic-chatbot",
        save_strategy = "epoch",  # Save after each epoch
        save_total_limit = 2,  # Keep only 2 best checkpoints
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        report_to = "none",
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=20): 100%|██████████| 100/100 [00:09<00:00, 10.02 examples/s]


In [10]:

# Train only on model responses (not user inputs or system prompts)
from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

print("Training dataset preview:")
print(f"Total examples: {len(trainer.train_dataset)}")
print(f"Sample input_ids length: {len(trainer.train_dataset[0]['input_ids'])}")

Map (num_proc=20): 100%|██████████| 100/100 [00:00<00:00, 141.45 examples/s]

Training dataset preview:
Total examples: 100
Sample input_ids length: 208





In [11]:
# Verify training setup - check what parts are being trained
sample_idx = 10
print("Full prompt:")
print(tokenizer.decode(trainer.train_dataset[sample_idx]["input_ids"]))
print("\n" + "="*80 + "\n")
print("Only training on (labels != -100):")
print(tokenizer.decode([x if x != -100 else tokenizer.pad_token_id for x in trainer.train_dataset[sample_idx]["labels"]]).replace(tokenizer.pad_token, ""))

Full prompt:
<bos><bos><start_of_turn>user
You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation). I have eye redness and itching.<end_of_turn>
<start_of_turn>model
{
  "prescription_text": "You may have Eye infection. Take Ofloxacin Eye Drops as prescribed for relief.",
  "medicine_name": "Ofloxacin Eye Drops",
  "dose_size": "2 drops",
  "frequency": "Twice daily",
  "duration": "5 days",
  "speech": "You may have Eye infection. Take Ofloxacin Eye Drops as prescribed for relief. I'm prescribing Ofloxacin Eye Drops. Please take 2 drops twice daily for 5 days. Make sure to follow the dosage instructions carefully and contact me if symptoms persist or worsen."
}<end_of_turn>



Only training on (labels != -100):
{
  "prescription_text": "You may have Eye infection. Take Ofloxacin Eye Drops as 

In [None]:
# Start training
print("🚀 Starting training...")
print(f"Total steps: ~{len(dataset) * 3 // (2 * 4)} steps (3 epochs, batch_size=2, grad_accum=4)")
print("Expected training time: 15-30 minutes depending on GPU\n")

trainer_stats = trainer.train()

print("\nTraining completed!")
print(f"Final loss: {trainer_stats.training_loss:.4f}")

🚀 Starting training...
Total steps: ~37 steps (3 epochs, batch_size=2, grad_accum=4)
Expected training time: 15-30 minutes depending on GPU



==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 100 | Num Epochs = 3 | Total steps = 39
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 32,788,480 of 4,332,867,952 (0.76% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,1.9258
20,1.1812
30,0.7062



✅ Training completed!
Final loss: 1.1131


In [None]:
# Enable fast inference mode
FastLanguageModel.for_inference(model)

def get_prescription(patient_symptoms, max_new_tokens=512, temperature=0.3):
    """Get prescription from fine-tuned doctor model"""
    
    conversation = [
        {
            "role": "system",
            "content": "You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation)."
        },
        {
            "role": "user",
            "content": patient_symptoms
        }
    ]
    
    prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,  # Lower temp for more consistent medical responses
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    
    try:
        # Try to parse JSON response
        prescription = json.loads(response)
        return prescription
    except:
        # If not valid JSON, return as text
        return {"raw_response": response}

# Test with medical symptoms
test_cases = [
    "I have severe headache and light sensitivity.",
    "I have nausea and vomiting.",
    "I have joint pain in my knees.",
    "I have persistent cough with phlegm.",
    "I have difficulty sleeping and anxiety.",
]

print("🏥 Testing Fine-tuned Clinic Chatbot\n" + "="*80 + "\n")

for symptom in test_cases:
    print(f"PATIENT: {symptom}")
    prescription = get_prescription(symptom)
    print(f"DOCTOR PRESCRIPTION:")
    print(json.dumps(prescription, indent=2))
    
    # Print speech field separately for clarity
    if 'speech' in prescription:
        print(f"\nDOCTOR SAYS: {prescription['speech']}")
    
    print("\n" + "-"*80 + "\n")

🏥 Testing Fine-tuned Clinic Chatbot

PATIENT: I have severe headache and light sensitivity.
DOCTOR PRESCRIPTION:
{
  "prescription_text": "Please see below for your prescription.",
  "medicine_name": "Sumatriptan 50mg",
  "dose_size": "1 tablet",
  "frequency": "As needed for headache, typically every 4-6 hours.",
  "duration": "Use as directed.  Do not exceed 2 tablets in 24 hours.",
  "speech": "You have been prescribed Sumatriptan 50mg tablets for your severe headache and light sensitivity. Take one tablet as needed for your headache, typically every 4-6 hours. Do not exceed two tablets in a 24-hour period.  If your symptoms persist or worsen, please contact your doctor immediately."
}

💬 DOCTOR SAYS: You have been prescribed Sumatriptan 50mg tablets for your severe headache and light sensitivity. Take one tablet as needed for your headache, typically every 4-6 hours. Do not exceed two tablets in a 24-hour period.  If your symptoms persist or worsen, please contact your doctor immed

In [None]:
# Save the fine-tuned model
print("Saving fine-tuned model...")

# Save LoRA adapter (small, ~100MB)
model.save_pretrained("../models/clinic-chatbot-lora")
tokenizer.save_pretrained("../models/clinic-chatbot-lora")

print("LoRA adapter saved to ../models/clinic-chatbot-lora")



💾 Saving fine-tuned model...
✅ LoRA adapter saved to ../models/clinic-chatbot-lora
