In [15]:
import gradio as gr
import torch 
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer, util
import numpy as np
import sqlite3

In [None]:
AVAILABLE_MODELS = {
    "google/flan-t5-base": "Flan-T5 Base (Recommended)",
    "google/flan-t5-small": "Flan-T5 Small (Faster)",
    
}
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DB_FILE = 'faq_database.db'
SIMILARITY_TH = 0.65
DISCLAIMER = "\n\n--- \n *DISCLAIMER: This is for educational purposes only and not a substitute for professional medical advice.*"

PROMPT_TEMPLATE_DIRECT = """
Directly answer the user's question using the provided context.

Context: {context}
Question: {query}
Answer:
"""

PROMPT_TEMPLATE_GOOD = """ 
You are a helpful and cautious Healthcare Information Assistant. Your task is to answer the user's question based *primarily* on the provided context.
If the context contains the answer, use it to form a clear and concise response. If the context is not relevant, use your general knowledge to answer, but state that the information is general.
Never provide a medical diagnosis or prescribe treatment.

CONTEXT:
{context}

USER QUESTION:
{query}

ANSWER:
"""


PROMPT_TEMPLATES = {
    "Cautious Assistant": PROMPT_TEMPLATE_GOOD,
    "Direct Assistant": PROMPT_TEMPLATE_DIRECT,
}

In [17]:
# Database 

def setup_database():
    """Makes a database if not present"""
    conn = sqlite3.connect(DB_FILE)
    cursor = conn.cursor()
    cursor.execute("""
            CREATE TABLE IF NOT EXISTS faqs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            question TEXT NOT NULL UNIQUE,
            answer TEXT NOT NULL )
                """)
    cursor.execute("SELECT COUNT (*) FROM faqs")

    if cursor.fetchone()[0] == 0:
        print("Database is empty. Population with initial data..")
        inital_data = [
            ("What are the common symptoms of the seasonal flu?", "Common flu symptoms include a high fever, body aches, headache, fatigue, a dry cough, and a sore throat. It's distinct from a cold as it comes on more suddenly and is more severe."),
            ("How can I differentiate between a cold and the flu?", "The flu is generally more severe than the common cold. Flu symptoms appear more suddenly and intensely, and are more likely to include fever and body aches, whereas a cold typically involves a runny or stuffy nose."),
            ("What is the proper first aid for a minor burn?", "For a minor burn, immediately cool the area by running it under cool (not cold) water for 10-20 minutes. Cover it with a sterile, non-adhesive bandage. Do not apply ice, butter, or oils."),
            ("How much water should an average adult drink per day?", "A general guideline is to drink about 8 glasses (around 2 liters or half a gallon) of water per day. However, needs can vary based on activity level, climate, and overall health."),
            ("What is the recommended amount of sleep for adults?", "Most adults require 7-9 hours of quality sleep per night for optimal physical health, cognitive function, and emotional well-being."),
            ("What are the hospital's standard operating hours?", "Our general clinic and outpatient services are open from 9 AM to 6 PM on weekdays. The emergency department remains open 24 hours a day, 7 days a week."),
            ("What is the contact number for medical emergencies?", "For any life-threatening medical emergency, please dial 911 immediately. For our hospital's dedicated emergency room, the direct line is (555) 765-4321."),
        ]

        cursor.executemany("INSERT INTO faqs (question, answer) VALUES (?, ?)", inital_data)
    
    conn.commit()
    conn.close()

def add_faqs (question: str, answer: str) -> str:
    """Add new faq to database"""

    if not question or not answer:
        return "Error : Question and Answer must not be empty"
    
    try :
        conn = sqlite3.connect(DB_FILE)
        cursor = conn.cursor()
        cursor.execute("INSERT INTO faqs (question, answer) VALUES (?, ?)", (question, answer))
        conn.commit()
        conn.close()
        return f"Added : {question[:30]}... to database."
    
    except sqlite3.IntegrityError:
        return "Error : This quesiton already exists in the database."
    
    except Exception as e:
        return f"Error: {e}"

