# 🏥 Medical Prescription Model - Inference Only

This notebook loads the fine-tuned clinic chatbot model and provides an interface for generating medical prescriptions based on patient symptoms.


In [None]:
import torch
import json
from unsloth import FastLanguageModel


## Load Fine-tuned Model

Loading the LoRA adapter merged with the base model.


In [None]:
max_seq_length = 1024

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="../models/clinic-chatbot-lora",
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
)

# Enable fast inference mode
FastLanguageModel.for_inference(model)

print("✅ Model loaded successfully!")


## Inference Function

Function to generate structured prescription responses.


In [1]:
def get_prescription(patient_symptoms, max_new_tokens=512, temperature=0.3, top_p=0.9):
    """
    Generate prescription from patient symptoms.
    
    Args:
        patient_symptoms (str): Description of patient symptoms
        max_new_tokens (int): Maximum tokens to generate
        temperature (float): Sampling temperature (lower = more deterministic)
        top_p (float): Nucleus sampling parameter
    
    Returns:
        dict: Structured prescription with medicine details and speech
    """
    
    conversation = [
        {
            "role": "system",
            "content": "You are a professional medical doctor. When a patient describes their symptoms, provide a structured prescription response in JSON format with: prescription_text, medicine_name, dose_size, frequency, duration, and speech (natural language explanation)."
        },
        {
            "role": "user",
            "content": patient_symptoms
        }
    ]
    
    prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    
    try:
        prescription = json.loads(response)
        return prescription
    except json.JSONDecodeError:
        return {"raw_response": response, "error": "Failed to parse JSON response"}

print("✅ Inference function ready!")


✅ Inference function ready!


## Test Cases

Run inference on multiple test cases.


In [None]:
test_cases = [
    "I have severe headache and light sensitivity.",
    "I have nausea and vomiting.",
    "I have joint pain in my knees.",
    "I have persistent cough with phlegm.",
    "I have difficulty sleeping and anxiety.",
]

print("🏥 Testing Fine-tuned Clinic Chatbot\n" + "="*80 + "\n")

for symptom in test_cases:
    print(f"PATIENT: {symptom}")
    prescription = get_prescription(symptom)
    print(f"DOCTOR PRESCRIPTION:")
    print(json.dumps(prescription, indent=2))
    
    if 'speech' in prescription:
        print(f"\n💬 DOCTOR SAYS: {prescription['speech']}")
    
    print("\n" + "-"*80 + "\n")


## Interactive Mode

Enter your own symptoms to get prescription.


In [None]:
# Custom symptom input
custom_symptom = "I have fever and body aches."

print(f"PATIENT SYMPTOM: {custom_symptom}\n")
prescription = get_prescription(custom_symptom)

print("PRESCRIPTION:")
print(json.dumps(prescription, indent=2))

if 'speech' in prescription:
    print(f"\n💬 DOCTOR SAYS:\n{prescription['speech']}")


In [None]:
def batch_inference(symptoms_list, temperature=0.3):
    """
    Process multiple symptoms and return all prescriptions.
    
    Args:
        symptoms_list (list): List of symptom descriptions
        temperature (float): Sampling temperature
    
    Returns:
        list: List of prescription dictionaries
    """
    results = []
    
    for symptom in symptoms_list:
        prescription = get_prescription(symptom, temperature=temperature)
        results.append({
            "symptom": symptom,
            "prescription": prescription
        })
    
    return results

# Example batch processing
batch_symptoms = [
    "I have chest pain and shortness of breath.",
    "I have skin rash and itching.",
    "I have stomach pain and diarrhea."
]

batch_results = batch_inference(batch_symptoms)

print("📋 Batch Inference Results\n" + "="*80 + "\n")
for result in batch_results:
    print(f"SYMPTOM: {result['symptom']}")
    print(f"PRESCRIPTION: {json.dumps(result['prescription'], indent=2)}")
    print("-"*80 + "\n")


## Export Results

Save inference results to JSON file.


In [None]:
import os
from datetime import datetime

def save_results(results, filename=None):
    """
    Save inference results to JSON file.
    
    Args:
        results (list): List of result dictionaries
        filename (str): Output filename (optional)
    """
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"inference_results_{timestamp}.json"
    
    output_path = os.path.join(os.path.dirname(os.getcwd()), 'results', filename)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"✅ Results saved to: {output_path}")
    return output_path

# Save batch results
# save_results(batch_results)
print("💾 Export function ready. Uncomment to save results.")


## Model Information


In [None]:
print("📊 Model Information:")
print(f"Base Model: unsloth/gemma-3-4b-it-unsloth-bnb-4bit")
print(f"Fine-tuned Adapter: ../models/clinic-chatbot-lora")
print(f"Max Sequence Length: {max_seq_length}")
print(f"Quantization: 4-bit")
print(f"Device: {model.device}")
print(f"Training Data: 100 medical symptom-prescription pairs")
print(f"Training Epochs: 3")
