In [None]:
# run whole script
%run 03_safe_execution.ipynb

In [3]:
import yaml
from pathlib import Path
from typing import Tuple, Dict, Any

In [9]:
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

load_dotenv()

True

In [4]:
import sys
from pathlib import Path
sys.path.append(str(Path("D:/code/text-to-sql-agent")))
# from src.db.db_connection import get_db_connection

# conn = get_db_connection()
# cursor = conn.cursor()

In [6]:
# -------------------------------------------------------
# Load enriched schema (semantics + joins)
# -------------------------------------------------------

SCHEMA_PATH = Path("../src/schema/schema_summary.yaml")

with open(SCHEMA_PATH, "r", encoding="utf-8") as f:
    schema_context: Dict[str, Any] = yaml.safe_load(f)

schema_context.keys()


dict_keys(['tables', 'hints', 'common_joins'])

In [7]:
# -------------------------------------------------------
# Optimized SQL generation prompt
# -------------------------------------------------------

def build_optimized_prompt(
    question: str,
    schema: Dict[str, Any]
) -> str:
    return f"""
You are an expert PostgreSQL SQL generator.

CRITICAL RULES:
- Output ONLY one SQL SELECT query
- Do NOT include markdown, backticks, or explanations
- Use ONLY tables and columns from the schema
- Follow join templates strictly
- Never invent joins or columns
- Prefer correctness over brevity

Database schema with semantics:
{schema}

User question:
{question}

SQL:
""".strip()


In [16]:
from langchain_openai import ChatOpenAI
import os

def load_llm():
    return ChatOpenAI(
        model="gpt-4.1-mini",
        temperature=0,
        openai_api_key=os.getenv("OPENAI_API_KEY")
    )


In [11]:
# -------------------------------------------------------
# LLM call using ChatOpenAI
# -------------------------------------------------------

def call_llm(prompt: str) -> str:
    llm = load_llm()
    response = llm.invoke(prompt)
    return response.content.strip()


In [12]:
# -------------------------------------------------------
# Generate SQL (optimized)
# -------------------------------------------------------

def generate_sql_optimized(question: str) -> str:
    prompt = build_optimized_prompt(question, schema_context)
    raw_sql = call_llm(prompt)
    return raw_sql


In [13]:
# -------------------------------------------------------
# Optimized correction prompt
# -------------------------------------------------------

def build_optimized_correction_prompt(
    question: str,
    schema: Dict[str, Any],
    previous_sql: str,
    error_reason: str
) -> str:
    return f"""
The SQL query below is INVALID.

Failure reason:
{error_reason}

Rules to fix:
- Use schema exactly as provided
- Follow join templates
- Do not invent columns or tables
- Output ONLY corrected SQL

Schema:
{schema}

Question:
{question}

Invalid SQL:
{previous_sql}

Corrected SQL:
""".strip()


In [None]:
def optimized_sql_agent(
    question: str,
    max_retries: int = 2
) -> Tuple[bool, str]:

    raw_sql = generate_sql_optimized(question)
    sql = preprocess_sql(raw_sql)
    sql = normalize_sql(sql)

    is_valid, reason = validate_sql(sql)
    if is_valid:
        return True, sql

    for _ in range(max_retries):
        correction_prompt = build_optimized_correction_prompt(
            question,
            schema_context,
            cleaned_sql,
            reason
        )

        corrected_sql = call_llm(correction_prompt)
        sql = preprocess_sql(corrected_sql)
        sql = normalize_sql(sql)

        is_valid, reason = validate_sql(sql)
        if is_valid:
            return True, sql

    return False, reason


In [21]:
question = "Which products are most frequently reordered? give top 10"

success, result = optimized_sql_agent(question)

if success:
    print("Final SQL:")
    print(result)
else:
    print("Failed:")
    print(result)


Final SQL:
select p.product_id, p.product_name, sum(op.reordered) as total_reorders from order_products_prior op inner join products p on op.product_id = p.product_id group by p.product_id, p.product_name order by total_reorders desc limit 10;