In [18]:
# Chat-Bot pipleline 

class ChatBot :
    def __init__(self, default_model="google/flan-t5-base"):
        print("Initializing Chatbot pipleline.... ")
        self.device = "mps"
        print(f"Using device : {self.device}")
        self.embedding_model = SentenceTransformer(EMBEDDING_MODEL, device=self.device)
        self.tokenizer = None
        self.model = None
        self.switch_model(default_model)

        self.query_counts = {}
        self.pending_for_review = {}
        self.REVIEW_THRE = 3
        self.refresh()

    
    def switch_model(self, model_name: str):
        print(f"Switching model to : {model_name}")

        try :
            self.tokenizer = T5Tokenizer.from_pretrained(model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(self.device)
            print("Model swithched succesfully.")
            return f"Switched to {model_name}"
        
        except Exception as e :
            return f"Failed to switch model. Error : {e}"

    def refresh(self):
        print('Refreshing.......')
        conn = sqlite3.connect(DB_FILE)
        cursor = conn.cursor()
        cursor.execute('SELECT question, answer FROM faqs')
        rows = cursor.fetchall()
        conn.close()

        if not rows:
            print("Warning: Knowledge base is empty.")
            self.faqs_question = []
            self.faqs_data_map = {}
            self.faq_embeddings = torch.empty(0, device=self.device)
            return 
        
        self.faqs_question = [row[0] for row in rows]
        self.faqs_data_map = {row[0]: row[1] for row in rows}

        self.faq_embeddings = self.embedding_model.encode(
            self.faqs_question,
            convert_to_tensor = True,
            show_progress_bar = False
        ).to(self.device)
        
        print(f"Database now contains : {len(self.faqs_question)} FAQs")

    def find_most_similar (self, user_query_embedding):
        if len(user_query_embedding) == 0:
            return None, 0.0
        
        similarities = util.pytorch_cos_sim(user_query_embedding, self.faq_embeddings)[0]
        best_match_indx = torch.argmax(similarities).item()
        best_match_score = similarities[best_match_indx].item()
        return self.faqs_question[best_match_indx], best_match_score

    def __call__(self, user_query: str, prompt_template_name: str) -> tuple[str, str, str, str] :

        query_embedding = self.embedding_model.encode(user_query, convert_to_tensor=True).to(self.device)
        matched_query, score = self.find_most_similar(query_embedding)
        prompt_template = PROMPT_TEMPLATES.get(prompt_template_name, PROMPT_TEMPLATE_GOOD) # default templet is good


        if score > SIMILARITY_TH :
            source = "FAQ Database"
            context = self.faqs_data_map[matched_query]
            prompt = prompt_template.format(context = context, query = user_query)
            matched_question = matched_query
        
        else :
            source = "Dr. Antik"    # an alias for the the model , PS : he is my room-mate ;)
            matched_question = "N\A"
            prompt = prompt_template.format(context = "No specific context found.", query = user_query)

            normalized_query = user_query.lower().strip()
            if normalized_query:
                self.query_counts[normalized_query] = self.query_counts.get(normalized_query, 0) + 1

                if self.query_counts[normalized_query] >= self.REVIEW_THRE :
                    self.pending_for_review.append(normalized_query)
                    print(f"Question '{normalized_query}' has been flagged for review.")

                self.query_counts[normalized_query] = 0
        
        input = self.tokenizer(prompt, return_tensors = "pt").to(self.device)
        output = self.model.generate(**input, max_length = 512)
        response_text = self.tokenizer.decode(output[0], skip_special_tokens = True)

        final_response = response_text + DISCLAIMER
        confidence_str = f"{score: .2f}"

        return final_response, source, matched_question, confidence_str

In [19]:
setup_database()
pipeline = ChatBot()

Initializing Chatbot pipleline.... 
Using device : mps
Switching model to : google/flan-t5-base
Model swithched succesfully.
Refreshing.......
Database now contains : 7 FAQs


In [None]:
# UI 

def handle_add_faqs (question, answer):
    message = add_faqs(question, answer)

    if "Added" in message :
        pipeline.refresh()
    
    return message

def get_pending_reviews():
    if not pipeline.pending_for_review:
        return "No questions are currently pending review."
    

    review_list = "### 📋 Questions Flagged for Review:\n"
    for i, question in enumerate(pipeline.pending_for_review, 1):
        review_list += f"{i}. `{question}`\n"
    return review_list


def handle_switch_model(model_name_from_ui):
    model_path = [k for k, v in AVAILABLE_MODELS.items() if v == model_name_from_ui][0]
    return pipeline.switch_model(model_path)


with gr.Blocks(theme=gr.themes.Soft()) as iface :
    gr.Markdown("# Health Chat Bot")

    with gr.Tabs():
        
        with gr.TabItem("ChatBot"):
            gr.Markdown("This chatbot uses a **Retrieval-Augmented Generation (RAG)** pipeline to answer your questions.")
            
            with gr.Row():
                with gr.Column(scale=2):
                    question_box = gr.Textbox(lines = 4, placeholder="e.g., How do I treat a small burn?", label="Ask a Health Question")
                    
                    prompt_selector = gr.Radio(
                        label="Select a Prompt Template",
                        choices=list(PROMPT_TEMPLATES.keys()),
                        value="Cautious Assistant"
                    )
                    
                    submit_btn = gr.Button("Submit", variant="primary")


                with gr.Column(scale=1):
                    examples = gr.Examples(examples=["How do I know if I have the flu?", "what's the number for emergencies?", "how much sleep should i get?"], inputs=[question_box])

                    gr.Markdown("### ⚙️ Model Selection")
                    model_selector = gr.Dropdown(
                        label="Switch AI Model",
                        choices=list(AVAILABLE_MODELS.values()),
                        value=AVAILABLE_MODELS[pipeline.model.name_or_path] # Set default
                    )

                    switch_model_btn = gr.Button("Apply Model Switch")
                    model_status_box = gr.Textbox(label="Status", interactive=False)

            gr.Markdown("### Response Details")
            with gr.Row():
                source_label = gr.Label(label="Response Source")
                matched_question_label = gr.Label(label="Matched FAQ Question")
                confidence_label = gr.Label(label="Confidence Score")
            answer_box = gr.Textbox(lines=6, label="Chatbot Answer", interactive=False)
            
            submit_btn.click(
                fn=pipeline, 
                inputs=[question_box, prompt_selector], 
                outputs=[answer_box, source_label, matched_question_label, confidence_label]
            )

            switch_model_btn.click(
                fn=handle_switch_model,
                inputs=[model_selector],
                outputs=[model_status_box]
            )

        with gr.TabItem("Manage Knowledge Base"):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("## Add a New FAQ")
                    new_question_box = gr.Textbox(lines=2, label="New Question")
                    new_answer_box = gr.Textbox(lines=5, label="New Answer")
                    add_faq_btn = gr.Button("Add New FAQ", variant="primary")
                    add_status_box = gr.Textbox(label="Status", interactive=False)

                with gr.Column(scale=1):
                    gr.Markdown("## ✍️ Pending Review")
                    pending_review_display = gr.Markdown(value="Click 'Refresh' to see pending questions.")
                    refresh_review_btn = gr.Button("Refresh Review List")
            
            add_faq_btn.click(
                fn=handle_add_faqs,
                inputs=[new_question_box, new_answer_box],
                outputs=[add_status_box]
            )

            refresh_review_btn.click(
                fn=get_pending_reviews,
                inputs=None,
                outputs=[pending_review_display]
            )

In [None]:
iface.launch()

* Running on local URL:  http://127.0.0.1:7864
* To create a public link, set `share=True` in `launch()`.




Switching model to : google/flan-t5-base
Model swithched succesfully.
