In [None]:
from datasets import Dataset
import pandas as pd
import torch
import json
import os


In [None]:
from unsloth import FastLanguageModel

# Optimal sequence length for medical prescriptions (short responses)
max_seq_length = 1024

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="google/gemma-2-9b-it",
    max_seq_length=max_seq_length,
    dtype=None,  # Auto-detect best dtype
    load_in_4bit=True,  # 4-bit quantization for efficiency
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,  # LoRA rank - balanced for medical precision
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,  # Scaling factor (typically = r)
    lora_dropout = 0,  # 0 is optimized for Unsloth
    bias = "none",  # "none" is optimized
    use_gradient_checkpointing = "unsloth",  # 30% less VRAM
    random_state = 42,
    use_rslora = False,
    loftq_config = None,
)



In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma",  # Gemma-2 uses "gemma" template
)

In [None]:
# Load clinic chatbot training data
data_path = os.path.join(os.path.dirname(os.getcwd()), 'data', 'clinic_chatbot_training_data.json')
with open(data_path, 'r') as f:
    clinic_data = json.load(f)

print(f"Loaded {len(clinic_data)} training examples")
print("Sample:", clinic_data[0])

In [None]:

# Add "speech" field to doctor_output - natural language explanation
def add_speech_field(item):
    """Convert structured prescription into natural language speech"""
    doc_out = item['doctor_output']
    
    # Generate natural speech from prescription
    speech = (
        f"{doc_out['prescription_text']} "
        f"I'm prescribing {doc_out['medicine_name']}. "
        f"Please take {doc_out['dose_size']} {doc_out['frequency'].lower()} "
        f"for {doc_out['duration']}. "
        f"Make sure to follow the dosage instructions carefully and contact me if symptoms persist or worsen."
    )
    
    # Add speech field to doctor_output
    doc_out['speech'] = speech
    
    return {
        "patient_input": item['patient_input'],
        "doctor_output": doc_out
    }

# Process all data to add speech field
processed_data = [add_speech_field(item) for item in clinic_data]
print("Sample with speech field:", processed_data[0])


In [None]:
# Convert to chat format for training
def convert_to_chat_format(item):
    """Convert clinic data to chat conversation format"""
    
    # Format doctor output as JSON string for the model to learn
    doctor_response = json.dumps(item['doctor_output'], indent=2)
    
    conversation = [
        {
            "role": "system",
            "content": "You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation)."
        },
        {
            "role": "user",
            "content": item['patient_input']
        },
        {
            "role": "assistant",
            "content": doctor_response
        }
    ]
    
    return {"conversation": conversation}

# Convert all data
chat_data = [convert_to_chat_format(item) for item in processed_data]

# Create Dataset
dataset = Dataset.from_list(chat_data)
print(f"Dataset size: {len(dataset)}")
print("First conversation:", dataset[0])

In [None]:
# Format conversations with chat template
def formatting_prompts_func(examples):
    convos = examples["conversation"]
    texts = [
        tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
        for convo in convos
    ]
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)
print("Formatted prompt sample:")
print(dataset[0]['text'][:500] + "...")

In [None]:
from trl import SFTTrainer, SFTConfig

# Optimal training configuration for medical data
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,  # Lower for medical precision
        gradient_accumulation_steps = 4,  # Effective batch size = 8
        warmup_steps = 10,  # 10% of total steps (~50 examples)
        num_train_epochs = 3,  # Multiple epochs for small dataset
        learning_rate = 2e-5,  # Lower LR for medical accuracy
        logging_steps = 10,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",  # Cosine for smooth convergence
        seed = 42,
        output_dir = "../models/clinic-chatbot",
        save_strategy = "epoch",  # Save after each epoch
        save_total_limit = 2,  # Keep only 2 best checkpoints
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        report_to = "none",
    ),
)

In [None]:

# Train only on model responses (not user inputs or system prompts)
from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

print("Training dataset preview:")
print(f"Total examples: {len(trainer.train_dataset)}")
print(f"Sample input_ids length: {len(trainer.train_dataset[0]['input_ids'])}")

In [None]:
# Verify training setup - check what parts are being trained
sample_idx = 10
print("Full prompt:")
print(tokenizer.decode(trainer.train_dataset[sample_idx]["input_ids"]))
print("\n" + "="*80 + "\n")
print("Only training on (labels != -100):")
print(tokenizer.decode([x if x != -100 else tokenizer.pad_token_id for x in trainer.train_dataset[sample_idx]["labels"]]).replace(tokenizer.pad_token, ""))

In [None]:
# Start training
print("🚀 Starting training...")
print(f"Total steps: ~{len(dataset) * 3 // (2 * 4)} steps (3 epochs, batch_size=2, grad_accum=4)")
print("Expected training time: 15-30 minutes depending on GPU\n")

