In [None]:
# run whole script
%run 03_safe_execution.ipynb

In [49]:
import yaml
from pathlib import Path
from typing import Tuple, Dict, Any, TypedDict, Optional

from langgraph.graph import StateGraph, END


In [50]:
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

load_dotenv()

True

In [51]:
import sys
from pathlib import Path
sys.path.append(str(Path("D:/code/text-to-sql-agent")))

from src.db.db_connection import get_db_connection

conn = get_db_connection()
cursor = conn.cursor()

In [52]:
# -------------------------------------------------------
# Load enriched schema (semantics + joins)
# -------------------------------------------------------

SCHEMA_PATH = Path("../src/schema/schema_summary.yaml")

with open(SCHEMA_PATH, "r", encoding="utf-8") as f:
    schema_context: Dict[str, Any] = yaml.safe_load(f)

schema_context.keys()


dict_keys(['tables', 'hints', 'common_joins'])

In [53]:
# -------------------------------------------------------
# Optimized SQL generation prompt
# -------------------------------------------------------

def build_optimized_prompt(
    question: str,
    schema: Dict[str, Any]
) -> str:
    return f"""
You are an expert PostgreSQL SQL generator.

CRITICAL RULES:
- Output ONLY one SQL SELECT query
- Do NOT include markdown, backticks, or explanations
- Use ONLY tables and columns from the schema
- Follow join templates strictly
- Never invent joins or columns
- Prefer correctness over brevity

Database schema with semantics:
{schema}

User question:
{question}

SQL:
""".strip()


In [54]:
from langchain_openai import ChatOpenAI
import os

def load_llm():
    return ChatOpenAI(
        model="gpt-4.1-mini",
        temperature=0,
        openai_api_key=os.getenv("OPENAI_API_KEY")
    )


In [55]:
# -------------------------------------------------------
# LLM call using ChatOpenAI
# -------------------------------------------------------

def call_llm(prompt: str) -> str:
    llm = load_llm()
    response = llm.invoke(prompt)
    return response.content.strip()


In [56]:
# -------------------------------------------------------
# Generate SQL (optimized)
# -------------------------------------------------------

def generate_sql_optimized(question: str) -> str:
    prompt = build_optimized_prompt(question, schema_context)
    raw_sql = call_llm(prompt)
    return raw_sql


In [57]:
# -------------------------------------------------------
# Optimized correction prompt
# -------------------------------------------------------

def build_optimized_correction_prompt(
    question: str,
    schema: Dict[str, Any],
    previous_sql: str,
    error_reason: str
) -> str:
    return f"""
The SQL query below is INVALID.

Failure reason:
{error_reason}

Rules to fix:
- Use schema exactly as provided
- Follow join templates
- Do not invent columns or tables
- Output ONLY corrected SQL

Schema:
{schema}

Question:
{question}

Invalid SQL:
{previous_sql}

Corrected SQL:
""".strip()


In [58]:
class SQLAgentState(TypedDict):
    question: str
    sql: Optional[str]
    valid: bool
    reason: Optional[str]
    retries: int

In [59]:
def generate_sql_node(state: SQLAgentState) -> SQLAgentState:
    raw_sql = generate_sql_optimized(state["question"])
    sql = clean_sql(raw_sql)

    return {
        **state,
        "sql": sql
    }


In [60]:
def validate_sql_node(state: SQLAgentState) -> SQLAgentState:
    is_valid, reason = validate_sql(state["sql"])

    return {
        **state,
        "valid": is_valid,
        "reason": None if is_valid else reason
    }


In [63]:
def correct_sql_node(state: SQLAgentState) -> SQLAgentState:
    prompt = build_optimized_correction_prompt(
        question=state["question"],
        schema=schema_context,
        previous_sql=state["sql"],
        error_reason=state["reason"]
    )

    corrected_sql = call_llm(prompt)
    sql = clean_sql(corrected_sql)

    return {
        **state,
        "sql": sql,
        "retries": state["retries"] + 1
    }


In [64]:
MAX_RETRIES = 2

def route_after_validation(state: SQLAgentState):
    if state["valid"]:
        return END
    if state["retries"] >= MAX_RETRIES:
        return END
    return "correct_sql"


In [65]:
graph = StateGraph(SQLAgentState)

graph.add_node("generate_sql", generate_sql_node)
graph.add_node("validate_sql", validate_sql_node)
graph.add_node("correct_sql", correct_sql_node)

graph.set_entry_point("generate_sql")

graph.add_edge("generate_sql", "validate_sql")

graph.add_conditional_edges(
    "validate_sql",
    route_after_validation,
    {
        "correct_sql": "correct_sql",
        END: END
    }
)

graph.add_edge("correct_sql", "validate_sql")

sql_agent = graph.compile()


In [66]:
def correct_sql_node(state: SQLAgentState) -> SQLAgentState:
    prompt = build_optimized_correction_prompt(
        question=state["question"],
        schema=schema_context,
        previous_sql=state["sql"],
        error_reason=state["reason"]
    )

    corrected_sql = call_llm(prompt)
    sql = clean_sql(corrected_sql)

    return {
        **state,
        "sql": sql,
        "retries": state["retries"] + 1
    }


In [67]:
initial_state: SQLAgentState = {
    "question": "Which products are most frequently reordered? give top 10",
    "sql": None,
    "valid": False,
    "reason": None,
    "retries": 0
}

final_state = sql_agent.invoke(initial_state)

if final_state["valid"]:
    print("Final SQL:")
    print(final_state["sql"])
else:
    print("Failed:")
    print(final_state["reason"])

Final SQL:
select p.product_id, p.product_name, sum(op.reordered) as total_reorders from order_products_prior op inner join products p on op.product_id = p.product_id group by p.product_id, p.product_name order by total_reorders desc limit 10;
