**==============================================================**
#### **‚úÖ Retrieval-Augmented Chatbot Code**
#### **RAG Chatbot for Patient Risk Prediction Project**
**==============================================================**

#### **üß± 1Ô∏è‚É£ Import Libraries & Setup**

In [39]:
# -----------------------------------------------------
# üß† Step 1: Import all necessary libraries
# -----------------------------------------------------
import os
import json
import numpy as np
from opensearchpy import OpenSearch
from sentence_transformers import SentenceTransformer
from transformers import pipeline


#### **==============================================================**
#### **‚öôÔ∏è Configuration**
#### **==============================================================**

In [40]:
INDEX_NAME = "patient_risk_docs"
EMBED_CACHE_PATH = r"D:\Patient Risk Prediction\Patient-Risk-Prediction\chatbot\cache\embeddings_cache.json"
MODEL_NAME = "BAAI/bge-small-en"
OPENSEARCH_HOST = "localhost"
PORT = 9200

#### **üß† 2 Configure Databricks Connection**

In [41]:
from databricks import sql

# üîê Databricks Connection Config
DATABRICKS_CONFIG = {
    "server_hostname": "XXXXXX-e088.cloud.databricks.com",  # replace with yours
    "http_path": "/sql/1.0/warehouses/XXXXXXXXXXXXX",             # from SQL warehouse settings
    "access_token": "XXXXXXXXXXXXXX"             # your PAT
}

def run_databricks_query(query):
    try:
        with sql.connect(**DATABRICKS_CONFIG) as connection:
            with connection.cursor() as cursor:
                cursor.execute(query)
                result = cursor.fetchall()
                columns = [desc[0] for desc in cursor.description]
                return {"columns": columns, "rows": result}
    except Exception as e:
        print("‚ùå Databricks query error:", e)
        return None


#### **‚öôÔ∏è 2Ô∏è‚É£ OpenSearch Connection Setup**

In [42]:
# -----------------------------------------------------
# ‚öôÔ∏è Step 2: Connect to your local OpenSearch instance
# -----------------------------------------------------
client = OpenSearch(
    hosts=[{"host": OPENSEARCH_HOST, "port": PORT}],
    http_auth=("admin", "admin"),
    use_ssl=False,
)

info = client.info()
print(f"‚úÖ Connected to OpenSearch {info['version']['number']}")


‚úÖ Connected to OpenSearch 2.9.0


#### **üß† Cell 3 ‚Äî Load Embedding Model**

In [43]:
# ------------------------------------------------------------
# 3Ô∏è‚É£ Load Embedding Model (same as index builder)
# ------------------------------------------------------------
embed_model = SentenceTransformer(MODEL_NAME)
print("üß© Embedding model loaded successfully!")

üß© Embedding model loaded successfully!


#### **==============================================================**
#### **üí¨ 4Ô∏è‚É£ Load or Initialize Cache**
#### **==============================================================**

In [44]:
if os.path.exists(EMBED_CACHE_PATH):
    try:
        with open(EMBED_CACHE_PATH, "r", encoding="utf-8") as f:
            cache = json.load(f)
        print(f"‚ö° Loaded {len(cache)} cached embeddings from: {EMBED_CACHE_PATH}")
    except json.JSONDecodeError:
        print("‚ö†Ô∏è Cache file was empty or corrupted. Starting fresh.")
        cache = {}
else:
    cache = {}

‚ö° Loaded 6 cached embeddings from: D:\Patient Risk Prediction\Patient-Risk-Prediction\chatbot\cache\embeddings_cache.json


#### **==============================================================**
#### **üîç 5Ô∏è‚É£ Helper Functions**
#### **==============================================================**

In [45]:
def search_similar_docs(query, k=3):
    """
    Performs a vector similarity search in OpenSearch.
    """
    query_vector = embed_model.encode(query).tolist()

    search_body = {
        "size": k,
        "query": {
            "knn": {
                "embedding": {
                    "vector": query_vector,
                    "k": k
                }
            }
        }
    }

    response = client.search(index=INDEX_NAME, body=search_body)
    hits = response["hits"]["hits"]
    docs = [hit["_source"]["content"] for hit in hits]
    return docs


from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# Load a small, efficient CPU-friendly model
qa_model_name = "google/flan-t5-base"

print("‚öôÔ∏è Loading lightweight model for CPU inference...")
tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(qa_model_name)

qa_pipe = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    device=-1  # Force CPU
)
print("‚úÖ FLAN-T5 model loaded successfully!")

