In [1]:
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine
import pandas as pd
import re
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
import gradio as gr

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from langchain_openai import AzureChatOpenAI

In [5]:
# 1. Load environment variables and establish a connection to the SQL database
load_dotenv()
connectionString = os.environ["py-connectionString"]
db_engine = create_engine(connectionString)

In [6]:
def get_schema_info():
    query = """
    SELECT [table_schema], [table_name], [column_name], [data_type], [description]
    FROM [AdventureWorks2022].[dbo].[schema_metadata]
    """
    df = pd.read_sql(query, db_engine)
    prompt_lines = []
    for (schema, table), group in df.groupby(['table_schema', 'table_name']):
        prompt_lines.append(f"Table: {table} (schema: {schema})")
        prompt_lines.append("Columns:")
        for _, row in group.iterrows():
            prompt_lines.append(f"  - {row['column_name']} ({row['data_type']}): {row['description']}")
        prompt_lines.append("")
    return "\n".join(prompt_lines)

custom_schema_info = get_schema_info()

  con = self.exit_stack.enter_context(con.connect())


In [7]:
# 3. Build a system prompt for the LLM
system_prompt = (
    "You are an expert SQL assistant. "
    "Use ONLY the following database schema information to answer questions and generate SQL queries. "
    "For every SQL query you generate, always use the '(NOLOCK)' table hint after every table name in the FROM and JOIN clauses to avoid locking. "
    "Do not assume the existence of any tables or columns not listed below. "
    "If the answer cannot be found using this schema, say so.\n\n"
    f"{custom_schema_info}"
)


In [9]:
# 4. LLM and SQLDatabase setup
llm = AzureChatOpenAI(
    azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],  # e.g., "gpt-4.1-nano"
    azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
    api_key=os.environ["AZURE_OPENAI_API_KEY"],
    api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2024-02-15-preview"),
    temperature=0,
)

db = SQLDatabase(db_engine)
execute_query = QuerySQLDatabaseTool(db=db)
write_query = create_sql_query_chain(llm, db)

In [11]:
# 5. Utility to extract SQL code from LLM output
def extract_sql(text):
    match = re.search(r"```sql\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    match = re.search(r"(SELECT[\s\S]+?;)", text, re.IGNORECASE)
    if match:
        return match.group(1).strip()
    return text.strip()

In [12]:
def ask_sql_and_explain(question):
    # 1. Generate SQL query from LLM using schema info
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    sql_query = llm.invoke(messages)
    sql_code = extract_sql(sql_query.content if hasattr(sql_query, "content") else sql_query)
    print("Executing SQL query:\n", sql_code)
    result = execute_query.invoke({"query": sql_code})

    # 2. Check for error in result
    result_str = str(result)
    error_detected = any(word in result_str.lower() for word in [
        "error", "exception", "traceback", "incorrect syntax", "invalid column", "failed", "invalid object", "does not exist"
    ])

    # 3. If error, ask LLM to correct the query and retry ONCE
    if error_detected:
        correction_prompt = (
            f"The following SQL query failed to execute:\n"
            f"{sql_code}\n\n"
            f"Error message:\n{result_str}\n\n"
            "Based ONLY on the schema below, please generate a corrected SQL query. "
            "Do not explain, just provide the corrected SQL in a code block.\n\n"
            f"{custom_schema_info}"
        )
        correction = llm.invoke(correction_prompt)
        corrected_sql = extract_sql(correction.content if hasattr(correction, "content") else correction)
        print("Retrying with corrected SQL query:\n", corrected_sql)
        result = execute_query.invoke({"query": corrected_sql})
        sql_code = corrected_sql  # Update to show the retried query
        result_str = str(result)

        # 4. If still error, check if LLM's explanation contains another SQL code block and try once more
        error_detected = any(word in result_str.lower() for word in [
            "error", "exception", "traceback", "incorrect syntax", "invalid column", "failed", "invalid object", "does not exist"
        ])
        if error_detected:
            # Try to extract another SQL code block from the LLM's explanation
            another_sql = extract_sql(result_str)
            if another_sql and another_sql != corrected_sql:
                print("Trying SQL code found in LLM's explanation:\n", another_sql)
                result = execute_query.invoke({"query": another_sql})
                sql_code = another_sql
                result_str = str(result)

    # 5. Format result for LLM explanation
    if hasattr(result, "to_markdown"):
        result_str = result.to_markdown(index=False)
    else:
        result_str = str(result)

    # 6. Ask LLM to explain the result
    followup_prompt = (
        f"Question: {question}\n"
        f"SQL Query Executed:\n{sql_code}\n"
        f"SQL Result:\n{result_str}\n\n"
        "Please provide a clear, concise, and well-formatted answer to the question based on the SQL result above."
    )
    answer = llm.invoke(followup_prompt)
    return answer, sql_code

In [13]:
# 7. Gradio integration
def gradio_ask(question):
    answer, sql_code = ask_sql_and_explain(question)
    if hasattr(answer, "content"):
        return answer.content, sql_code
    return str(answer), sql_code

demo = gr.Interface(
    fn=gradio_ask,
    inputs=gr.Textbox(label="Ask a database question"),
    outputs=[
        gr.Textbox(label="LLM Answer"),
        gr.Textbox(label="SQL Query Executed")
    ],
    title="SQL AI Assistant",
    description="Ask a question about your database and get a natural language answer. See the SQL query generated and executed."
)

demo.launch()

* Running on local URL:  http://127.0.0.1:7865
* To create a public link, set `share=True` in `launch()`.




Executing SQL query:
 SELECT TOP 10 
    p.ProductID,
    p.Name,
    SUM(sod.OrderQty) AS TotalQuantitySold
FROM 
    Production.Product p (NOLOCK)
JOIN 
    Sales.SalesOrderDetail sod (NOLOCK) ON p.ProductID = sod.ProductID
JOIN 
    Sales.SalesOrderHeader soh (NOLOCK) ON sod.SalesOrderID = soh.SalesOrderID
WHERE 
    YEAR(soh.OrderDate) = 2014
GROUP BY 
    p.ProductID,
    p.Name
ORDER BY 
    TotalQuantitySold DESC;
Executing SQL query:
 The provided schema does not include information about product names or specific product details such as "AWC Logo Cap." Therefore, I cannot determine how many "AWC Logo Cap" items were sold based on the available tables and columns.
Retrying with corrected SQL query:
 SELECT 
    p.Name AS ProductName,
    SUM(sod.OrderQty) AS TotalSold
FROM 
    Production.Product p
JOIN 
    Sales.SalesOrderDetail sod ON p.ProductID = sod.ProductID
WHERE 
    p.Name = 'AWC Logo Cap'
GROUP BY 
    p.Name;
Executing SQL query:
 The provided schema does not includ