In [None]:
import json

medical_faq = [
    {
        "question": "What are the common symptoms of diabetes?",
        "answer": "Common symptoms include increased thirst, frequent urination, extreme fatigue, blurred vision, and slow-healing wounds."
    },
    {
        "question": "How can high blood pressure be managed?",
        "answer": "High blood pressure can be managed through lifestyle changes such as reducing salt intake, regular exercise, avoiding alcohol and smoking, and taking prescribed medications."
    },
    {
        "question": "What are the side effects of antibiotics?",
        "answer": "Common side effects of antibiotics include nausea, diarrhea, and yeast infections. Rarely, they can cause allergic reactions."
    },
    {
        "question": "What is the difference between viral and bacterial infections?",
        "answer": "Bacterial infections are caused by bacteria and are often treated with antibiotics, while viral infections are caused by viruses and usually resolve on their own with supportive care."
    },
    {
        "question": "When should I see a doctor for a headache?",
        "answer": "You should see a doctor if your headache is sudden and severe, associated with vision changes, fever, neck stiffness, weakness, confusion, or after a head injury."
    }
]

with open("medical_faq.json", "w") as f:
    json.dump(medical_faq, f, indent=4)

import json

hospital_logistics = [
    {
        "question": "What are the hospital working hours?",
        "answer": "Our hospital is open 24/7 for emergency care. Outpatient departments operate from 8:00 AM to 6:00 PM, Monday through Saturday."
    },
    {
        "question": "How can I book an appointment with a doctor?",
        "answer": "You can book appointments online via our hospital website or by calling the reception at +1-800-123-4567."
    },
    {
        "question": "What is the hospital's contact number?",
        "answer": "For general inquiries, call +1-800-123-4567. For emergency, dial +1-800-999-9111."
    },
    {
        "question": "Where is the hospital located?",
        "answer": "We are located at 123 Health Street, Wellness City, USA. You can find us on Google Maps by searching 'Wellness General Hospital'."
    },
    {
        "question": "Which doctors are available today?",
        "answer": "Available doctors include Dr. Smith (Cardiology), Dr. Johnson (Orthopedics), Dr. Patel (Pediatrics), and Dr. Chen (Dermatology)."
    },
    {
        "question": "Does the hospital have a pharmacy?",
        "answer": "Yes, our pharmacy is open from 8:00 AM to 10:00 PM daily and is located on the ground floor near the reception."
    }
]

with open("hospital_logistics.json", "w") as f:
    json.dump(hospital_logistics, f, indent=4)

print("hospital_logistics.json created with fake hospital dataset.")


In [None]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

import os
import torch
import transformers
access_token = "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"
os.environ["HF_TOKEN"] = access_token

with open("/medical_faq.json", "r") as f:
    faq_data = json.load(f)

# Use a sentence transformer to embed questions
encoder = SentenceTransformer("google/embeddinggemma-300m")
faq_questions = [item["question"] for item in faq_data]
faq_embeddings = encoder.encode(faq_questions)

# Convert to numpy float32 (required for FAISS)
faq_embeddings = np.array(faq_embeddings, dtype="float32")

# Build FAISS index
dimension = faq_embeddings.shape[1]
faq_index = faiss.IndexFlatL2(dimension)
faq_index.add(faq_embeddings)

def get_qa_context(query, top_k=1):
    query_vec = encoder.encode([query]).astype("float32")
    distances, indices = faq_index.search(query_vec, top_k)
    
    results = []
    for idx, dist in zip(indices[0], distances[0]):
        results.append({
            "question": faq_data[idx]["question"],
            "answer": faq_data[idx]["answer"],
            "score": float( 1 / (1 + dist))
        })
    return results

with open("hospital_logistics.json", "r") as f:
    logistics_data = json.load(f)


help_questions = [item["question"] for item in logistics_data]
help_embeddings = encoder.encode(help_questions)

help_embeddings = np.array(help_embeddings, dtype="float32")

dimension = help_embeddings.shape[1]
help_index = faiss.IndexFlatL2(dimension)
help_index.add(help_embeddings)

def get_help_context(query, top_k=2):
    
    query_vec = encoder.encode([query]).astype("float32")
    distances, indices = help_index.search(query_vec, top_k)
    
    results = []
    ctx = [""]
    for idx, dist in zip(indices[0], distances[0]):
        results.append({
            "question": logistics_data[idx]["question"],
            "answer": logistics_data[idx]["answer"],
            "score": float( 1 / (1 + dist))
        })
        ctx[0] += "\n" + logistics_data[idx]["answer"]
    return ctx


def get_context(query, query_type = "FAQ"):

    if query_type == "FAQ":
        context = get_qa_context(query)

    else:
        context = get_help_context(query)
        

    
    return context[0]
        

