In [140]:
from langgraph.graph import StateGraph, START, END
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import TypedDict
import sqlite3
import json
import os
import re
from dotenv import load_dotenv

In [141]:
load_dotenv()

True

In [142]:
model = ChatGoogleGenerativeAI(
    model = 'gemini-2.5-flash',
    google_api_key = os.getenv('GEMINI_API_KEY')
)

E0000 00:00:1759699116.507699   67601 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


In [143]:
class AgentState(TypedDict):
    user_query: str
    sql_query: str
    validation_passed: bool
    feedback: str
    result: str
    iteration_count: int 


In [144]:
# # -----------------------------
# # 2️⃣ Node 1 - Generate SQL
# # -----------------------------
# def generate_sql(state: AgentState):

#     # Load metadata to tell the model the schema
#     with open("indian_deserts.json") as f:
#         metadata = json.load(f)

#     table = "indian_desserts"
#     schema_info = metadata[table]
#     columns = list(schema_info["columns"].keys())

#     prompt = f"""
#     You are a SQL expert. Generate a valid SQLite query for the user request.
#     Use table '{table}' and columns: {columns}.

#     User Query: {state['user_query']}
#     Validator Feedback: {state.get('feedback', 'None')}
#     """

#     response = model.invoke(prompt)
#     sql_query = response.content.strip("`").replace("sql", "").strip()

#     print("\n🧠 Generated SQL:", sql_query)
    
#     return {**state, "sql_query": sql_query}

def generate_sql(state: AgentState):
    with open("indian_deserts.json") as f:
        metadata = json.load(f)

    table = "indian_desserts"
    schema_info = metadata[table]
    
    # Build richer schema description
    schema_desc = []
    for col_name, col_info in schema_info["columns"].items():
        schema_desc.append(
            f"- {col_name} ({col_info['type']}): {col_info['description']}"
        )
    schema_text = "\n".join(schema_desc)

    prompt = f"""
You are a SQL expert. Generate a valid SQLite query for the user's request.

**Table:** {table}
**Schema:**
{schema_text}

**User Query:** {state['user_query']}
**Validator Feedback:** {state.get('feedback', 'None')}

**CRITICAL INSTRUCTIONS:**
- Return ONLY the raw SQL query with NO extra text, explanations, or prefixes
- Do NOT include markdown code blocks, backticks, or the word "sql"
- Do NOT end the query with a semicolon
- Use proper SQLite syntax
- **When asked about desserts, ALWAYS add: WHERE course = 'dessert'**
- If feedback is provided, fix the issues mentioned

Example: SELECT name, state FROM indian_desserts WHERE course = 'dessert' AND state = 'West Bengal'
"""

    response = model.invoke(prompt)
    
    # ROBUST CLEANING
    sql_query = response.content.strip()
    
    # Remove markdown code blocks
    sql_query = re.sub(r'^```sql\s*', '', sql_query, flags=re.IGNORECASE)
    sql_query = re.sub(r'^```\s*', '', sql_query)
    sql_query = re.sub(r'```$', '', sql_query)
    
    # Remove common prefixes
    sql_query = re.sub(r'^\s*sql\s*', '', sql_query, flags=re.IGNORECASE)
    sql_query = re.sub(r'^\s*sqlite\s*', '', sql_query, flags=re.IGNORECASE)
    
    # Remove trailing semicolon
    sql_query = sql_query.rstrip('; \n\t')
    
    # CRITICAL: Remove any garbage before SELECT
    if not sql_query.upper().startswith('SELECT'):
        match = re.search(r'\bSELECT\b', sql_query, re.IGNORECASE)
        if match:
            sql_query = sql_query[match.start():]
    
    sql_query = sql_query.strip()

    print(f"\n🧠 Generated SQL (Iteration {state['iteration_count']}): {sql_query}")
    
    return {
        **state, 
        "sql_query": sql_query,
        "iteration_count": state["iteration_count"] + 1
    }


