# SQL Query Agent with Ollama - Exploration Notebook

**Sprint 1 Deliverable** | **DSM 1.0 Track**

This notebook builds and evaluates a natural language to SQL agent using open-source LLMs running locally via Ollama. It serves as a text-to-code generation testbed -- SQL is a constrained language ideal for systematic evaluation.

**Architecture:** LangGraph state graph with schema filtering, SQL validation (sqlglot), and error-correction retry loop. Design informed by DIN-SQL, MAC-SQL, and CHESS research.

**Goals:**
1. Verify environment (Ollama, models, database)
2. Build LangGraph agent: `schema_filter` → `generate_sql` → `validate_query` → `execute_query` → `handle_error`
3. Evaluate at least 2 models (sqlcoder:7b vs llama3.1:8b) on a curated test suite

**References:**
- Research: `docs/research/text_to_sql_state_of_art.md`
- Plan: `docs/plans/PLAN.md`
- Sprint plan: `docs/plans/sprint-1-plan.md`


In [1]:
# cell 2: Imports & Configuration
from sqlalchemy import create_engine, inspect, text
from langchain_ollama import ChatOllama
import sqlglot
import time

# Configuration
OLLAMA_BASE_URL = "http://172.27.64.1:11434"  # Ollama on Windows, accessed via WSL gateway
DB_PATH = "../data/chinook.db"
PRIMARY_MODEL = "sqlcoder:7b"
BASELINE_MODEL = "llama3.1:8b"

# Database engine
engine = create_engine(f"sqlite:///{DB_PATH}")
print(f"Database engine created: {engine.url}")


Database engine created: sqlite:///../data/chinook.db


In [2]:
# Cell 3: Database Schema Inspection

inspector = inspect(engine)
tables = inspector.get_table_names()

print(f"Chinook Database: {len(tables)} tables\n")

schema_info = {}
for table_name in tables:
    columns = inspector.get_columns(table_name)
    pk = inspector.get_pk_constraint(table_name)
    fks = inspector.get_foreign_keys(table_name)
    
    schema_info[table_name] = {
        "columns": columns,
        "pk": pk.get("constrained_columns", []),
        "fks": fks,
    }
    
    # Display
    pk_cols = set(pk.get("constrained_columns", []))
    print(f"{table_name}")
    for col in columns:
        marker = " (PK)" if col["name"] in pk_cols else ""
        print(f"   {col['name']}: {col['type']}{marker}")
    for fk in fks:
        print(f"   FK: {fk['constrained_columns']} -> {fk['referred_table']}.{fk['referred_columns']}")
    print()


Chinook Database: 11 tables

Album
   AlbumId: INTEGER (PK)
   Title: NVARCHAR(160)
   ArtistId: INTEGER
   FK: ['ArtistId'] -> Artist.['ArtistId']

Artist
   ArtistId: INTEGER (PK)
   Name: NVARCHAR(120)

Customer
   CustomerId: INTEGER (PK)
   FirstName: NVARCHAR(40)
   LastName: NVARCHAR(20)
   Company: NVARCHAR(80)
   Address: NVARCHAR(70)
   City: NVARCHAR(40)
   State: NVARCHAR(40)
   Country: NVARCHAR(40)
   PostalCode: NVARCHAR(10)
   Phone: NVARCHAR(24)
   Fax: NVARCHAR(24)
   Email: NVARCHAR(60)
   SupportRepId: INTEGER
   FK: ['SupportRepId'] -> Employee.['EmployeeId']

Employee
   EmployeeId: INTEGER (PK)
   LastName: NVARCHAR(20)
   FirstName: NVARCHAR(20)
   Title: NVARCHAR(30)
   ReportsTo: INTEGER
   BirthDate: DATETIME
   HireDate: DATETIME
   Address: NVARCHAR(70)
   City: NVARCHAR(40)
   State: NVARCHAR(40)
   Country: NVARCHAR(40)
   PostalCode: NVARCHAR(10)
   Phone: NVARCHAR(24)
   Fax: NVARCHAR(24)
   Email: NVARCHAR(60)
   FK: ['ReportsTo'] -> Employee.['Employe

In [3]:
# Cell 4: Row Counts

print("Row counts:")
with engine.connect() as conn:
    for table_name in tables:
        count = conn.execute(text(f"SELECT COUNT(*) FROM [{table_name}]")).scalar()
        print(f"  {table_name}: {count:,} rows")


Row counts:
  Album: 347 rows
  Artist: 275 rows
  Customer: 59 rows
  Employee: 8 rows
  Genre: 25 rows
  Invoice: 412 rows
  InvoiceLine: 2,240 rows
  MediaType: 5 rows
  Playlist: 18 rows
  PlaylistTrack: 8,715 rows
  Track: 3,503 rows


In [4]:
# Cell 5: Verify Ollama Connectivity & Available Models
import requests

try:
    response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=10)
    response.raise_for_status()
    models = response.json().get("models", [])
    print(f"Ollama is running at {OLLAMA_BASE_URL}")
    print(f"Available models: {len(models)}\n")
    for m in models:
        size_gb = m.get("size", 0) / (1024**3)
        print(f"  {m['name']:30s} {size_gb:.1f} GB")
