#### **üîπ Step 1: Unify Models**

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from opensearchpy import OpenSearch
from databricks import sql
import json, os

  from .autonotebook import tqdm as notebook_tqdm


#### **-----------------------------------------------------------**
#### **üß† Load Models**
#### **-----------------------------------------------------------**

In [3]:
print("‚öôÔ∏è Loading models...")

# QA Model (Text Understanding / Summarization)
qa_model_name = "google/flan-t5-base"
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
qa_model = AutoModelForSeq2SeqLM.from_pretrained(qa_model_name)
qa_pipe = pipeline("text2text-generation", model=qa_model, tokenizer=qa_tokenizer, device=-1)

# Code / SQL Model (Generative)
code_model_name = "microsoft/phi-1_5"
code_tokenizer = AutoTokenizer.from_pretrained(code_model_name)
code_model = AutoModelForCausalLM.from_pretrained(code_model_name, low_cpu_mem_usage=True)
code_pipe = pipeline("text-generation", model=code_model, tokenizer=code_tokenizer, device=-1)

print("‚úÖ Both QA & CodeGen models loaded successfully!")

‚öôÔ∏è Loading models...


Device set to use cpu
Device set to use cpu


‚úÖ Both QA & CodeGen models loaded successfully!


#### **üîπ Step 2: Set Up Databricks + OpenSearch**

# -----------------------------------------------------------
# ‚öôÔ∏è Databricks Connection
# -----------------------------------------------------------

In [4]:
DATABRICKS_CONFIG = {
    "server_hostname": "XXXXX-e088.cloud.databricks.com",
    "http_path": "/sql/1.0/warehouses/XXXXXXXXXXXXX",
    "access_token": "XXXXXXXXXXXXXXXX"
}

def run_databricks_query(query):
    try:
        with sql.connect(**DATABRICKS_CONFIG) as connection:
            with connection.cursor() as cursor:
                cursor.execute(query)
                result = cursor.fetchall()
                columns = [desc[0] for desc in cursor.description]
                return {"columns": columns, "rows": result}
    except Exception as e:
        print("‚ùå Databricks query error:", e)
        return None

#### **-----------------------------------------------------------**
#### **‚öôÔ∏è OpenSearch Connection**
#### **-----------------------------------------------------------**

In [5]:
client = OpenSearch(
    hosts=[{"host": "localhost", "port": 9200}],
    http_auth=("admin", "admin"),
    use_ssl=False,
)
print("‚úÖ Connected to OpenSearch:", client.info()["version"]["number"])

embed_model = SentenceTransformer("BAAI/bge-small-en")
print("‚úÖ Embedding model ready.")

‚úÖ Connected to OpenSearch: 2.9.0
‚úÖ Embedding model ready.


#### **üîπ Step 3: Smart Routing Function**

In [9]:
import re

def clean_sql_output(raw_text: str) -> str:
    """
    Cleans the raw model output to extract only the SQL code.
    Removes explanations, triple quotes, markdown blocks, etc.
    """
    # Extract text between triple backticks or triple quotes
    match = re.search(r"```sql(.*?)```", raw_text, re.DOTALL | re.IGNORECASE)
    if not match:
        match = re.search(r'"""(.*?)"""', raw_text, re.DOTALL)
    sql = match.group(1).strip() if match else raw_text.strip()

    # Remove leading junk lines like "Generate SQL ..." etc.
    sql = re.sub(r"(?i).*?select", "SELECT", sql, count=1, flags=re.DOTALL)
    # Remove stray markdown or 'python' tags
    sql = re.sub(r"```|python|sql|#.*", "", sql)
    return sql.strip()