def rag_answer(question, k=3):
    """
    Enhanced RAG pipeline:
    - Retrieves docs from OpenSearch for context.
    - If the question sounds analytical, run live SQL from Databricks.
    - Then generate a final summary answer.
    """
    # Step 1: Retrieve context
    top_docs = search_similar_docs(question, k=k)
    context = "\n\n".join(top_docs)

    # Step 2: Detect if the question needs SQL
    sql_keywords = ["average", "count", "total", "sum", "by", "show", "list", "how many", "compare"]
    if any(kw in question.lower() for kw in sql_keywords):
        print("üßÆ Detected analytical intent ‚Äî generating SQL query...")

        # Ask your QA model to generate SQL
        sql_prompt = f"""
You are an expert data engineer.
Given the question and context, generate an SQL query that can be run on the Databricks patient_risk_prediction database.

Context:
{context}

Question: {question}

Return only the SQL query.
"""
        sql_query = qa_pipe(sql_prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
        print("üßæ Generated SQL:\n", sql_query)

        # Try running on Databricks
        result = run_databricks_query(sql_query)
        if result:
            rows = result["rows"]
            cols = result["columns"]
            print("üìä Databricks Query Result:")
            print(cols)
            print(rows[:5])  # preview top rows

            # Summarize results
            summary_prompt = f"""
You are a data analyst. Summarize this Databricks SQL result clearly.

Question: {question}
Columns: {cols}
Rows: {rows[:5]}
"""
            summary = qa_pipe(summary_prompt, max_new_tokens=100, truncation=True)[0]["generated_text"]
            return summary

    # Step 3: Default to text-only RAG answer
    answer = generate_answer(question, context)
    return answer





def rag_answer(question, k=3):
    """
    Full RAG pipeline: retrieve -> generate -> answer
    """
    top_docs = search_similar_docs(question, k=k)
    context = "\n\n".join(top_docs)
    answer = generate_answer(question, context)
    return answer


‚öôÔ∏è Loading lightweight model for CPU inference...


Device set to use cpu


‚úÖ FLAN-T5 model loaded successfully!


üí¨ 6Ô∏è‚É£ Test the RAG Chatbot**

In [46]:
# ------------------------------------------------------------
# 7Ô∏è‚É£ Test the RAG Chatbot
# ------------------------------------------------------------
question = "What is the average billing amount by insurance company?"
answer = rag_answer(question)
print("ü§ñ Answer:", answer)


ü§ñ Answer: vw_high_risk_patients.sql  config(materialized='view')  SELECT p.patient_sk, h.name, h.age, h.gender, h.medical_condition, h.hospital, h.insurance_provider, ROUND(h.billing_amount, 2) AS billing_amount, h.stay_duration_days, h.date_of_admission, h.discharge_date, CASE WHERE h.medical_condition IN ('Cancer', 'Heart Disease'


In [47]:
question = "Show total billing by insurance?"
answer = rag_answer(question)
print("ü§ñ Answer:", answer)


ü§ñ Answer: vw_high_risk_patients.sql  config(materialized='view')  SELECT p.patient_sk, h.name, h.age, h.gender, h.medical_condition, h.hospital, h.insurance_provider, ROUND(h.billing_amount, 2) AS billing_amount, h.stay_duration_days, h.date_of_admission, h.discharge_date, CASE WHERE h.medical_condition IN ('Cancer', 'Heart Disease'


In [48]:
question = "How many patients were readmitted within 30 days?"
answer = rag_answer(question)
print("ü§ñ Answer:", answer)

ü§ñ Answer: patient_readmission_30d.sql  config( materialized='table', schema='ml')  WITH ordered AS ( SELECT name AS patient_name, gender, age, medical_condition, hospital, insurance_provider, date_of_admission, discharge_date, stay_duration_days, billing_amount, ROW_NUMBER() OVER ( PARTITION BY name ORDER BY date_of_admission ) AS encounter_id, LEAD(date_of_admission) OVER ( PARTITION BY name ORDER BY date_of_


In [49]:
question = "List all tables in the gold layer."
answer = rag_answer(question)
print("ü§ñ Answer:", answer)

ü§ñ Answer: gold_objects_validation.sql USE patient_risk_prediction.gold; SHOW TABLES; select * from patient_risk_prediction.gold.dim_doctor; select * from patient_risk_prediction.gold.dim_patient; DESCRIBE patient_risk_prediction.gold.dim_patient; select count(*) from patient_risk_prediction.gold.dim_doctor; --50000 select * from patient_risk_prediction.gold.dim_doctor limit 10; select doctor_sk,count(*) from patient_risk_prediction.gold.d