except requests.ConnectionError:
    print(f"ERROR: Cannot connect to Ollama at {OLLAMA_BASE_URL}")
    print("Make sure Ollama is running on the Windows host.")
except Exception as e:
    print(f"ERROR: {e}")


Ollama is running at http://172.27.64.1:11434
Available models: 4

  llama3.1:8b                    4.6 GB
  sqlcoder:7b                    3.8 GB
  gemma3:1b                      0.8 GB
  llama3:latest                  4.3 GB


In [5]:
# Cell 6: Test LLM Connectivity
for model_name in [PRIMARY_MODEL, BASELINE_MODEL]:
    print(f"Testing {model_name}...")
    llm = ChatOllama(model=model_name, base_url=OLLAMA_BASE_URL, temperature=0)
    t0 = time.time()
    response = llm.invoke("Return only the SQL: SELECT 1")
    elapsed = time.time() - t0
    print(f"  Response: {response.content.strip()[:100]}")
    print(f"  Latency: {elapsed:.1f}s\n")


Testing sqlcoder:7b...
  Response: AS "column_name" FROM DUAL UNION ALL SELECT 2 AS "column_name" FROM DUAL UNION ALL SELECT 3 AS "colu
  Latency: 31.9s

Testing llama3.1:8b...
  Response: `SELECT 1;`
  Latency: 10.9s



In [6]:
# Cell 7: Sample Rows for Few-Shot Prompting
sample_tables = ["Artist", "Album", "Track", "Customer", "Invoice", "InvoiceLine"]

sample_rows = {}
with engine.connect() as conn:
    for table_name in sample_tables:
        rows = conn.execute(text(f"SELECT * FROM [{table_name}] LIMIT 3")).fetchall()
        keys = conn.execute(text(f"SELECT * FROM [{table_name}] LIMIT 1")).keys()
        sample_rows[table_name] = {"columns": list(keys), "rows": rows}
        
        print(f"{table_name}:")
        print(f"  Columns: {list(keys)}")
        for row in rows:
            print(f"  {list(row)}")
        print()


Artist:
  Columns: ['ArtistId', 'Name']
  [1, 'AC/DC']
  [2, 'Accept']
  [3, 'Aerosmith']

Album:
  Columns: ['AlbumId', 'Title', 'ArtistId']
  [1, 'For Those About To Rock We Salute You', 1]
  [2, 'Balls to the Wall', 2]
  [3, 'Restless and Wild', 2]

Track:
  Columns: ['TrackId', 'Name', 'AlbumId', 'MediaTypeId', 'GenreId', 'Composer', 'Milliseconds', 'Bytes', 'UnitPrice']
  [1, 'For Those About To Rock (We Salute You)', 1, 1, 1, 'Angus Young, Malcolm Young, Brian Johnson', 343719, 11170334, 0.99]
  [2, 'Balls to the Wall', 2, 2, 1, 'U. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann', 342562, 5510424, 0.99]
  [3, 'Fast As a Shark', 3, 2, 1, 'F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman', 230619, 3990994, 0.99]