trainer_stats = trainer.train()

print("\n✅ Training completed!")
print(f"Final loss: {trainer_stats.training_loss:.4f}")

In [None]:
# Enable fast inference mode
FastLanguageModel.for_inference(model)

def get_prescription(patient_symptoms, max_new_tokens=512, temperature=0.3):
    """Get prescription from fine-tuned doctor model"""
    
    conversation = [
        {
            "role": "system",
            "content": "You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation)."
        },
        {
            "role": "user",
            "content": patient_symptoms
        }
    ]
    
    prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,  # Lower temp for more consistent medical responses
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    
    try:
        # Try to parse JSON response
        prescription = json.loads(response)
        return prescription
    except:
        # If not valid JSON, return as text
        return {"raw_response": response}

# Test with medical symptoms
test_cases = [
    "I have severe headache and light sensitivity.",
    "I have nausea and vomiting.",
    "I have joint pain in my knees.",
    "I have persistent cough with phlegm.",
    "I have difficulty sleeping and anxiety.",
]

print("🏥 Testing Fine-tuned Clinic Chatbot\n" + "="*80 + "\n")

for symptom in test_cases:
    print(f"PATIENT: {symptom}")
    prescription = get_prescription(symptom)
    print(f"DOCTOR PRESCRIPTION:")
    print(json.dumps(prescription, indent=2))
    
    # Print speech field separately for clarity
    if 'speech' in prescription:
        print(f"\n💬 DOCTOR SAYS: {prescription['speech']}")
    
    print("\n" + "-"*80 + "\n")

In [None]:
# Save the fine-tuned model
print("💾 Saving fine-tuned model...")

# Save LoRA adapter (small, ~100MB)
model.save_pretrained("../models/clinic-chatbot-lora")
tokenizer.save_pretrained("../models/clinic-chatbot-lora")

print("✅ LoRA adapter saved to ../models/clinic-chatbot-lora")

# Optional: Merge and save full model (larger, but standalone)
# model.save_pretrained_merged("../models/clinic-chatbot-merged", tokenizer, save_method="merged_16bit")
# print("✅ Merged model saved to ../models/clinic-chatbot-merged")

# Optional: Export to GGUF for Ollama/llama.cpp
# model.save_pretrained_gguf("../models/clinic-chatbot-gguf", tokenizer, quantization_method="q4_k_m")
# print("✅ GGUF model saved to ../models/clinic-chatbot-gguf")


## 📊 Fine-tuning Parameters Summary

### Model Configuration
- **Model**: `google/gemma-2-9b-it` (Gemma 2, 9B parameters, instruction-tuned)
- **Quantization**: 4-bit (QLoRA) for memory efficiency
- **Max Sequence Length**: 1024 tokens (optimal for short medical prescriptions)

### LoRA Configuration
- **Rank (r)**: 16 - Balanced between quality and speed
- **Alpha**: 16 - Standard scaling (typically = r)
- **Dropout**: 0 - Optimized for Unsloth
- **Target Modules**: All attention + MLP layers (q/k/v/o_proj, gate/up/down_proj)

### Training Hyperparameters
- **Batch Size**: 2 (per device)
- **Gradient Accumulation**: 4 steps (effective batch size = 8)
- **Learning Rate**: 2e-5 (lower for medical precision)
- **Epochs**: 3 (small dataset needs multiple passes)
- **Warmup Steps**: 10
- **LR Scheduler**: Cosine (smooth convergence)
- **Optimizer**: AdamW 8-bit

### Dataset
- **Size**: ~500 examples (100 unique cases with duplicates)
- **Format**: Patient symptoms → Doctor prescription (JSON with speech field)
- **Output Structure**:
  - `prescription_text`: Brief diagnosis
  - `medicine_name`: Medication name
  - `dose_size`: Dosage amount
  - `frequency`: When to take
  - `duration`: How long
  - `speech`: Natural language explanation (NEW)

### Why These Parameters?

1. **Small Dataset (500 examples)**:
   - Multiple epochs (3) to learn patterns
   - Lower learning rate (2e-5) for stability
   - Smaller batch size to prevent overfitting

2. **Medical Domain**:
   - Low temperature (0.3) for consistent outputs
   - Structured JSON output format
   - Training only on model responses (not user inputs)

3. **Resource Efficiency**:
   - 4-bit quantization reduces VRAM by 75%
   - Gradient checkpointing saves memory
   - LoRA rank 16 is efficient yet effective

4. **Expected Results**:
   - Training Loss: Should converge to ~0.5-1.0
   - Training Time: ~15-30 minutes on modern GPU
   - Model Size: ~100MB LoRA adapter
