In [2]:
# run whole script
%run 03_safe_execution.ipynb

SQL blocked by guardrails: Forbidden SQL operation detected

Query: ```SELECT * FROM orders```
Valid: True
Reason: SQL is safe to execute

Query: WITH t AS (SELECT * FROM orders) SELECT * FROM t
Valid: True
Reason: SQL is safe to execute

Query: DELETE FROM orders
Valid: False
Reason: Forbidden SQL operation detected

Query: DROP TABLE products
Valid: False
Reason: Forbidden SQL operation detected

Query: SELECT * FROM orders; DELETE FROM orders
Valid: False
Reason: Forbidden SQL operation detected

Query: UPDATE orders SET order_dow = 1
Valid: False
Reason: Forbidden SQL operation detected


In [3]:
import yaml
from pathlib import Path
from typing import Tuple, Dict, Any, TypedDict, Optional

from langgraph.graph import StateGraph, END 

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

load_dotenv()


True

## Database Connection

In [4]:
import sys
from pathlib import Path
sys.path.append(str(Path("D:/code/text-to-sql-agent")))

from src.db.db_connection import get_db_connection

conn = get_db_connection()
cursor = conn.cursor()

## Load Schema

In [5]:
SCHEMA_PATH = Path("../src/schema/schema_summary.yaml")

with open(SCHEMA_PATH, "r", encoding="utf-8") as f:
    schema_context: Dict[str, Any] = yaml.safe_load(f)

schema_context.keys()


dict_keys(['tables', 'hints', 'common_joins'])

## Initialize LLm

In [6]:
from langchain_openai import ChatOpenAI
import os

def load_llm():
    return ChatOpenAI(
        model="gpt-4.1-mini",
        temperature=0,
        openai_api_key=os.getenv("OPENAI_API_KEY")
    )

def call_llm(prompt: str) -> str:
    llm = load_llm()
    response = llm.invoke(prompt)
    return response.content.strip()


## Prompts

In [7]:
# Optimized SQL generation prompt

def build_optimized_prompt(
    question: str,
    schema: Dict[str, Any]
) -> str:
    return f"""
You are an expert PostgreSQL SQL generator.

CRITICAL RULES:
- Output ONLY one SQL SELECT query
- Do NOT include markdown, backticks, or explanations
- Use ONLY tables and columns from the schema
- Follow join templates strictly
- Never invent joins or columns
- Prefer correctness over brevity

Database schema with semantics:
{schema}

User question:
{question}

SQL:
""".strip()


# Optimized correction prompt

def build_optimized_correction_prompt(
    question: str,
    schema: Dict[str, Any],
    previous_sql: str,
    error_reason: str
) -> str:
    return f"""
The SQL query below is INVALID.

Failure reason:
{error_reason}

Rules to fix:
- Use schema exactly as provided
- Follow join templates
- Do not invent columns or tables
- Output ONLY corrected SQL

Schema:
{schema}

Question:
{question}

Invalid SQL:
{previous_sql}

Corrected SQL:
""".strip()

In [8]:
# # -------------------------------------------------------
# # Generate SQL (optimized)
# # -------------------------------------------------------

def generate_sql_optimized(question: str) -> str:
    prompt = build_optimized_prompt(question, schema_context)
    raw_sql = call_llm(prompt)
    return raw_sql


In [9]:
class SQLAgentState(TypedDict):
    question: str
    sql: Optional[str]
    valid: bool
    reason: Optional[str]
    retries: int
    executed: bool          
    results: Optional[list]

In [10]:
def generate_sql_node(state: SQLAgentState) -> SQLAgentState:
    print("üîÑ Generating SQL...")
    raw_sql = generate_sql_optimized(state["question"])
    sql = clean_sql(raw_sql)
    print(f"Generated: {sql[:100]}...")
    return {**state, "sql": sql}


In [11]:
def validate_sql_node(state: SQLAgentState) -> SQLAgentState:
    print("üîç Validating syntax...")
    is_valid, reason = validate_sql(state["sql"])
    print(f"Syntax valid: {is_valid}")
    if not is_valid:
        print(f"Reason: {reason}")
    return {**state, "valid": is_valid, "reason": None if is_valid else reason, "executed": False}