In [None]:
import gradio as gr

import os
import torch
import transformers
access_token = "hf_WHAyd................"
os.environ["HF_TOKEN"] = access_token

from transformers import pipeline
from PIL import Image
import requests
import torch

pipe = pipeline(
    "image-text-to-text",
    model="google/medgemma-4b-it",
    torch_dtype=torch.float16,
    device="cpu",
)

system_prompt = """ You are a helpful medical assistant. 
Answer {Yes}, if the Query is about Offering hospital/clinic support info (e.g., working hours, appointment booking info, 
contact numbers. 
Else If: context is provided from FAQ database.
You provide the answer from the context provided between <ctx> </ctx>.

Else:
If no context is provided, you can use your knowledge to answer.
"""

system_prompt = """
You are a helpful medical assistant. 

Rules for answering:

1. If the user query is about medical symptoms, conditions, or general health advice 
   (NOT hospital logistics), 
   then answer normally using your medical knowledge, not "{Yes}".

2. If the user query comes with a <ctx> ... </ctx> block, 
   then answer using only the information from that context.

3. If the user query is ONLY about hospital logistics 
   (examples: working hours, appointment booking, contact numbers, directions, available doctors), 
   then respond with exactly "{Yes}".
"""

classification_prompt = """
Classify the user query into one of these categories:
- HOSPITAL_INFO → if about hospital/clinic logistics (working hours, appointment booking, contact numbers, directions, doctors available).
- GENERAL_MEDICAL → if asking about symptoms, conditions, treatments, image, or medical knowledge.

User query: {query}
Category:
"""
def count_images(message, history):
    messages = []
    system_message = {
        "role": "system",
        "content": [{"type": "text", "text": system_prompt}]
    }
    
    messages = []
    # for msg in history:
    #     if isinstance(msg["content"],tuple):
    #       continue
        
    #     else:
    #       messages.append({"role":msg["role"], "content": [
    #              {"type": "text", "text": msg["content"]},
    #          ]})
    
    if len(message["files"]):
        image = message["files"][0]
        image = Image.open(image)
        messages.append({
             "role": "user",
             "content": [
                 {"type": "text", "text": message["text"]},
                 {"type": "image", "image": image}
             ]
         })
    # print(image)
    else:
        image = None
        messages.append({
             "role": "user",
             "content": [
                 {"type": "text", "text": message["text"]},
                 # {"type": "image", "image": image}
             ]
         })
    # if image is None:
    #     messages[-1]["content"] = messages[-1]["content"][:-1] 
    
    classification_message = [{
        "role": "user",
        "content": [
            {"type": "text", "text": classification_prompt.format(query=message["text"]) }
        ]
    }]
    
    query_type = pipe(text=classification_message, max_new_tokens=50)
    query_type = query_type[0]["generated_text"][-1]["content"]
    print(query_type)
    
    
    if "HOSPITAL_INFO" in query_type:
        ctx = get_context(message["text"], query_type)
        print(ctx)
        help_prompt = """
            Look at the Context and Query, refer to that context and answer relevant part only.  
            If the Context and Query are totally unrelated, say "I don't have answer to this question".
            Context: {context}
            Query: {query}
            """.format(context=ctx, query=message["text"])  # fill variables
        
            # prepend the FAQ context to the last message
        messages[-1]["content"][0]["text"] = help_prompt 

        print(messages)
        
        output = pipe(text=messages, max_new_tokens=50)
        return output[0]["generated_text"][-1]["content"]
    
    
    else:
        # check if avaiable in db, if similarity > 0.6 messages[-1]["content"]["text"] = get_context +  messages[-1]["content"]["text"] 
        context = get_context(message["text"])
        if context['score'] >= 0.8:
            faq_prompt = """
            Look at the Context and Query, if Context contains answer to the query, refer to that context only.
            Otherwise use your knowledge to answer the query.
            
            Context: {context}
            Query: {query}
            """.format(context=context["answer"], query=message["text"])  # fill variables
        
            # prepend the FAQ context to the last message
            messages[-1]["content"][0]["text"] = faq_prompt 

            print(messages)
        
        output = pipe(text=messages, max_new_tokens=50)
        return output[0]["generated_text"][-1]["content"]

with gr.Blocks() as demo:
    gr.Markdown(
        "⚠️ **Disclaimer:** This chatbot is only for educational purposes and should not be taken as medical advice."
    )
    
    gr.ChatInterface(
        fn=chat, 
        type="messages", 
        multimodal=True,
        textbox=gr.MultimodalTextbox(
            file_count="multiple", 
            file_types=["image"], 
            sources=["upload"]
        )
    )


demo.launch(debug = True)