In [145]:
def validate_sql(state: AgentState) -> AgentState:
    
    sql_query = state['sql_query']
    feedback = []
    validation_passed = True
    
    # Load metadata and valid columns
    with open("indian_deserts.json") as f:
        metadata = json.load(f)
    
    valid_columns = set(metadata["indian_desserts"]["columns"].keys())
    
    # ------------------------------------------------
    # 1️⃣ LLM-based Column Name Extraction (YOUR IDEA)
    # ------------------------------------------------
    llm_prompt = f"""
Extract ONLY the column names from this SQL query:
---
{sql_query}
---

Rules:
- Extract columns from SELECT and WHERE clauses
- Ignore: SQL keywords, table names, string literals (like 'West Bengal')
- For SELECT *, return: *
- Return ONLY comma-separated column names, nothing else

Examples:
Query: SELECT name, state FROM indian_desserts WHERE prep_time < 30
Output: name, state, prep_time

Query: SELECT * FROM indian_desserts
Output: *
    """
    
    try:
        # NOTE: You'd use your existing 'model' object here
        response = model.invoke(llm_prompt)
        # Clean the response to get a list
        extracted_cols_list = [
            c.strip().lower() 
            for c in response.content.split(',') 
            if c.strip()
        ]
        extracted_columns = set(extracted_cols_list)
        
        print(f"\n🔍 LLM extracted columns: {extracted_columns}")
        
    except Exception as e:
        feedback.append(f"LLM failed to extract column names: {str(e)}")
        validation_passed = False
        extracted_columns = set()
    
    # ------------------------------------------------
    # 2️⃣ Validate column names against metadata (SAME LOGIC)
    # ------------------------------------------------
    # Convert valid_columns to a set for fast checking
    invalid_columns = extracted_columns - valid_columns
    
    if invalid_columns:
        feedback.append(f"Invalid column names found: {invalid_columns}. Valid columns are: {valid_columns}")
        validation_passed = False
        print(f"❌ Invalid columns: {invalid_columns}")
    else:
        print(f"✅ All columns are valid")
        
   # ------------------------------------------------
    # 3️⃣ Dry run the SQL query on desserts.db
    # ------------------------------------------------
    try:
        conn = sqlite3.connect("desserts.db")
        cursor = conn.cursor()
        
        # Execute query with LIMIT to avoid large results
        test_query = sql_query
        if 'LIMIT' not in sql_query.upper():
            test_query += " LIMIT 5"
        
        cursor.execute(test_query)
        results = cursor.fetchall()
        
        print(f"✅ Dry run successful. Sample results: {results[:2]}")
        conn.close()
        
    except sqlite3.Error as e:
        feedback.append(f"SQL execution error: {str(e)}")
        validation_passed = False
        print(f"❌ SQL Error: {str(e)}")
    
    # ------------------------------------------------
    # Return updated state
    # ------------------------------------------------
    feedback_str = " | ".join(feedback) if feedback else "All validations passed"
    
    return {
        **state,
        "validation_passed": validation_passed,
        "feedback": feedback_str
    }

In [146]:
# -----------------------------
# 4️⃣ Node 3 - Execute SQL
# -----------------------------
def execute_sql(state: AgentState):
    with open("indian_deserts.json") as f:
        metadata = json.load(f)

        # Assuming the database path is fixed outside the metadata
    db_path = "desserts.db"

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    try:
        cursor.execute(state["sql_query"])
        rows = cursor.fetchall()
        result = rows if rows else "No results found."
        print("📊 Query Result:", result)
    except Exception as e:
        result = f"Execution Error: {e}"
        print(result)
    finally:
        conn.close()

    return {**state, "result": str(result)}

In [147]:
# Conditional edge for looping logic
def validation_check(state: AgentState):
    return "execute_sql" if state["validation_passed"] else "generate_sql"

In [148]:
# -----------------------------
# 5️⃣ Define Graph Structure
# -----------------------------
graph = StateGraph(AgentState)

# Add nodes
graph.add_node("generate_sql", generate_sql)
graph.add_node("validate_sql", validate_sql)
graph.add_node("execute_sql", execute_sql)

# Add edges
graph.add_edge(START, "generate_sql")  # <-- Entry point!
graph.add_edge("generate_sql", "validate_sql")
graph.add_conditional_edges("validate_sql", validation_check, '')
graph.add_edge("execute_sql", END)

# Compile the graph
workflow = graph.compile()

In [149]:
if __name__ == "__main__":
    print("🍰 Welcome to the Indian Desserts SQL Agent!")
    user_query = input("Enter your question: ")

    initial_state = {
        "user_query": user_query,
        "sql_query": "",
        "validation_passed": False,
        "feedback": "",
        "result": "",
        "iteration_count": 0    # if generate_sql accesses columns
    }

    result = workflow.invoke(initial_state)

    print("\n🎯 Final Query:", result["sql_query"])
    print("📈 Final Result:", result["result"])

🍰 Welcome to the Indian Desserts SQL Agent!

🧠 Generated SQL (Iteration 0): SELECT * FROM indian_desserts WHERE state = 'West Bengal' AND course = 'dessert'

🔍 LLM extracted columns: {'course', 'state', '*'}
❌ Invalid columns: {'*'}
✅ Dry run successful. Sample results: [('Balu shahi', 'Maida flour, yogurt, oil, sugar', 'vegetarian', 45, 25, 'sweet', 'dessert', 'West Bengal', 'East'), ('Gulab jamun', 'Milk powder, plain flour, baking powder, ghee, milk, sugar, water, rose water', 'vegetarian', 15, 40, 'sweet', 'dessert', 'West Bengal', 'East')]

🧠 Generated SQL (Iteration 1): SELECT name, ingredients, diet, prep_time, cook_time, flavor_profile, course, state, region FROM indian_desserts WHERE course = 'dessert' AND state = 'West Bengal'

🔍 LLM extracted columns: {"operators)) + r')\\s*'\n        \n        for condition in conditions:\n            condition = condition.strip()\n            if not condition:\n                continue\n            \n            # split each condition by a