Customer:
  Columns: ['CustomerId', 'FirstName', 'LastName', 'Company', 'Address', 'City', 'State', 'Country', 'PostalCode', 'Phone', 'Fax', 'Email', 'SupportRepId']
  [1, 'Luís', 'Gonçalves', 'Embraer - Empresa Brasileira de Aeronáutica S.A.'

## Phase 2: Core Agent Build

LangGraph agent with 5 nodes:
1. **schema_filter** — select relevant tables for the question
2. **generate_sql** — LLM generates SQL from question + filtered schema
3. **validate_query** — sqlglot parses SQL, security check (block writes)
4. **execute_query** — run against SQLite, capture results or errors
5. **handle_error** — on failure, feed error back to LLM for retry (max 3)


In [7]:
# Cell 9: Agent State Definition
from typing import TypedDict, Optional

class AgentState(TypedDict):
    question: str                    # User's natural language question
    relevant_tables: list[str]       # Tables selected by schema_filter
    schema_text: str                 # DDL/schema for relevant tables
    generated_sql: str               # SQL produced by LLM
    is_valid: bool                   # sqlglot parse + security check passed
    validation_error: str            # Error message if validation fails
    results: Optional[list]          # Query results from SQLite
    error: str                       # Execution error message
    retry_count: int                 # Current retry attempt (max 3)
    model_name: str                  # Which Ollama model to use

print("AgentState defined with fields:")
for field, ftype in AgentState.__annotations__.items():
    print(f"  {field}: {ftype}")


AgentState defined with fields:
  question: <class 'str'>
  relevant_tables: list[str]
  schema_text: <class 'str'>
  generated_sql: <class 'str'>
  is_valid: <class 'bool'>
  validation_error: <class 'str'>
  results: typing.Optional[list]
  error: <class 'str'>
  retry_count: <class 'int'>
  model_name: <class 'str'>


In [8]:
# Cell 10: Node 1 — Schema Filter
def schema_filter(state: AgentState) -> dict:
    """Select relevant tables based on question keywords."""
    question_lower = state["question"].lower()
    question_words = set(question_lower.replace("?", "").replace(",", "").split())
    
    scored_tables = []
    for table_name, info in schema_info.items():
        score = 0
        table_lower = table_name.lower()
        
        # Table name match (strongest signal)
        if table_lower in question_lower:
            score += 3
        # Partial table name match
        for word in question_words:
            if word in table_lower or table_lower in word:
                score += 2
        # Column name match
        for col in info["columns"]:
            col_lower = col["name"].lower()
            if col_lower in question_lower:
                score += 1
            for word in question_words:
                if word in col_lower:
                    score += 0.5
        
        if score > 0:
            scored_tables.append((table_name, score))
    
    # Sort by score, take top tables; always include at least FK-connected tables
    scored_tables.sort(key=lambda x: x[1], reverse=True)
    selected = [t[0] for t in scored_tables[:5]]  # max 5 tables
    
    # Add FK-connected tables for selected tables
    for table_name in list(selected):
        for fk in schema_info[table_name]["fks"]:
            referred = fk["referred_table"]
            if referred not in selected:
                selected.append(referred)
    
    # Fallback: if nothing matched, include all tables
    if not selected:
        selected = list(schema_info.keys())
    
    # Build schema text for selected tables
    schema_lines = []
    for table_name in selected:
        info = schema_info[table_name]
        cols = ", ".join(
            f"{c['name']} {c['type']}" for c in info["columns"]
        )
        schema_lines.append(f"CREATE TABLE {table_name} ({cols});")
        # Add sample rows
        if table_name in sample_rows:
            sr = sample_rows[table_name]
            schema_lines.append(f"-- Sample: {sr['rows'][0]}")
    
    schema_text = "\n".join(schema_lines)
    
    print(f"Question: {state['question']}")
    print(f"Selected tables ({len(selected)}): {selected}")
    print(f"Schema text length: {len(schema_text)} chars")
    
    return {"relevant_tables": selected, "schema_text": schema_text}

# Quick test
test_state = {
    "question": "How many albums does each artist have?",
    "relevant_tables": [], "schema_text": "", "generated_sql": "",
    "is_valid": False, "validation_error": "", "results": None,
    "error": "", "retry_count": 0, "model_name": PRIMARY_MODEL,
}
schema_filter(test_state)


Question: How many albums does each artist have?
Selected tables (2): ['Album', 'Artist']
Schema text length: 219 chars


{'relevant_tables': ['Album', 'Artist'],
 'schema_text': "CREATE TABLE Album (AlbumId INTEGER, Title NVARCHAR(160), ArtistId INTEGER);\n-- Sample: (1, 'For Those About To Rock We Salute You', 1)\nCREATE TABLE Artist (ArtistId INTEGER, Name NVARCHAR(120));\n-- Sample: (1, 'AC/DC')"}

In [9]:
# Cell 11: Node 2 — Generate SQL
SQL_PROMPT_TEMPLATE = """You are a SQL expert. Generate a SQLite-compatible SELECT query for the question below.

Schema:
{schema_text}

Rules:
- Return ONLY the SQL query, no explanation
- Use only SELECT statements
- Use only tables and columns from the schema above
- Use SQLite syntax

Question: {question}

SQL:"""

def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM."""
    prompt = SQL_PROMPT_TEMPLATE.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(
        model=state["model_name"],
        base_url=OLLAMA_BASE_URL,
        temperature=0,
    )
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    # Clean response: strip markdown fences and whitespace
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    
    print(f"Model: {state['model_name']}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Quick test
test_state_filtered = {**test_state, **schema_filter(test_state)}
generate_sql(test_state_filtered)


Question: How many albums does each artist have?
Selected tables (2): ['Album', 'Artist']
Schema text length: 219 chars
Model: sqlcoder:7b
Generated SQL: 
Latency: 8.1s


{'generated_sql': ''}

In [10]:
# Cell 12: Diagnose sqlcoder prompt format
SQLCODER_PROMPT_TEMPLATE = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
{schema_text}

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql
"""

prompt_generic = SQL_PROMPT_TEMPLATE.format(
    schema_text=test_state_filtered["schema_text"],
    question=test_state_filtered["question"],
)
prompt_sqlcoder = SQLCODER_PROMPT_TEMPLATE.format(
    schema_text=test_state_filtered["schema_text"],
    question=test_state_filtered["question"],
)

llm = ChatOllama(model=PRIMARY_MODEL, base_url=OLLAMA_BASE_URL, temperature=0)

print("=== Generic prompt ===")
t0 = time.time()
r1 = llm.invoke(prompt_generic)
print(f"Response: '{r1.content.strip()[:200]}'")
print(f"Latency: {time.time() - t0:.1f}s\n")

print("=== sqlcoder-style prompt ===")
t0 = time.time()
r2 = llm.invoke(prompt_sqlcoder)
print(f"Response: '{r2.content.strip()[:200]}'")
print(f"Latency: {time.time() - t0:.1f}s")


=== Generic prompt ===
Response: ''
Latency: 22.5s

=== sqlcoder-style prompt ===
Response: 'SELECT a.artistid, COUNT(*) AS num_albums FROM album a GROUP BY a.artistid;
```'
Latency: 5.5s


In [11]:
# Cell 13: Node 2 (revised) — Generate SQL with model-aware prompts
GENERIC_PROMPT = """You are a SQL expert. Generate a SQLite-compatible SELECT query for the question below.

Schema:
{schema_text}

Rules:
- Return ONLY the SQL query, no explanation
- Use only SELECT statements
- Use only tables and columns from the schema above
- Use SQLite syntax

Question: {question}

SQL:"""

SQLCODER_PROMPT = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
{schema_text}

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql
"""

def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM."""
    model = state["model_name"]
    
    # Select prompt template based on model
    if "sqlcoder" in model:
        template = SQLCODER_PROMPT
    else:
        template = GENERIC_PROMPT
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    # Clean response: strip markdown fences and whitespace
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    # Remove trailing semicolons for consistency
    sql = sql.rstrip(";").strip()
    
    print(f"Model: {model}")
    print(f"Prompt: {'sqlcoder' if 'sqlcoder' in model else 'generic'}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Test with both models
test_state_filtered = {**test_state, **schema_filter(test_state)}
print("--- sqlcoder:7b ---")
generate_sql(test_state_filtered)
print()
print("--- llama3.1:8b ---")
generate_sql({**test_state_filtered, "model_name": BASELINE_MODEL})


Question: How many albums does each artist have?
Selected tables (2): ['Album', 'Artist']
Schema text length: 219 chars
--- sqlcoder:7b ---
Model: sqlcoder:7b
Prompt: sqlcoder
Generated SQL: SELECT a.artistid, COUNT(*) AS num_albums FROM album a GROUP BY a.artistid
Latency: 4.0s

--- llama3.1:8b ---
Model: llama3.1:8b
Prompt: generic
Generated SQL: SELECT A.Name, COUNT(AlbumId) AS AlbumCount
FROM Artist A
JOIN Album ON A.ArtistId = Album.ArtistId
GROUP BY A.Name
Latency: 17.2s


{'generated_sql': 'SELECT A.Name, COUNT(AlbumId) AS AlbumCount\nFROM Artist A\nJOIN Album ON A.ArtistId = Album.ArtistId\nGROUP BY A.Name'}

In [12]:
# Cell 14: Node 3 — Validate Query
BLOCKED_KEYWORDS = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE"}

def validate_query(state: AgentState) -> dict:
    """Validate SQL with sqlglot and block write operations."""
    sql = state["generated_sql"]
    
    # Check for empty SQL
    if not sql.strip():
        print("INVALID: Empty SQL")
        return {"is_valid": False, "validation_error": "LLM returned empty SQL"}
    
    # Security check: block write operations
    sql_upper = sql.upper()
    for keyword in BLOCKED_KEYWORDS:
        if keyword in sql_upper.split():
            print(f"BLOCKED: {keyword} detected")
            return {"is_valid": False, "validation_error": f"Write operation blocked: {keyword}"}
    
    # Syntax check with sqlglot
    try:
        parsed = sqlglot.parse(sql, read="sqlite")
        if not parsed or parsed[0] is None:
            print("INVALID: sqlglot could not parse")
            return {"is_valid": False, "validation_error": "sqlglot failed to parse SQL"}
        print(f"VALID: {sql[:80]}...")
        return {"is_valid": True, "validation_error": ""}
    except sqlglot.errors.ParseError as e:
        print(f"INVALID: {e}")
        return {"is_valid": False, "validation_error": str(e)}

# Test valid query
print("--- Valid query ---")
validate_query({"generated_sql": "SELECT Name FROM Artist LIMIT 5"})
print()

# Test blocked query
print("--- Blocked query ---")
validate_query({"generated_sql": "DROP TABLE Artist"})
print()

# Test bad syntax
print("--- Bad syntax ---")
validate_query({"generated_sql": "SELEC Name FORM Artist"})


--- Valid query ---
VALID: SELECT Name FROM Artist LIMIT 5...

--- Blocked query ---
BLOCKED: DROP detected

--- Bad syntax ---
INVALID: Invalid expression / Unexpected token. Line 1, Col: 15.
  SELEC Name [4mFORM[0m Artist


{'is_valid': False,
 'validation_error': 'Invalid expression / Unexpected token. Line 1, Col: 15.\n  SELEC Name \x1b[4mFORM\x1b[0m Artist'}

In [13]:
# Cell 15: Node 4 — Execute Query
def execute_query(state: AgentState) -> dict:
    """Execute validated SQL against the database."""
    sql = state["generated_sql"]
    
    try:
        with engine.connect() as conn:
            rows = conn.execute(text(sql)).fetchmany(20)
            results = [list(row) for row in rows]
            print(f"Executed: {sql[:80]}")
            print(f"Rows returned: {len(results)}")
            for row in results[:5]:
                print(f"  {row}")
            if len(results) > 5:
                print(f"  ... ({len(results)} rows total)")
            return {"results": results, "error": ""}
    except Exception as e:
        print(f"Execution error: {e}")
        return {"results": None, "error": str(e)}

# Test with a real query
print("--- Valid execution ---")
execute_query({"generated_sql": "SELECT Name FROM Artist LIMIT 5"})
print()

# Test with a bad query (valid syntax but wrong column)
print("--- Runtime error ---")
execute_query({"generated_sql": "SELECT Foo FROM Artist"})


--- Valid execution ---
Executed: SELECT Name FROM Artist LIMIT 5
Rows returned: 5
  ['AC/DC']
  ['Accept']
  ['Aerosmith']
  ['Alanis Morissette']
  ['Alice In Chains']

--- Runtime error ---
Execution error: (sqlite3.OperationalError) no such column: Foo
[SQL: SELECT Foo FROM Artist]
(Background on this error at: https://sqlalche.me/e/20/e3q8)


{'results': None,
 'error': '(sqlite3.OperationalError) no such column: Foo\n[SQL: SELECT Foo FROM Artist]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'}

In [14]:
# Cell 16: Node 5 — Handle Error (Retry with Error Context)
ERROR_REPAIR_GENERIC = """The following SQL query produced an error. Fix it.

Schema:
{schema_text}

Original question: {question}

Failed SQL:
{generated_sql}

Error:
{error}

Return ONLY the corrected SQL query, no explanation.

SQL:"""

ERROR_REPAIR_SQLCODER = """### Task
The following SQL query produced an error. Fix the query to answer:
`{question}`

### Database Schema
{schema_text}

### Failed Query
{generated_sql}

### Error
{error}

### Corrected Answer
```sql
"""

def handle_error(state: AgentState) -> dict:
    """Feed error back to LLM for SQL repair."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = ERROR_REPAIR_SQLCODER
    else:
        template = ERROR_REPAIR_GENERIC
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
        generated_sql=state["generated_sql"],
        error=state["error"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    new_retry = state["retry_count"] + 1
    print(f"Retry {new_retry}: {sql[:80]}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql, "retry_count": new_retry, "error": ""}

# Test: simulate a failed query with wrong column name
test_error_state = {
    **test_state_filtered,
    "generated_sql": "SELECT Foo FROM Artist",
    "error": "(sqlite3.OperationalError) no such column: Foo",
    "retry_count": 0,
}
handle_error(test_error_state)


Retry 1: SELECT a."ArtistId", COUNT(a."AlbumId") AS "Number of Albums" FROM "Album" a GRO
Latency: 62.0s


{'generated_sql': 'SELECT a."ArtistId", COUNT(a."AlbumId") AS "Number of Albums" FROM "Album" a GROUP BY a."ArtistId"',
 'retry_count': 1,
 'error': ''}

In [15]:
# Cell 17: Wire LangGraph State Graph
from langgraph.graph import StateGraph, END

def should_retry(state: AgentState) -> str:
    """Route after execute_query: retry on error or finish."""
    if state["error"] and state["retry_count"] < 3:
        return "handle_error"
    return END

def check_validation(state: AgentState) -> str:
    """Route after validate_query: execute if valid, retry or stop."""
    if state["is_valid"]:
        return "execute_query"
    if state["retry_count"] < 3:
        return "handle_error"
    return END

# Build graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("schema_filter", schema_filter)
workflow.add_node("generate_sql", generate_sql)
workflow.add_node("validate_query", validate_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("handle_error", handle_error)

# Set entry point
workflow.set_entry_point("schema_filter")

# Add edges
workflow.add_edge("schema_filter", "generate_sql")
workflow.add_edge("generate_sql", "validate_query")
workflow.add_conditional_edges("validate_query", check_validation)
workflow.add_conditional_edges("execute_query", should_retry)
workflow.add_edge("handle_error", "validate_query")

# Compile
agent = workflow.compile()
print("Agent compiled successfully")
print(f"Nodes: {list(agent.nodes.keys())}")


Agent compiled successfully
Nodes: ['__start__', 'schema_filter', 'generate_sql', 'validate_query', 'execute_query', 'handle_error']


In [16]:
# Cell 18: End-to-End Test — Single Question
initial_state = {
    "question": "How many albums does each artist have? Show artist name and count, top 5.",
    "relevant_tables": [],
    "schema_text": "",
    "generated_sql": "",
    "is_valid": False,
    "validation_error": "",
    "results": None,
    "error": "",
    "retry_count": 0,
    "model_name": PRIMARY_MODEL,
}

print("=" * 60)
print(f"Question: {initial_state['question']}")
print(f"Model: {initial_state['model_name']}")
print("=" * 60)

t0 = time.time()
final_state = agent.invoke(initial_state)
total_time = time.time() - t0

print(f"\n{'=' * 60}")
print(f"FINAL RESULTS")
print(f"SQL: {final_state['generated_sql']}")
print(f"Valid: {final_state['is_valid']}")
print(f"Results: {final_state['results']}")
print(f"Error: {final_state['error']}")
print(f"Retries: {final_state['retry_count']}")
print(f"Total time: {total_time:.1f}s")


Question: How many albums does each artist have? Show artist name and count, top 5.
Model: sqlcoder:7b
Question: How many albums does each artist have? Show artist name and count, top 5.
Selected tables (5): ['Artist', 'Album', 'Customer', 'Employee', 'Genre']
Schema text length: 1159 chars
Model: sqlcoder:7b
Prompt: sqlcoder
Generated SQL: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album b ON a.ArtistId = b.ArtistId GROUP BY a."Name" ORDER BY "Number of Albums" DESC NULLS LAST LIMIT 5
Latency: 19.0s
VALID: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album...
Executed: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album
Rows returned: 5
  ['Iron Maiden', 21]
  ['Led Zeppelin', 14]
  ['Deep Purple', 11]
  ['U2', 10]
  ['Metallica', 10]

FINAL RESULTS
SQL: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album b ON a.ArtistId = b.ArtistId GROUP BY a."Name" ORDER BY "Numbe

In [17]:
# Cell 19: End-to-End Test — Multiple Questions
test_questions = [
    "List all genres",
    "What are the top 3 customers by total spending?",
    "Find all tracks by AC/DC",
]

for q in test_questions:
    state = {
        "question": q,
        "relevant_tables": [],
        "schema_text": "",
        "generated_sql": "",
        "is_valid": False,
        "validation_error": "",
        "results": None,
        "error": "",
        "retry_count": 0,
        "model_name": PRIMARY_MODEL,
    }
    
    print("=" * 60)
    print(f"Q: {q}")
    t0 = time.time()
    result = agent.invoke(state)
    elapsed = time.time() - t0
    print(f"\nSQL: {result['generated_sql']}")
    print(f"Results: {result['results']}")
    print(f"Retries: {result['retry_count']} | Time: {elapsed:.1f}s")
    print()


Q: List all genres
Question: List all genres
Selected tables (4): ['Genre', 'Playlist', 'PlaylistTrack', 'Track']
Schema text length: 523 chars
Model: sqlcoder:7b
Prompt: sqlcoder
Generated SQL: SELECT Genre.Name FROM Genre
Latency: 4.9s
VALID: SELECT Genre.Name FROM Genre...
Executed: SELECT Genre.Name FROM Genre
Rows returned: 20
  ['Rock']
  ['Jazz']
  ['Metal']
  ['Alternative & Punk']
  ['Rock And Roll']
  ... (20 rows total)

SQL: SELECT Genre.Name FROM Genre
Results: [['Rock'], ['Jazz'], ['Metal'], ['Alternative & Punk'], ['Rock And Roll'], ['Blues'], ['Latin'], ['Reggae'], ['Pop'], ['Soundtrack'], ['Bossa Nova'], ['Easy Listening'], ['Heavy Metal'], ['R&B/Soul'], ['Electronica/Dance'], ['World'], ['Hip Hop/Rap'], ['Science Fiction'], ['TV Shows'], ['Sci Fi & Fantasy']]
Retries: 0 | Time: 5.0s

Q: What are the top 3 customers by total spending?
Question: What are the top 3 customers by total spending?
Selected tables (7): ['Customer', 'Invoice', 'Track', 'Employee', 'MediaType',

In [18]:
# Cell 20: Improved Prompts — SQLite Rules & Schema in Error Repair
GENERIC_PROMPT = """You are a SQL expert. Generate a SQLite-compatible SELECT query for the question below.

Schema:
{schema_text}

Rules:
- Return ONLY the SQL query, no explanation
- Use only SELECT statements
- Use only tables and columns from the schema above
- Use exact column names as shown in the schema (case-sensitive)
- SQLite syntax only: use LIKE not ILIKE, no NULLS FIRST/LAST
- For case-insensitive matching use: LOWER(column) LIKE LOWER('%value%')

Question: {question}

SQL:"""

SQLCODER_PROMPT = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
{schema_text}

### Rules
- SQLite dialect only
- Use exact column names from schema (case-sensitive: FirstName not first_name)
- Use LIKE not ILIKE (SQLite has no ILIKE)
- No NULLS FIRST/LAST

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql
"""

ERROR_REPAIR_GENERIC = """The following SQL query produced an error. Fix it using the schema below.

Schema:
{schema_text}

Question: {question}

Failed SQL:
{generated_sql}

Error:
{error}

Rules:
- Use exact column names from the schema above (case-sensitive)
- SQLite syntax only: use LIKE not ILIKE
- Return ONLY the corrected SQL query, no explanation

SQL:"""

ERROR_REPAIR_SQLCODER = """### Task
The following SQL query produced an error. Fix the query to answer:
`{question}`

### Database Schema
{schema_text}

### Failed Query
{generated_sql}

### Error
{error}

### Rules
- Use exact column names from the schema (case-sensitive: FirstName not first_name)
- SQLite dialect only: use LIKE not ILIKE
- No NULLS FIRST/LAST

### Corrected Answer
```sql
"""

# Update handle_error to use new templates
def handle_error(state: AgentState) -> dict:
    """Feed error back to LLM for SQL repair."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = ERROR_REPAIR_SQLCODER
    else:
        template = ERROR_REPAIR_GENERIC
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
        generated_sql=state["generated_sql"],
        error=state["error"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    new_retry = state["retry_count"] + 1
    print(f"Retry {new_retry}: {sql[:80]}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql, "retry_count": new_retry, "error": ""}

# Also update generate_sql to use new templates
def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = SQLCODER_PROMPT
    else:
        template = GENERIC_PROMPT
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    print(f"Model: {model}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Recompile the graph with updated functions
workflow = StateGraph(AgentState)
workflow.add_node("schema_filter", schema_filter)
workflow.add_node("generate_sql", generate_sql)
workflow.add_node("validate_query", validate_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("handle_error", handle_error)
workflow.set_entry_point("schema_filter")
workflow.add_edge("schema_filter", "generate_sql")
workflow.add_edge("generate_sql", "validate_query")
workflow.add_conditional_edges("validate_query", check_validation)
workflow.add_conditional_edges("execute_query", should_retry)
workflow.add_edge("handle_error", "validate_query")
agent = workflow.compile()

print("Agent recompiled with improved prompts")
print("Changes: exact column name rules, LIKE not ILIKE, no NULLS LAST, schema in error repair")


Agent recompiled with improved prompts
Changes: exact column name rules, LIKE not ILIKE, no NULLS LAST, schema in error repair


In [19]:
# Cell 21: Retest Previously Failed Questions
failed_questions = [
    "What are the top 3 customers by total spending?",
    "Find all tracks by AC/DC",
]

for q in failed_questions:
    state = {
        "question": q,
        "relevant_tables": [],
        "schema_text": "",
        "generated_sql": "",
        "is_valid": False,
        "validation_error": "",
        "results": None,
        "error": "",
        "retry_count": 0,
        "model_name": PRIMARY_MODEL,
    }
    
    print("=" * 60)
    print(f"Q: {q}")
    t0 = time.time()
    result = agent.invoke(state)
    elapsed = time.time() - t0
    print(f"\nSQL: {result['generated_sql']}")
    print(f"Results: {result['results']}")
    print(f"Retries: {result['retry_count']} | Time: {elapsed:.1f}s")
    print()


Q: What are the top 3 customers by total spending?
Question: What are the top 3 customers by total spending?
Selected tables (7): ['Customer', 'Invoice', 'Track', 'Employee', 'MediaType', 'Genre', 'Album']
Schema text length: 1839 chars
Model: sqlcoder:7b
Generated SQL: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FROM customer p JOIN invoice i ON p.customerid = i.customerid GROUP BY p.first_name, p.last_name ORDER BY total_spent DESC NULLS LAST LIMIT 3
Latency: 24.2s
VALID: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FROM customer p JO...
Execution error: (sqlite3.OperationalError) no such column: p.first_name
[SQL: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FROM customer p JOIN invoice i ON p.customerid = i.customerid GROUP BY p.first_name, p.last_name ORDER BY total_spent DESC NULLS LAST LIMIT 3]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Retry 1: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FRO

In [20]:
# Cell 22: SQL Post-Processing for SQLite Compatibility
import re

def build_column_map():
    """Build case-insensitive column name mapping from schema_info."""
    col_map = {}  # lowercase -> actual name
    for table_name, info in schema_info.items():
        for col in info["columns"]:
            col_map[col["name"].lower()] = col["name"]
        # Also map table names
        col_map[table_name.lower()] = table_name
    return col_map

COLUMN_MAP = build_column_map()

def postprocess_sql(sql: str) -> str:
    """Fix known SQLite incompatibilities in generated SQL."""
    original = sql
    
    # 1. Replace ILIKE with LIKE (SQLite LIKE is case-insensitive for ASCII)
    sql = re.sub(r'\bILIKE\b', 'LIKE', sql, flags=re.IGNORECASE)
    
    # 2. Remove NULLS FIRST / NULLS LAST
    sql = re.sub(r'\s+NULLS\s+(FIRST|LAST)\b', '', sql, flags=re.IGNORECASE)
    
    # 3. Fix column name casing (snake_case -> PascalCase)
    def replace_identifier(match):
        word = match.group(0)
        # Don't replace SQL keywords or aliases
        sql_keywords = {
            'SELECT', 'FROM', 'WHERE', 'JOIN', 'ON', 'GROUP', 'BY', 'ORDER',
            'HAVING', 'LIMIT', 'AS', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'IS',
            'NULL', 'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'DESC', 'ASC',
            'DISTINCT', 'BETWEEN', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END',
            'INNER', 'LEFT', 'RIGHT', 'OUTER', 'UNION', 'ALL', 'EXISTS',
            'CAST', 'LOWER', 'UPPER', 'LENGTH', 'SUBSTR', 'TRIM',
        }
        if word.upper() in sql_keywords:
            return word
        lookup = word.lower().replace('"', '').replace("'", '')
        if lookup in COLUMN_MAP:
            return COLUMN_MAP[lookup]
        return word
    
    # Match identifiers (possibly quoted)
    sql = re.sub(r'"?\b[A-Za-z_]\w*\b"?', replace_identifier, sql)
    
    if sql != original:
        changes = []
        if 'ILIKE' in original.upper() and 'ILIKE' not in sql.upper():
            changes.append("ILIKE->LIKE")
        if re.search(r'NULLS\s+(FIRST|LAST)', original, re.IGNORECASE):
            changes.append("removed NULLS FIRST/LAST")
        if sql.lower() != original.lower():
            changes.append("fixed column casing")
        print(f"  Post-processed: {', '.join(changes)}")
    
    return sql

# Test it
test_cases = [
    "SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST",
    "SELECT Track.Name FROM Track WHERE Composer ILIKE '%AC/DC%'",
    "SELECT Name FROM Artist LIMIT 5",  # should be unchanged
]

for sql in test_cases:
    print(f"  Input:  {sql}")
    fixed = postprocess_sql(sql)
    print(f"  Output: {fixed}")
    print()


  Input:  SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST
  Post-processed: removed NULLS FIRST/LAST, fixed column casing
  Output: SELECT p.first_name, p.last_name FROM Customer p ORDER BY p.last_name

  Input:  SELECT Track.Name FROM Track WHERE Composer ILIKE '%AC/DC%'
  Post-processed: ILIKE->LIKE, fixed column casing
  Output: SELECT Track.Name FROM Track WHERE Composer LIKE '%AC/DC%'

  Input:  SELECT Name FROM Artist LIMIT 5
  Output: SELECT Name FROM Artist LIMIT 5



In [21]:
# Cell 23: Fix Column Mapping — Handle snake_case to PascalCase
def build_column_map():
    """Build case-insensitive column name mapping, including snake_case variants."""
    col_map = {}
    for table_name, info in schema_info.items():
        for col in info["columns"]:
            actual = col["name"]
            # Map lowercase version
            col_map[actual.lower()] = actual
            # Map snake_case version (e.g., first_name -> FirstName)
            snake = re.sub(r'(?<!^)(?=[A-Z])', '_', actual).lower()
            col_map[snake] = actual
        # Table names
        col_map[table_name.lower()] = table_name
    return col_map

COLUMN_MAP = build_column_map()

# Verify the fix
print("Key mappings:")
for key in ["first_name", "firstname", "last_name", "lastname", 
            "customerid", "customer_id", "artistid", "artist_id",
            "supportrepid", "support_rep_id", "billingcountry", "billing_country"]:
    actual = COLUMN_MAP.get(key, "NOT FOUND")
    print(f"  {key:20s} -> {actual}")

# Retest postprocess
print("\nPost-process test:")
sql = "SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST"
print(f"  Input:  {sql}")
print(f"  Output: {postprocess_sql(sql)}")


Key mappings:
  first_name           -> FirstName
  firstname            -> FirstName
  last_name            -> LastName
  lastname             -> LastName
  customerid           -> CustomerId
  customer_id          -> CustomerId
  artistid             -> ArtistId
  artist_id            -> ArtistId
  supportrepid         -> SupportRepId
  support_rep_id       -> SupportRepId
  billingcountry       -> BillingCountry
  billing_country      -> BillingCountry

Post-process test:
  Input:  SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST
  Post-processed: removed NULLS FIRST/LAST, fixed column casing
  Output: SELECT p.FirstName, p.LastName FROM Customer p ORDER BY p.LastName