def smart_chat(question: str):
    q_lower = question.lower()

    # 1Ô∏è‚É£ Metadata or project-structure questions ‚Üí use RAG
    if any(kw in q_lower for kw in ["table", "schema", "model", "gold layer", "dim_", "fact_"]):
        print("üîç Detected metadata/RAG query ‚Üí Using OpenSearch context.")
        docs = search_similar_docs(question, k=3)
        context = "\n\n".join(docs)

        prompt = f"""
You are a SQL and Data Engineering assistant.

Below is context with SQL snippets from a Databricks project.
Use it to extract and list ALL distinct tables from the 'gold' schema, 
clearly grouped as:
- Fact Tables:
- Dimension Tables:

If you cannot find any, say "No gold tables found."

Context:
{context}

Question: {question}

Answer:
"""
        result = qa_pipe(prompt, max_new_tokens=180, truncation=True)
        answer = result[0]["generated_text"].strip()
        return answer

    # 2Ô∏è‚É£ Analytical (requires SQL)
    elif any(kw in q_lower for kw in ["average", "count", "sum", "total", "compare", "by", "show", "how many", "list"]):
        print("üßÆ Detected analytical question ‚Üí Generate + Execute SQL.")
        sql_prompt = f"""
Generate a **pure SQL query only** (no explanation, no markdown) 
for Databricks database `patient_risk_prediction`.
It has schemas: bronze, silver, gold, ml.
The gold schema includes: dim_date, dim_doctor, dim_hospital, dim_patient, fact_admissions, fact_billing_summary.

Question: {question}

Return ONLY SQL.
"""
        raw_sql = code_pipe(sql_prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
        sql_query = clean_sql_output(raw_sql)
        print("üßæ Cleaned SQL:\n", sql_query)

        # Execute on Databricks
        result = run_databricks_query(sql_query)
        if result and result["rows"]:
            cols, rows = result["columns"], result["rows"]
            summary_prompt = f"""
You are a healthcare data analyst.
Summarize this SQL result clearly in plain English.

Question: {question}
Columns: {cols}
Sample rows: {rows[:5]}
"""
            summary = qa_pipe(summary_prompt, max_new_tokens=120, truncation=True)[0]["generated_text"]
            return summary
        else:
            return f"‚ö†Ô∏è SQL failed or returned no data.\n\nGenerated SQL:\n{sql_query}"

    # 3Ô∏è‚É£ Code generation (Python / ETL)
    elif any(kw in q_lower for kw in ["python", "etl", "pipeline", "spark", "code", "function"]):
        print("üíª Detected code request ‚Üí Generating code...")
        result = code_pipe(question, max_new_tokens=150, temperature=0.7, top_p=0.9)
        return result[0]["generated_text"].strip()

    # 4Ô∏è‚É£ General factual QA
    else:
        print("üìò Default factual QA route.")
        result = qa_pipe(question, max_new_tokens=100)
        return result[0]["generated_text"].strip()


#### **üîπ Step 4: Support Function (RAG Search)**

In [7]:
def search_similar_docs(query, k=3):
    qv = embed_model.encode(query).tolist()
    body = {
        "size": k,
        "query": {"knn": {"embedding": {"vector": qv, "k": k}}}
    }
    res = client.search(index="patient_risk_docs", body=body)
    return [hit["_source"]["content"] for hit in res["hits"]["hits"]]


#### **üîπ Step 5: Test the Smart Chatbot**

In [10]:
print("\nüí¨ Q1: What tables are in the gold layer?")
print("ü§ñ", smart_chat("What tables are in the gold layer?"))

print("\nüí¨ Q2: What is the average billing amount by insurance company?")
print("ü§ñ", smart_chat("What is the average billing amount by insurance company?"))

print("\nüí¨ Q3: Write Python code to load patient data from Snowflake into Databricks.")
print("ü§ñ", smart_chat("Write Python code to load patient data from Snowflake into Databricks."))

print("\nüí¨ Q4: Explain the purpose of patient readmission prediction model.")
print("ü§ñ", smart_chat("Explain the purpose of patient readmission prediction model."))



üí¨ Q1: What tables are in the gold layer?
üîç Detected metadata/RAG query ‚Üí Using OpenSearch context.
ü§ñ gold_objects_validation.sql USE patient_risk_prediction.gold; SHOW TABLES; select * from patient_risk_prediction.gold.dim_doctor; select * from patient_risk_prediction.gold.dim_patient limit 10; DESCRIBE patient_risk_prediction.gold.dim_patient; select count(*) from patient_risk_prediction.gold.dim_doctor; --50000 select * from patient_risk_prediction.gold.dim_doctor limit 10; select doctor_sk,count(*) from patient_risk_prediction.gold.dim_doctor group by doctor_sk having count(*) > 1 ORDER BY doctor_sk ASC

üí¨ Q2: What is the average billing amount by insurance company?
üßÆ Detected analytical question ‚Üí Generate + Execute SQL.
üßæ Cleaned SQL:
 Generate a **pure SQL query only** (no explanation, no markdown) 
for Databricks database `patient_risk_prediction`.
It has schemas: bronze, silver, gold, ml.
The gold schema includes: dim_date, dim_doctor, dim_hospital, dim_p