In [12]:
def execute_sql_node(state: SQLAgentState) -> SQLAgentState:
    print("‚ö° Executing SQL...")
    try:
        cursor.execute(state["sql"])
        results = cursor.fetchall()
        print(f"‚úÖ Executed! Got {len(results)} rows")
        return {**state, "executed": True, "results": results, "reason": None}
    except Exception as e:
        print(f"‚ùå Execution failed: {str(e)[:100]}")
        return {**state, "executed": False, "results": None, "reason": f"Execution error: {str(e)}"}


In [13]:
def validate_execution_node(state: SQLAgentState) -> SQLAgentState:
    print("üîç Validating execution...")
    if not state["executed"]:
        print("Failed - execution error")
        return {**state, "valid": False}
    if not state["results"] or len(state["results"]) == 0:
        print("Failed - no results")
        return {**state, "valid": False, "reason": "Query returned no results"}
    print("‚úÖ Validation passed!")
    return {**state, "valid": True, "reason": None}

In [14]:
def correct_sql_node(state: SQLAgentState) -> SQLAgentState:
    print(f"üîß Correcting SQL (retry {state['retries'] + 1}/{MAX_RETRIES})...")
    prompt = build_optimized_correction_prompt(
        question=state["question"],
        schema=schema_context,
        previous_sql=state["sql"],
        error_reason=state["reason"]
    )
    corrected_sql = call_llm(prompt)
    sql = clean_sql(corrected_sql)
    print(f"Corrected: {sql[:100]}...")
    return {**state, "sql": sql, "retries": state["retries"] + 1}

In [15]:
MAX_RETRIES = 2

def route_after_syntax_check(state: SQLAgentState):
    """Route after pre-execution validation"""
    if state["valid"]:
        return "execute_sql"
    if state["retries"] >= MAX_RETRIES:
        return END
    return "correct_sql"

def route_after_execution(state: SQLAgentState):
    """Route after post-execution validation"""
    if state["valid"]:
        return END
    if state["retries"] >= MAX_RETRIES:
        return END
    return "correct_sql"


In [16]:
graph = StateGraph(SQLAgentState)

# Nodes
graph.add_node("generate_sql", generate_sql_node)
graph.add_node("validate_sql", validate_sql_node)
graph.add_node("execute_sql", execute_sql_node)           # NEW
graph.add_node("validate_execution", validate_execution_node)  # NEW
graph.add_node("correct_sql", correct_sql_node)

# Entry point
graph.set_entry_point("generate_sql")

# Flow
graph.add_edge("generate_sql", "validate_sql")

graph.add_conditional_edges(
    "validate_sql",
    route_after_syntax_check,
    {
        "execute_sql": "execute_sql",
        "correct_sql": "correct_sql",
        END: END
    }
)

graph.add_edge("execute_sql", "validate_execution")  # NEW

graph.add_conditional_edges(
    "validate_execution",
    route_after_execution,
    {
        "correct_sql": "correct_sql",
        END: END
    }
)

graph.add_edge("correct_sql", "validate_sql")

sql_agent = graph.compile()

In [17]:
initial_state: SQLAgentState = {
    "question": "how many products are there?",
    "sql": None,
    "valid": False,
    "reason": None,
    "retries": 0,
    "executed": False,   # NEW
    "results": None      # NEW
}

final_state = sql_agent.invoke(initial_state)

# Updated output
if final_state["valid"] and final_state["executed"]:
    print("‚úÖ Success!")
    print(f"SQL: {final_state['sql']}")
    print(f"Results: {len(final_state['results'])} rows")
    print(final_state['results'][:5])  # Show first 5 rows
else:
    print("‚ùå Failed")
    print(f"Reason: {final_state['reason']}")

üîÑ Generating SQL...
Generated: select count(*) as product_count from products;...
üîç Validating syntax...
Syntax valid: True
‚ö° Executing SQL...
‚úÖ Executed! Got 1 rows
üîç Validating execution...
‚úÖ Validation passed!
‚úÖ Success!
SQL: select count(*) as product_count from products;
Results: 1 rows
[(49688,)]
