# SQL Query Agent with Ollama - Exploration Notebook

**Sprint 1 Deliverable** | **DSM 1.0 Track**

This notebook builds and evaluates a natural language to SQL agent using open-source LLMs running locally via Ollama. It serves as a text-to-code generation testbed -- SQL is a constrained language ideal for systematic evaluation.

**Architecture:** LangGraph state graph with schema filtering, SQL validation (sqlglot), and error-correction retry loop. Design informed by DIN-SQL, MAC-SQL, and CHESS research.

**Goals:**
1. Verify environment (Ollama, models, database)
2. Build LangGraph agent: `schema_filter` → `generate_sql` → `validate_query` → `execute_query` → `handle_error`
3. Evaluate at least 2 models (sqlcoder:7b vs llama3.1:8b) on a curated test suite

**References:**
- Research: `docs/research/text_to_sql_state_of_art.md`
- Plan: `docs/plans/PLAN.md`
- Sprint plan: `docs/plans/sprint-1-plan.md`


In [1]:
# cell 2: Imports & Configuration
from sqlalchemy import create_engine, inspect, text
from langchain_ollama import ChatOllama
import sqlglot
import time

# Configuration
OLLAMA_BASE_URL = "http://172.27.64.1:11434"  # Ollama on Windows, accessed via WSL gateway
DB_PATH = "../data/chinook.db"
PRIMARY_MODEL = "sqlcoder:7b"
BASELINE_MODEL = "llama3.1:8b"

# Database engine
engine = create_engine(f"sqlite:///{DB_PATH}")
print(f"Database engine created: {engine.url}")


Database engine created: sqlite:///../data/chinook.db


In [2]:
# Cell 3: Database Schema Inspection

inspector = inspect(engine)
tables = inspector.get_table_names()

print(f"Chinook Database: {len(tables)} tables\n")

schema_info = {}
for table_name in tables:
    columns = inspector.get_columns(table_name)
    pk = inspector.get_pk_constraint(table_name)
    fks = inspector.get_foreign_keys(table_name)
    
    schema_info[table_name] = {
        "columns": columns,
        "pk": pk.get("constrained_columns", []),
        "fks": fks,
    }
    
    # Display
    pk_cols = set(pk.get("constrained_columns", []))
    print(f"{table_name}")
    for col in columns:
        marker = " (PK)" if col["name"] in pk_cols else ""
        print(f"   {col['name']}: {col['type']}{marker}")
    for fk in fks:
        print(f"   FK: {fk['constrained_columns']} -> {fk['referred_table']}.{fk['referred_columns']}")
    print()


Chinook Database: 11 tables

Album
   AlbumId: INTEGER (PK)
   Title: NVARCHAR(160)
   ArtistId: INTEGER
   FK: ['ArtistId'] -> Artist.['ArtistId']

Artist
   ArtistId: INTEGER (PK)
   Name: NVARCHAR(120)

Customer
   CustomerId: INTEGER (PK)
   FirstName: NVARCHAR(40)
   LastName: NVARCHAR(20)
   Company: NVARCHAR(80)
   Address: NVARCHAR(70)
   City: NVARCHAR(40)
   State: NVARCHAR(40)
   Country: NVARCHAR(40)
   PostalCode: NVARCHAR(10)
   Phone: NVARCHAR(24)
   Fax: NVARCHAR(24)
   Email: NVARCHAR(60)
   SupportRepId: INTEGER
   FK: ['SupportRepId'] -> Employee.['EmployeeId']

Employee
   EmployeeId: INTEGER (PK)
   LastName: NVARCHAR(20)
   FirstName: NVARCHAR(20)
   Title: NVARCHAR(30)
   ReportsTo: INTEGER
   BirthDate: DATETIME
   HireDate: DATETIME
   Address: NVARCHAR(70)
   City: NVARCHAR(40)
   State: NVARCHAR(40)
   Country: NVARCHAR(40)
   PostalCode: NVARCHAR(10)
   Phone: NVARCHAR(24)
   Fax: NVARCHAR(24)
   Email: NVARCHAR(60)
   FK: ['ReportsTo'] -> Employee.['Employe

In [3]:
# Cell 4: Row Counts

print("Row counts:")
with engine.connect() as conn:
    for table_name in tables:
        count = conn.execute(text(f"SELECT COUNT(*) FROM [{table_name}]")).scalar()
        print(f"  {table_name}: {count:,} rows")


Row counts:
  Album: 347 rows
  Artist: 275 rows
  Customer: 59 rows
  Employee: 8 rows
  Genre: 25 rows
  Invoice: 412 rows
  InvoiceLine: 2,240 rows
  MediaType: 5 rows
  Playlist: 18 rows
  PlaylistTrack: 8,715 rows
  Track: 3,503 rows


In [4]:
# Cell 5: Verify Ollama Connectivity & Available Models
import requests

try:
    response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=10)
    response.raise_for_status()
    models = response.json().get("models", [])
    print(f"Ollama is running at {OLLAMA_BASE_URL}")
    print(f"Available models: {len(models)}\n")
    for m in models:
        size_gb = m.get("size", 0) / (1024**3)
        print(f"  {m['name']:30s} {size_gb:.1f} GB")
except requests.ConnectionError:
    print(f"ERROR: Cannot connect to Ollama at {OLLAMA_BASE_URL}")
    print("Make sure Ollama is running on the Windows host.")
except Exception as e:
    print(f"ERROR: {e}")


Ollama is running at http://172.27.64.1:11434
Available models: 4

  llama3.1:8b                    4.6 GB
  sqlcoder:7b                    3.8 GB
  gemma3:1b                      0.8 GB
  llama3:latest                  4.3 GB


In [5]:
# Cell 6: Test LLM Connectivity
for model_name in [PRIMARY_MODEL, BASELINE_MODEL]:
    print(f"Testing {model_name}...")
    llm = ChatOllama(model=model_name, base_url=OLLAMA_BASE_URL, temperature=0)
    t0 = time.time()
    response = llm.invoke("Return only the SQL: SELECT 1")
    elapsed = time.time() - t0
    print(f"  Response: {response.content.strip()[:100]}")
    print(f"  Latency: {elapsed:.1f}s\n")


Testing sqlcoder:7b...
  Response: AS "column_name" FROM DUAL UNION ALL SELECT 2 AS "column_name" FROM DUAL UNION ALL SELECT 3 AS "colu
  Latency: 31.9s

Testing llama3.1:8b...
  Response: `SELECT 1;`
  Latency: 10.9s



In [6]:
# Cell 7: Sample Rows for Few-Shot Prompting
sample_tables = ["Artist", "Album", "Track", "Customer", "Invoice", "InvoiceLine"]

sample_rows = {}
with engine.connect() as conn:
    for table_name in sample_tables:
        rows = conn.execute(text(f"SELECT * FROM [{table_name}] LIMIT 3")).fetchall()
        keys = conn.execute(text(f"SELECT * FROM [{table_name}] LIMIT 1")).keys()
        sample_rows[table_name] = {"columns": list(keys), "rows": rows}
        
        print(f"{table_name}:")
        print(f"  Columns: {list(keys)}")
        for row in rows:
            print(f"  {list(row)}")
        print()


Artist:
  Columns: ['ArtistId', 'Name']
  [1, 'AC/DC']
  [2, 'Accept']
  [3, 'Aerosmith']

Album:
  Columns: ['AlbumId', 'Title', 'ArtistId']
  [1, 'For Those About To Rock We Salute You', 1]
  [2, 'Balls to the Wall', 2]
  [3, 'Restless and Wild', 2]

Track:
  Columns: ['TrackId', 'Name', 'AlbumId', 'MediaTypeId', 'GenreId', 'Composer', 'Milliseconds', 'Bytes', 'UnitPrice']
  [1, 'For Those About To Rock (We Salute You)', 1, 1, 1, 'Angus Young, Malcolm Young, Brian Johnson', 343719, 11170334, 0.99]
  [2, 'Balls to the Wall', 2, 2, 1, 'U. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann', 342562, 5510424, 0.99]
  [3, 'Fast As a Shark', 3, 2, 1, 'F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman', 230619, 3990994, 0.99]

Customer:
  Columns: ['CustomerId', 'FirstName', 'LastName', 'Company', 'Address', 'City', 'State', 'Country', 'PostalCode', 'Phone', 'Fax', 'Email', 'SupportRepId']
  [1, 'Luís', 'Gonçalves', 'Embraer - Empresa Brasileira de Aeronáutica S.A.'

## Phase 2: Core Agent Build

LangGraph agent with 5 nodes:
1. **schema_filter** — select relevant tables for the question
2. **generate_sql** — LLM generates SQL from question + filtered schema
3. **validate_query** — sqlglot parses SQL, security check (block writes)
4. **execute_query** — run against SQLite, capture results or errors
5. **handle_error** — on failure, feed error back to LLM for retry (max 3)


In [7]:
# Cell 9: Agent State Definition
from typing import TypedDict, Optional

class AgentState(TypedDict):
    question: str                    # User's natural language question
    relevant_tables: list[str]       # Tables selected by schema_filter
    schema_text: str                 # DDL/schema for relevant tables
    generated_sql: str               # SQL produced by LLM
    is_valid: bool                   # sqlglot parse + security check passed
    validation_error: str            # Error message if validation fails
    results: Optional[list]          # Query results from SQLite
    error: str                       # Execution error message
    retry_count: int                 # Current retry attempt (max 3)
    model_name: str                  # Which Ollama model to use

print("AgentState defined with fields:")
for field, ftype in AgentState.__annotations__.items():
    print(f"  {field}: {ftype}")


AgentState defined with fields:
  question: <class 'str'>
  relevant_tables: list[str]
  schema_text: <class 'str'>
  generated_sql: <class 'str'>
  is_valid: <class 'bool'>
  validation_error: <class 'str'>
  results: typing.Optional[list]
  error: <class 'str'>
  retry_count: <class 'int'>
  model_name: <class 'str'>


In [8]:
# Cell 10: Node 1 — Schema Filter
def schema_filter(state: AgentState) -> dict:
    """Select relevant tables based on question keywords."""
    question_lower = state["question"].lower()
    question_words = set(question_lower.replace("?", "").replace(",", "").split())
    
    scored_tables = []
    for table_name, info in schema_info.items():
        score = 0
        table_lower = table_name.lower()
        
        # Table name match (strongest signal)
        if table_lower in question_lower:
            score += 3
        # Partial table name match
        for word in question_words:
            if word in table_lower or table_lower in word:
                score += 2
        # Column name match
        for col in info["columns"]:
            col_lower = col["name"].lower()
            if col_lower in question_lower:
                score += 1
            for word in question_words:
                if word in col_lower:
                    score += 0.5
        
        if score > 0:
            scored_tables.append((table_name, score))
    
    # Sort by score, take top tables; always include at least FK-connected tables
    scored_tables.sort(key=lambda x: x[1], reverse=True)
    selected = [t[0] for t in scored_tables[:5]]  # max 5 tables
    
    # Add FK-connected tables for selected tables
    for table_name in list(selected):
        for fk in schema_info[table_name]["fks"]:
            referred = fk["referred_table"]
            if referred not in selected:
                selected.append(referred)
    
    # Fallback: if nothing matched, include all tables
    if not selected:
        selected = list(schema_info.keys())
    
    # Build schema text for selected tables
    schema_lines = []
    for table_name in selected:
        info = schema_info[table_name]
        cols = ", ".join(
            f"{c['name']} {c['type']}" for c in info["columns"]
        )
        schema_lines.append(f"CREATE TABLE {table_name} ({cols});")
        # Add sample rows
        if table_name in sample_rows:
            sr = sample_rows[table_name]
            schema_lines.append(f"-- Sample: {sr['rows'][0]}")
    
    schema_text = "\n".join(schema_lines)
    
    print(f"Question: {state['question']}")
    print(f"Selected tables ({len(selected)}): {selected}")
    print(f"Schema text length: {len(schema_text)} chars")
    
    return {"relevant_tables": selected, "schema_text": schema_text}

# Quick test
test_state = {
    "question": "How many albums does each artist have?",
    "relevant_tables": [], "schema_text": "", "generated_sql": "",
    "is_valid": False, "validation_error": "", "results": None,
    "error": "", "retry_count": 0, "model_name": PRIMARY_MODEL,
}
schema_filter(test_state)


Question: How many albums does each artist have?
Selected tables (2): ['Album', 'Artist']
Schema text length: 219 chars


{'relevant_tables': ['Album', 'Artist'],
 'schema_text': "CREATE TABLE Album (AlbumId INTEGER, Title NVARCHAR(160), ArtistId INTEGER);\n-- Sample: (1, 'For Those About To Rock We Salute You', 1)\nCREATE TABLE Artist (ArtistId INTEGER, Name NVARCHAR(120));\n-- Sample: (1, 'AC/DC')"}

In [9]:
# Cell 11: Node 2 — Generate SQL
SQL_PROMPT_TEMPLATE = """You are a SQL expert. Generate a SQLite-compatible SELECT query for the question below.

Schema:
{schema_text}

Rules:
- Return ONLY the SQL query, no explanation
- Use only SELECT statements
- Use only tables and columns from the schema above
- Use SQLite syntax

Question: {question}

SQL:"""

def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM."""
    prompt = SQL_PROMPT_TEMPLATE.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(
        model=state["model_name"],
        base_url=OLLAMA_BASE_URL,
        temperature=0,
    )
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    # Clean response: strip markdown fences and whitespace
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    
    print(f"Model: {state['model_name']}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Quick test
test_state_filtered = {**test_state, **schema_filter(test_state)}
generate_sql(test_state_filtered)


Question: How many albums does each artist have?
Selected tables (2): ['Album', 'Artist']
Schema text length: 219 chars
Model: sqlcoder:7b
Generated SQL: 
Latency: 8.1s


{'generated_sql': ''}

In [10]:
# Cell 12: Diagnose sqlcoder prompt format
SQLCODER_PROMPT_TEMPLATE = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
{schema_text}

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql
"""

prompt_generic = SQL_PROMPT_TEMPLATE.format(
    schema_text=test_state_filtered["schema_text"],
    question=test_state_filtered["question"],
)
prompt_sqlcoder = SQLCODER_PROMPT_TEMPLATE.format(
    schema_text=test_state_filtered["schema_text"],
    question=test_state_filtered["question"],
)

llm = ChatOllama(model=PRIMARY_MODEL, base_url=OLLAMA_BASE_URL, temperature=0)

print("=== Generic prompt ===")
t0 = time.time()
r1 = llm.invoke(prompt_generic)
print(f"Response: '{r1.content.strip()[:200]}'")
print(f"Latency: {time.time() - t0:.1f}s\n")

print("=== sqlcoder-style prompt ===")
t0 = time.time()
r2 = llm.invoke(prompt_sqlcoder)
print(f"Response: '{r2.content.strip()[:200]}'")
print(f"Latency: {time.time() - t0:.1f}s")


=== Generic prompt ===
Response: ''
Latency: 22.5s

=== sqlcoder-style prompt ===
Response: 'SELECT a.artistid, COUNT(*) AS num_albums FROM album a GROUP BY a.artistid;
```'
Latency: 5.5s


In [11]:
# Cell 13: Node 2 (revised) — Generate SQL with model-aware prompts
GENERIC_PROMPT = """You are a SQL expert. Generate a SQLite-compatible SELECT query for the question below.

Schema:
{schema_text}

Rules:
- Return ONLY the SQL query, no explanation
- Use only SELECT statements
- Use only tables and columns from the schema above
- Use SQLite syntax

Question: {question}

SQL:"""

SQLCODER_PROMPT = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
{schema_text}

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql
"""

def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM."""
    model = state["model_name"]
    
    # Select prompt template based on model
    if "sqlcoder" in model:
        template = SQLCODER_PROMPT
    else:
        template = GENERIC_PROMPT
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    # Clean response: strip markdown fences and whitespace
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    # Remove trailing semicolons for consistency
    sql = sql.rstrip(";").strip()
    
    print(f"Model: {model}")
    print(f"Prompt: {'sqlcoder' if 'sqlcoder' in model else 'generic'}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Test with both models
test_state_filtered = {**test_state, **schema_filter(test_state)}
print("--- sqlcoder:7b ---")
generate_sql(test_state_filtered)
print()
print("--- llama3.1:8b ---")
generate_sql({**test_state_filtered, "model_name": BASELINE_MODEL})


Question: How many albums does each artist have?
Selected tables (2): ['Album', 'Artist']
Schema text length: 219 chars
--- sqlcoder:7b ---
Model: sqlcoder:7b
Prompt: sqlcoder
Generated SQL: SELECT a.artistid, COUNT(*) AS num_albums FROM album a GROUP BY a.artistid
Latency: 4.0s

--- llama3.1:8b ---
Model: llama3.1:8b
Prompt: generic
Generated SQL: SELECT A.Name, COUNT(AlbumId) AS AlbumCount
FROM Artist A
JOIN Album ON A.ArtistId = Album.ArtistId
GROUP BY A.Name
Latency: 17.2s


{'generated_sql': 'SELECT A.Name, COUNT(AlbumId) AS AlbumCount\nFROM Artist A\nJOIN Album ON A.ArtistId = Album.ArtistId\nGROUP BY A.Name'}

In [12]:
# Cell 14: Node 3 — Validate Query
BLOCKED_KEYWORDS = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE"}

def validate_query(state: AgentState) -> dict:
    """Validate SQL with sqlglot and block write operations."""
    sql = state["generated_sql"]
    
    # Check for empty SQL
    if not sql.strip():
        print("INVALID: Empty SQL")
        return {"is_valid": False, "validation_error": "LLM returned empty SQL"}
    
    # Security check: block write operations
    sql_upper = sql.upper()
    for keyword in BLOCKED_KEYWORDS:
        if keyword in sql_upper.split():
            print(f"BLOCKED: {keyword} detected")
            return {"is_valid": False, "validation_error": f"Write operation blocked: {keyword}"}
    
    # Syntax check with sqlglot
    try:
        parsed = sqlglot.parse(sql, read="sqlite")
        if not parsed or parsed[0] is None:
            print("INVALID: sqlglot could not parse")
            return {"is_valid": False, "validation_error": "sqlglot failed to parse SQL"}
        print(f"VALID: {sql[:80]}...")
        return {"is_valid": True, "validation_error": ""}
    except sqlglot.errors.ParseError as e:
        print(f"INVALID: {e}")
        return {"is_valid": False, "validation_error": str(e)}

# Test valid query
print("--- Valid query ---")
validate_query({"generated_sql": "SELECT Name FROM Artist LIMIT 5"})
print()

# Test blocked query
print("--- Blocked query ---")
validate_query({"generated_sql": "DROP TABLE Artist"})
print()

# Test bad syntax
print("--- Bad syntax ---")
validate_query({"generated_sql": "SELEC Name FORM Artist"})


--- Valid query ---
VALID: SELECT Name FROM Artist LIMIT 5...

--- Blocked query ---
BLOCKED: DROP detected

--- Bad syntax ---
INVALID: Invalid expression / Unexpected token. Line 1, Col: 15.
  SELEC Name [4mFORM[0m Artist


{'is_valid': False,
 'validation_error': 'Invalid expression / Unexpected token. Line 1, Col: 15.\n  SELEC Name \x1b[4mFORM\x1b[0m Artist'}

In [13]:
# Cell 15: Node 4 — Execute Query
def execute_query(state: AgentState) -> dict:
    """Execute validated SQL against the database."""
    sql = state["generated_sql"]
    
    try:
        with engine.connect() as conn:
            rows = conn.execute(text(sql)).fetchmany(20)
            results = [list(row) for row in rows]
            print(f"Executed: {sql[:80]}")
            print(f"Rows returned: {len(results)}")
            for row in results[:5]:
                print(f"  {row}")
            if len(results) > 5:
                print(f"  ... ({len(results)} rows total)")
            return {"results": results, "error": ""}
    except Exception as e:
        print(f"Execution error: {e}")
        return {"results": None, "error": str(e)}

# Test with a real query
print("--- Valid execution ---")
execute_query({"generated_sql": "SELECT Name FROM Artist LIMIT 5"})
print()

# Test with a bad query (valid syntax but wrong column)
print("--- Runtime error ---")
execute_query({"generated_sql": "SELECT Foo FROM Artist"})


--- Valid execution ---
Executed: SELECT Name FROM Artist LIMIT 5
Rows returned: 5
  ['AC/DC']
  ['Accept']
  ['Aerosmith']
  ['Alanis Morissette']
  ['Alice In Chains']

--- Runtime error ---
Execution error: (sqlite3.OperationalError) no such column: Foo
[SQL: SELECT Foo FROM Artist]
(Background on this error at: https://sqlalche.me/e/20/e3q8)


{'results': None,
 'error': '(sqlite3.OperationalError) no such column: Foo\n[SQL: SELECT Foo FROM Artist]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'}

In [14]:
# Cell 16: Node 5 — Handle Error (Retry with Error Context)
ERROR_REPAIR_GENERIC = """The following SQL query produced an error. Fix it.

Schema:
{schema_text}

Original question: {question}

Failed SQL:
{generated_sql}

Error:
{error}

Return ONLY the corrected SQL query, no explanation.

SQL:"""

ERROR_REPAIR_SQLCODER = """### Task
The following SQL query produced an error. Fix the query to answer:
`{question}`

### Database Schema
{schema_text}

### Failed Query
{generated_sql}

### Error
{error}

### Corrected Answer
```sql
"""

def handle_error(state: AgentState) -> dict:
    """Feed error back to LLM for SQL repair."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = ERROR_REPAIR_SQLCODER
    else:
        template = ERROR_REPAIR_GENERIC
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
        generated_sql=state["generated_sql"],
        error=state["error"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    new_retry = state["retry_count"] + 1
    print(f"Retry {new_retry}: {sql[:80]}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql, "retry_count": new_retry, "error": ""}

# Test: simulate a failed query with wrong column name
test_error_state = {
    **test_state_filtered,
    "generated_sql": "SELECT Foo FROM Artist",
    "error": "(sqlite3.OperationalError) no such column: Foo",
    "retry_count": 0,
}
handle_error(test_error_state)


Retry 1: SELECT a."ArtistId", COUNT(a."AlbumId") AS "Number of Albums" FROM "Album" a GRO
Latency: 62.0s


{'generated_sql': 'SELECT a."ArtistId", COUNT(a."AlbumId") AS "Number of Albums" FROM "Album" a GROUP BY a."ArtistId"',
 'retry_count': 1,
 'error': ''}

In [15]:
# Cell 17: Wire LangGraph State Graph
from langgraph.graph import StateGraph, END

def should_retry(state: AgentState) -> str:
    """Route after execute_query: retry on error or finish."""
    if state["error"] and state["retry_count"] < 3:
        return "handle_error"
    return END

def check_validation(state: AgentState) -> str:
    """Route after validate_query: execute if valid, retry or stop."""
    if state["is_valid"]:
        return "execute_query"
    if state["retry_count"] < 3:
        return "handle_error"
    return END

# Build graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("schema_filter", schema_filter)
workflow.add_node("generate_sql", generate_sql)
workflow.add_node("validate_query", validate_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("handle_error", handle_error)

# Set entry point
workflow.set_entry_point("schema_filter")

# Add edges
workflow.add_edge("schema_filter", "generate_sql")
workflow.add_edge("generate_sql", "validate_query")
workflow.add_conditional_edges("validate_query", check_validation)
workflow.add_conditional_edges("execute_query", should_retry)
workflow.add_edge("handle_error", "validate_query")

# Compile
agent = workflow.compile()
print("Agent compiled successfully")
print(f"Nodes: {list(agent.nodes.keys())}")


Agent compiled successfully
Nodes: ['__start__', 'schema_filter', 'generate_sql', 'validate_query', 'execute_query', 'handle_error']


In [16]:
# Cell 18: End-to-End Test — Single Question
initial_state = {
    "question": "How many albums does each artist have? Show artist name and count, top 5.",
    "relevant_tables": [],
    "schema_text": "",
    "generated_sql": "",
    "is_valid": False,
    "validation_error": "",
    "results": None,
    "error": "",
    "retry_count": 0,
    "model_name": PRIMARY_MODEL,
}

print("=" * 60)
print(f"Question: {initial_state['question']}")
print(f"Model: {initial_state['model_name']}")
print("=" * 60)

t0 = time.time()
final_state = agent.invoke(initial_state)
total_time = time.time() - t0

print(f"\n{'=' * 60}")
print(f"FINAL RESULTS")
print(f"SQL: {final_state['generated_sql']}")
print(f"Valid: {final_state['is_valid']}")
print(f"Results: {final_state['results']}")
print(f"Error: {final_state['error']}")
print(f"Retries: {final_state['retry_count']}")
print(f"Total time: {total_time:.1f}s")


Question: How many albums does each artist have? Show artist name and count, top 5.
Model: sqlcoder:7b
Question: How many albums does each artist have? Show artist name and count, top 5.
Selected tables (5): ['Artist', 'Album', 'Customer', 'Employee', 'Genre']
Schema text length: 1159 chars
Model: sqlcoder:7b
Prompt: sqlcoder
Generated SQL: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album b ON a.ArtistId = b.ArtistId GROUP BY a."Name" ORDER BY "Number of Albums" DESC NULLS LAST LIMIT 5
Latency: 19.0s
VALID: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album...
Executed: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album
Rows returned: 5
  ['Iron Maiden', 21]
  ['Led Zeppelin', 14]
  ['Deep Purple', 11]
  ['U2', 10]
  ['Metallica', 10]

FINAL RESULTS
SQL: SELECT a."Name", COUNT(b.AlbumId) AS "Number of Albums" FROM Artist a JOIN Album b ON a.ArtistId = b.ArtistId GROUP BY a."Name" ORDER BY "Numbe

In [17]:
# Cell 19: End-to-End Test — Multiple Questions
test_questions = [
    "List all genres",
    "What are the top 3 customers by total spending?",
    "Find all tracks by AC/DC",
]

for q in test_questions:
    state = {
        "question": q,
        "relevant_tables": [],
        "schema_text": "",
        "generated_sql": "",
        "is_valid": False,
        "validation_error": "",
        "results": None,
        "error": "",
        "retry_count": 0,
        "model_name": PRIMARY_MODEL,
    }
    
    print("=" * 60)
    print(f"Q: {q}")
    t0 = time.time()
    result = agent.invoke(state)
    elapsed = time.time() - t0
    print(f"\nSQL: {result['generated_sql']}")
    print(f"Results: {result['results']}")
    print(f"Retries: {result['retry_count']} | Time: {elapsed:.1f}s")
    print()


Q: List all genres
Question: List all genres
Selected tables (4): ['Genre', 'Playlist', 'PlaylistTrack', 'Track']
Schema text length: 523 chars
Model: sqlcoder:7b
Prompt: sqlcoder
Generated SQL: SELECT Genre.Name FROM Genre
Latency: 4.9s
VALID: SELECT Genre.Name FROM Genre...
Executed: SELECT Genre.Name FROM Genre
Rows returned: 20
  ['Rock']
  ['Jazz']
  ['Metal']
  ['Alternative & Punk']
  ['Rock And Roll']
  ... (20 rows total)

SQL: SELECT Genre.Name FROM Genre
Results: [['Rock'], ['Jazz'], ['Metal'], ['Alternative & Punk'], ['Rock And Roll'], ['Blues'], ['Latin'], ['Reggae'], ['Pop'], ['Soundtrack'], ['Bossa Nova'], ['Easy Listening'], ['Heavy Metal'], ['R&B/Soul'], ['Electronica/Dance'], ['World'], ['Hip Hop/Rap'], ['Science Fiction'], ['TV Shows'], ['Sci Fi & Fantasy']]
Retries: 0 | Time: 5.0s

Q: What are the top 3 customers by total spending?
Question: What are the top 3 customers by total spending?
Selected tables (7): ['Customer', 'Invoice', 'Track', 'Employee', 'MediaType',

In [18]:
# Cell 20: Improved Prompts — SQLite Rules & Schema in Error Repair
GENERIC_PROMPT = """You are a SQL expert. Generate a SQLite-compatible SELECT query for the question below.

Schema:
{schema_text}

Rules:
- Return ONLY the SQL query, no explanation
- Use only SELECT statements
- Use only tables and columns from the schema above
- Use exact column names as shown in the schema (case-sensitive)
- SQLite syntax only: use LIKE not ILIKE, no NULLS FIRST/LAST
- For case-insensitive matching use: LOWER(column) LIKE LOWER('%value%')

Question: {question}

SQL:"""

SQLCODER_PROMPT = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
{schema_text}

### Rules
- SQLite dialect only
- Use exact column names from schema (case-sensitive: FirstName not first_name)
- Use LIKE not ILIKE (SQLite has no ILIKE)
- No NULLS FIRST/LAST

### Answer
Given the database schema, here is the SQL query that answers `{question}`:
```sql
"""

ERROR_REPAIR_GENERIC = """The following SQL query produced an error. Fix it using the schema below.

Schema:
{schema_text}

Question: {question}

Failed SQL:
{generated_sql}

Error:
{error}

Rules:
- Use exact column names from the schema above (case-sensitive)
- SQLite syntax only: use LIKE not ILIKE
- Return ONLY the corrected SQL query, no explanation

SQL:"""

ERROR_REPAIR_SQLCODER = """### Task
The following SQL query produced an error. Fix the query to answer:
`{question}`

### Database Schema
{schema_text}

### Failed Query
{generated_sql}

### Error
{error}

### Rules
- Use exact column names from the schema (case-sensitive: FirstName not first_name)
- SQLite dialect only: use LIKE not ILIKE
- No NULLS FIRST/LAST

### Corrected Answer
```sql
"""

# Update handle_error to use new templates
def handle_error(state: AgentState) -> dict:
    """Feed error back to LLM for SQL repair."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = ERROR_REPAIR_SQLCODER
    else:
        template = ERROR_REPAIR_GENERIC
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
        generated_sql=state["generated_sql"],
        error=state["error"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    new_retry = state["retry_count"] + 1
    print(f"Retry {new_retry}: {sql[:80]}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql, "retry_count": new_retry, "error": ""}

# Also update generate_sql to use new templates
def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = SQLCODER_PROMPT
    else:
        template = GENERIC_PROMPT
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    print(f"Model: {model}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Recompile the graph with updated functions
workflow = StateGraph(AgentState)
workflow.add_node("schema_filter", schema_filter)
workflow.add_node("generate_sql", generate_sql)
workflow.add_node("validate_query", validate_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("handle_error", handle_error)
workflow.set_entry_point("schema_filter")
workflow.add_edge("schema_filter", "generate_sql")
workflow.add_edge("generate_sql", "validate_query")
workflow.add_conditional_edges("validate_query", check_validation)
workflow.add_conditional_edges("execute_query", should_retry)
workflow.add_edge("handle_error", "validate_query")
agent = workflow.compile()

print("Agent recompiled with improved prompts")
print("Changes: exact column name rules, LIKE not ILIKE, no NULLS LAST, schema in error repair")


Agent recompiled with improved prompts
Changes: exact column name rules, LIKE not ILIKE, no NULLS LAST, schema in error repair


In [19]:
# Cell 21: Retest Previously Failed Questions
failed_questions = [
    "What are the top 3 customers by total spending?",
    "Find all tracks by AC/DC",
]

for q in failed_questions:
    state = {
        "question": q,
        "relevant_tables": [],
        "schema_text": "",
        "generated_sql": "",
        "is_valid": False,
        "validation_error": "",
        "results": None,
        "error": "",
        "retry_count": 0,
        "model_name": PRIMARY_MODEL,
    }
    
    print("=" * 60)
    print(f"Q: {q}")
    t0 = time.time()
    result = agent.invoke(state)
    elapsed = time.time() - t0
    print(f"\nSQL: {result['generated_sql']}")
    print(f"Results: {result['results']}")
    print(f"Retries: {result['retry_count']} | Time: {elapsed:.1f}s")
    print()


Q: What are the top 3 customers by total spending?
Question: What are the top 3 customers by total spending?
Selected tables (7): ['Customer', 'Invoice', 'Track', 'Employee', 'MediaType', 'Genre', 'Album']
Schema text length: 1839 chars
Model: sqlcoder:7b
Generated SQL: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FROM customer p JOIN invoice i ON p.customerid = i.customerid GROUP BY p.first_name, p.last_name ORDER BY total_spent DESC NULLS LAST LIMIT 3
Latency: 24.2s
VALID: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FROM customer p JO...
Execution error: (sqlite3.OperationalError) no such column: p.first_name
[SQL: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FROM customer p JOIN invoice i ON p.customerid = i.customerid GROUP BY p.first_name, p.last_name ORDER BY total_spent DESC NULLS LAST LIMIT 3]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Retry 1: SELECT p.first_name, p.last_name, SUM(i.total) AS total_spent FRO

In [20]:
# Cell 22: SQL Post-Processing for SQLite Compatibility
import re

def build_column_map():
    """Build case-insensitive column name mapping from schema_info."""
    col_map = {}  # lowercase -> actual name
    for table_name, info in schema_info.items():
        for col in info["columns"]:
            col_map[col["name"].lower()] = col["name"]
        # Also map table names
        col_map[table_name.lower()] = table_name
    return col_map

COLUMN_MAP = build_column_map()

def postprocess_sql(sql: str) -> str:
    """Fix known SQLite incompatibilities in generated SQL."""
    original = sql
    
    # 1. Replace ILIKE with LIKE (SQLite LIKE is case-insensitive for ASCII)
    sql = re.sub(r'\bILIKE\b', 'LIKE', sql, flags=re.IGNORECASE)
    
    # 2. Remove NULLS FIRST / NULLS LAST
    sql = re.sub(r'\s+NULLS\s+(FIRST|LAST)\b', '', sql, flags=re.IGNORECASE)
    
    # 3. Fix column name casing (snake_case -> PascalCase)
    def replace_identifier(match):
        word = match.group(0)
        # Don't replace SQL keywords or aliases
        sql_keywords = {
            'SELECT', 'FROM', 'WHERE', 'JOIN', 'ON', 'GROUP', 'BY', 'ORDER',
            'HAVING', 'LIMIT', 'AS', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'IS',
            'NULL', 'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'DESC', 'ASC',
            'DISTINCT', 'BETWEEN', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END',
            'INNER', 'LEFT', 'RIGHT', 'OUTER', 'UNION', 'ALL', 'EXISTS',
            'CAST', 'LOWER', 'UPPER', 'LENGTH', 'SUBSTR', 'TRIM',
        }
        if word.upper() in sql_keywords:
            return word
        lookup = word.lower().replace('"', '').replace("'", '')
        if lookup in COLUMN_MAP:
            return COLUMN_MAP[lookup]
        return word
    
    # Match identifiers (possibly quoted)
    sql = re.sub(r'"?\b[A-Za-z_]\w*\b"?', replace_identifier, sql)
    
    if sql != original:
        changes = []
        if 'ILIKE' in original.upper() and 'ILIKE' not in sql.upper():
            changes.append("ILIKE->LIKE")
        if re.search(r'NULLS\s+(FIRST|LAST)', original, re.IGNORECASE):
            changes.append("removed NULLS FIRST/LAST")
        if sql.lower() != original.lower():
            changes.append("fixed column casing")
        print(f"  Post-processed: {', '.join(changes)}")
    
    return sql

# Test it
test_cases = [
    "SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST",
    "SELECT Track.Name FROM Track WHERE Composer ILIKE '%AC/DC%'",
    "SELECT Name FROM Artist LIMIT 5",  # should be unchanged
]

for sql in test_cases:
    print(f"  Input:  {sql}")
    fixed = postprocess_sql(sql)
    print(f"  Output: {fixed}")
    print()


  Input:  SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST
  Post-processed: removed NULLS FIRST/LAST, fixed column casing
  Output: SELECT p.first_name, p.last_name FROM Customer p ORDER BY p.last_name

  Input:  SELECT Track.Name FROM Track WHERE Composer ILIKE '%AC/DC%'
  Post-processed: ILIKE->LIKE, fixed column casing
  Output: SELECT Track.Name FROM Track WHERE Composer LIKE '%AC/DC%'

  Input:  SELECT Name FROM Artist LIMIT 5
  Output: SELECT Name FROM Artist LIMIT 5



In [21]:
# Cell 23: Fix Column Mapping — Handle snake_case to PascalCase
def build_column_map():
    """Build case-insensitive column name mapping, including snake_case variants."""
    col_map = {}
    for table_name, info in schema_info.items():
        for col in info["columns"]:
            actual = col["name"]
            # Map lowercase version
            col_map[actual.lower()] = actual
            # Map snake_case version (e.g., first_name -> FirstName)
            snake = re.sub(r'(?<!^)(?=[A-Z])', '_', actual).lower()
            col_map[snake] = actual
        # Table names
        col_map[table_name.lower()] = table_name
    return col_map

COLUMN_MAP = build_column_map()

# Verify the fix
print("Key mappings:")
for key in ["first_name", "firstname", "last_name", "lastname", 
            "customerid", "customer_id", "artistid", "artist_id",
            "supportrepid", "support_rep_id", "billingcountry", "billing_country"]:
    actual = COLUMN_MAP.get(key, "NOT FOUND")
    print(f"  {key:20s} -> {actual}")

# Retest postprocess
print("\nPost-process test:")
sql = "SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST"
print(f"  Input:  {sql}")
print(f"  Output: {postprocess_sql(sql)}")


Key mappings:
  first_name           -> FirstName
  firstname            -> FirstName
  last_name            -> LastName
  lastname             -> LastName
  customerid           -> CustomerId
  customer_id          -> CustomerId
  artistid             -> ArtistId
  artist_id            -> ArtistId
  supportrepid         -> SupportRepId
  support_rep_id       -> SupportRepId
  billingcountry       -> BillingCountry
  billing_country      -> BillingCountry

Post-process test:
  Input:  SELECT p.first_name, p.last_name FROM customer p ORDER BY p.last_name NULLS LAST
  Post-processed: removed NULLS FIRST/LAST, fixed column casing
  Output: SELECT p.FirstName, p.LastName FROM Customer p ORDER BY p.LastName


In [22]:
# Cell 24: Integrate Post-Processing & Retest Failed Queries
def generate_sql(state: AgentState) -> dict:
    """Generate SQL from question + filtered schema using LLM, with post-processing."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = SQLCODER_PROMPT
    else:
        template = GENERIC_PROMPT
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    
    # Post-process for SQLite compatibility
    sql = postprocess_sql(sql)
    
    print(f"Model: {model}")
    print(f"Generated SQL: {sql}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql}

# Also apply post-processing in handle_error
def handle_error(state: AgentState) -> dict:
    """Feed error back to LLM for SQL repair, with post-processing."""
    model = state["model_name"]
    
    if "sqlcoder" in model:
        template = ERROR_REPAIR_SQLCODER
    else:
        template = ERROR_REPAIR_GENERIC
    
    prompt = template.format(
        schema_text=state["schema_text"],
        question=state["question"],
        generated_sql=state["generated_sql"],
        error=state["error"],
    )
    
    llm = ChatOllama(model=model, base_url=OLLAMA_BASE_URL, temperature=0)
    
    t0 = time.time()
    response = llm.invoke(prompt)
    elapsed = time.time() - t0
    
    sql = response.content.strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()
    sql = sql.rstrip(";").strip()
    sql = postprocess_sql(sql)
    
    new_retry = state["retry_count"] + 1
    print(f"Retry {new_retry}: {sql[:80]}")
    print(f"Latency: {elapsed:.1f}s")
    
    return {"generated_sql": sql, "retry_count": new_retry, "error": ""}

# Recompile
workflow = StateGraph(AgentState)
workflow.add_node("schema_filter", schema_filter)
workflow.add_node("generate_sql", generate_sql)
workflow.add_node("validate_query", validate_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("handle_error", handle_error)
workflow.set_entry_point("schema_filter")
workflow.add_edge("schema_filter", "generate_sql")
workflow.add_edge("generate_sql", "validate_query")
workflow.add_conditional_edges("validate_query", check_validation)
workflow.add_conditional_edges("execute_query", should_retry)
workflow.add_edge("handle_error", "validate_query")
agent = workflow.compile()

# Retest the two failed queries
for q in ["What are the top 3 customers by total spending?", "Find all tracks by AC/DC"]:
    state = {
        "question": q, "relevant_tables": [], "schema_text": "",
        "generated_sql": "", "is_valid": False, "validation_error": "",
        "results": None, "error": "", "retry_count": 0, "model_name": PRIMARY_MODEL,
    }
    print("=" * 60)
    print(f"Q: {q}")
    t0 = time.time()
    result = agent.invoke(state)
    elapsed = time.time() - t0
    print(f"\nSQL: {result['generated_sql']}")
    print(f"Results: {result['results']}")
    print(f"Retries: {result['retry_count']} | Time: {elapsed:.1f}s")
    print()


Q: What are the top 3 customers by total spending?
Question: What are the top 3 customers by total spending?
Selected tables (7): ['Customer', 'Invoice', 'Track', 'Employee', 'MediaType', 'Genre', 'Album']
Schema text length: 1839 chars
  Post-processed: removed NULLS FIRST/LAST, fixed column casing
Model: sqlcoder:7b
Generated SQL: SELECT p.FirstName, p.LastName, SUM(i.Total) AS total_spent FROM Customer p JOIN Invoice i ON p.CustomerId = i.CustomerId GROUP BY p.FirstName, p.LastName ORDER BY total_spent DESC LIMIT 3
Latency: 20.7s
VALID: SELECT p.FirstName, p.LastName, SUM(i.Total) AS total_spent FROM Customer p JOIN...
Executed: SELECT p.FirstName, p.LastName, SUM(i.Total) AS total_spent FROM Customer p JOIN
Rows returned: 3
  ['Helena', 'Holý', 49.620000000000005]
  ['Richard', 'Cunningham', 47.620000000000005]
  ['Luis', 'Rojas', 46.62]

SQL: SELECT p.FirstName, p.LastName, SUM(i.Total) AS total_spent FROM Customer p JOIN Invoice i ON p.CustomerId = i.CustomerId GROUP BY p.FirstNa

In [23]:
# Cell 25: Genre Exploration — Heavy Metal, Metal & Blues
genre_questions = [
    "How many tracks are in the Heavy Metal, Metal, and Blues genres?",
    "Who are the top 5 artists with the most tracks in Heavy Metal, Metal, or Blues genres?",
]

for q in genre_questions:
    state = {
        "question": q, "relevant_tables": [], "schema_text": "",
        "generated_sql": "", "is_valid": False, "validation_error": "",
        "results": None, "error": "", "retry_count": 0, "model_name": PRIMARY_MODEL,
    }
    print("=" * 60)
    print(f"Q: {q}")
    t0 = time.time()
    result = agent.invoke(state)
    elapsed = time.time() - t0
    print(f"\nSQL: {result['generated_sql']}")
    print(f"Results: {result['results']}")
    print(f"Retries: {result['retry_count']} | Time: {elapsed:.1f}s")
    print()


Q: How many tracks are in the Heavy Metal, Metal, and Blues genres?
Question: How many tracks are in the Heavy Metal, Metal, and Blues genres?
Selected tables (7): ['Invoice', 'Genre', 'Track', 'InvoiceLine', 'Customer', 'MediaType', 'Album']
Schema text length: 1664 chars
  Post-processed: ILIKE->LIKE, fixed column casing
Model: sqlcoder:7b
Generated SQL: SELECT COUNT(*) AS total_tracks FROM Track WHERE GenreId IN (SELECT GenreId FROM Genre WHERE Name LIKE '%heavy%metal%' OR Name LIKE '%metal%' OR Name LIKE '%blues%')
Latency: 17.9s
VALID: SELECT COUNT(*) AS total_tracks FROM Track WHERE GenreId IN (SELECT GenreId FROM...
Executed: SELECT COUNT(*) AS total_tracks FROM Track WHERE GenreId IN (SELECT GenreId FROM
Rows returned: 1
  [483]

SQL: SELECT COUNT(*) AS total_tracks FROM Track WHERE GenreId IN (SELECT GenreId FROM Genre WHERE Name LIKE '%heavy%metal%' OR Name LIKE '%metal%' OR Name LIKE '%blues%')
Results: [[483]]
Retries: 0 | Time: 18.0s

Q: Who are the top 5 artists with the 

## Phase 3: Evaluation Framework (EXP-001)

**Experiment:** EXP-001 — Text-to-SQL Model Comparison
**DSM Framework:** C.1.3 (Capability Experiment), C.1.5 (Limitation Discovery), C.1.6 (Artifact Organization)
**Artifacts:** `data/experiments/s01_d02_exp001/`

**Objective:** Compare `sqlcoder:7b` (SQL fine-tune) vs `llama3.1:8b` (general-purpose) on 14 curated queries across Easy/Medium/Hard difficulties.

**Hypotheses:**
1. sqlcoder:7b achieves higher Execution Accuracy (SQL fine-tuning advantage)
2. llama3.1:8b produces more readable SQL (JOINs for names) but slower
3. sqlcoder:7b requires post-processing more frequently (PostgreSQL dialect bias)

**Metrics:** Execution Accuracy, Raw/Effective Parsability, Retry Rate, Post-Processing Rate, Latency, Error Categories

**Plan:**
- Cell 27: Test suite definition (14 queries + ground truth)
- Cell 28: Evaluation harness (run agent, capture all metrics)
- Cell 29: Run — sqlcoder:7b
- Cell 30: Run — llama3.1:8b
- Cell 31: Results analysis
- Cell 32: Findings + limitations



In [24]:
# Cell 27: EXP-001 Test Suite & Evaluation Setup
from dataclasses import dataclass, field

@dataclass
class TestQuery:
    id: str
    difficulty: str  # Easy, Medium, Hard
    question: str
    expected_description: str  # What correct results look like
    expected_tables: list[str] = field(default_factory=list)  # Minimum tables needed

TEST_SUITE = [
    # --- Easy (5): single table, simple WHERE, basic aggregation ---
    TestQuery("E1", "Easy", "How many employees are there?",
             "Single number: 8", ["Employee"]),
    TestQuery("E2", "Easy", "List all media types",
             "5 types returned", ["MediaType"]),
    TestQuery("E3", "Easy", "What is the most expensive track?",
             "Track with max UnitPrice", ["Track"]),
    TestQuery("E4", "Easy", "How many customers are from Brazil?",
             "Single count of Brazilian customers", ["Customer"]),
    TestQuery("E5", "Easy", "Show the 5 longest tracks by duration",
             "5 tracks ordered by Milliseconds DESC", ["Track"]),

    # --- Medium (5): JOINs, GROUP BY + HAVING, ORDER BY + LIMIT ---
    TestQuery("M1", "Medium", "Which genre has the most tracks?",
             "Rock with 1,297 tracks", ["Track", "Genre"]),
    TestQuery("M2", "Medium", "How much has each customer spent in total? Show top 5.",
             "Top 5 customers by SUM(Invoice.Total)", ["Customer", "Invoice"]),
    TestQuery("M3", "Medium", "List albums that have more than 20 tracks",
             "Albums with track count > 20", ["Album", "Track"]),
    TestQuery("M4", "Medium", "Which employees support the most customers?",
             "Employees ranked by customer count", ["Employee", "Customer"]),
    TestQuery("M5", "Medium", "What are the top 3 best-selling genres by revenue?",
             "Genres by SUM(UnitPrice * Quantity) from InvoiceLine",
             ["Genre", "Track", "InvoiceLine"]),

    # --- Hard (4): multi-table JOINs, subqueries, complex aggregation ---
    TestQuery("H1", "Hard", "Which artists have tracks in more than 2 genres?",
             "Artists with genre count > 2",
             ["Artist", "Album", "Track", "Genre"]),
    TestQuery("H2", "Hard",
             "Find customers who have never purchased a Jazz track",
             "Customers NOT IN Jazz purchases",
             ["Customer", "Invoice", "InvoiceLine", "Track", "Genre"]),
    TestQuery("H3", "Hard",
             "What is the average invoice total by country, only for countries with more than 5 customers?",
             "Countries with >5 customers showing AVG(Total)",
             ["Customer", "Invoice"]),
    TestQuery("H4", "Hard",
             "List the top 3 playlists by total track duration in hours",
             "Playlists by SUM(Milliseconds)/3600000",
             ["Playlist", "PlaylistTrack", "Track"]),
]

# Compute ground truth by running known-correct SQL
GROUND_TRUTH = {}
with engine.connect() as conn:
    GROUND_TRUTH["E1"] = conn.execute(text("SELECT COUNT(*) FROM Employee")).scalar()
    GROUND_TRUTH["E2"] = conn.execute(text("SELECT COUNT(*) FROM MediaType")).scalar()
    GROUND_TRUTH["E3"] = conn.execute(text(
        "SELECT Name, UnitPrice FROM Track ORDER BY UnitPrice DESC LIMIT 1"
    )).fetchone()
    GROUND_TRUTH["E4"] = conn.execute(text(
        "SELECT COUNT(*) FROM Customer WHERE Country = 'Brazil'"
    )).scalar()
    GROUND_TRUTH["E5"] = conn.execute(text(
        "SELECT Name, Milliseconds FROM Track ORDER BY Milliseconds DESC LIMIT 5"
    )).fetchall()
    GROUND_TRUTH["M1"] = conn.execute(text(
        "SELECT g.Name, COUNT(t.TrackId) as cnt FROM Genre g "
        "JOIN Track t ON g.GenreId = t.GenreId "
        "GROUP BY g.Name ORDER BY cnt DESC LIMIT 1"
    )).fetchone()
    GROUND_TRUTH["M2"] = conn.execute(text(
        "SELECT c.FirstName, c.LastName, SUM(i.Total) as total "
        "FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId "
        "GROUP BY c.CustomerId ORDER BY total DESC LIMIT 5"
    )).fetchall()
    GROUND_TRUTH["M3"] = conn.execute(text(
        "SELECT a.Title, COUNT(t.TrackId) as cnt FROM Album a "
        "JOIN Track t ON a.AlbumId = t.AlbumId "
        "GROUP BY a.AlbumId HAVING cnt > 20 ORDER BY cnt DESC"
    )).fetchall()
    GROUND_TRUTH["M4"] = conn.execute(text(
        "SELECT e.FirstName, e.LastName, COUNT(c.CustomerId) as cnt "
        "FROM Employee e JOIN Customer c ON e.EmployeeId = c.SupportRepId "
        "GROUP BY e.EmployeeId ORDER BY cnt DESC"
    )).fetchall()
    GROUND_TRUTH["M5"] = conn.execute(text(
        "SELECT g.Name, SUM(il.UnitPrice * il.Quantity) as revenue "
        "FROM Genre g JOIN Track t ON g.GenreId = t.GenreId "
        "JOIN InvoiceLine il ON t.TrackId = il.TrackId "
        "GROUP BY g.Name ORDER BY revenue DESC LIMIT 3"
    )).fetchall()
    GROUND_TRUTH["H1"] = conn.execute(text(
        "SELECT ar.Name, COUNT(DISTINCT g.GenreId) as genre_count "
        "FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId "
        "JOIN Track t ON al.AlbumId = t.AlbumId "
        "JOIN Genre g ON t.GenreId = g.GenreId "
        "GROUP BY ar.ArtistId HAVING genre_count > 2 ORDER BY genre_count DESC"
    )).fetchall()
    GROUND_TRUTH["H2"] = conn.execute(text(
        "SELECT COUNT(*) FROM Customer WHERE CustomerId NOT IN ("
        "SELECT DISTINCT c.CustomerId FROM Customer c "
        "JOIN Invoice i ON c.CustomerId = i.CustomerId "
        "JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId "
        "JOIN Track t ON il.TrackId = t.TrackId "
        "JOIN Genre g ON t.GenreId = g.GenreId "
        "WHERE g.Name = 'Jazz')"
    )).scalar()
    GROUND_TRUTH["H3"] = conn.execute(text(
        "SELECT c.Country, AVG(i.Total) as avg_total, COUNT(DISTINCT c.CustomerId) as cust_count "
        "FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId "
        "GROUP BY c.Country HAVING cust_count > 5 ORDER BY avg_total DESC"
    )).fetchall()
    GROUND_TRUTH["H4"] = conn.execute(text(
        "SELECT p.Name, SUM(t.Milliseconds) / 3600000.0 as hours "
        "FROM Playlist p JOIN PlaylistTrack pt ON p.PlaylistId = pt.PlaylistId "
        "JOIN Track t ON pt.TrackId = t.TrackId "
        "GROUP BY p.PlaylistId ORDER BY hours DESC LIMIT 3"
    )).fetchall()

# Display test suite and ground truth
print(f"Test suite: {len(TEST_SUITE)} queries")
print(f"  Easy: {sum(1 for t in TEST_SUITE if t.difficulty == 'Easy')}")
print(f"  Medium: {sum(1 for t in TEST_SUITE if t.difficulty == 'Medium')}")
print(f"  Hard: {sum(1 for t in TEST_SUITE if t.difficulty == 'Hard')}")
print()
for tq in TEST_SUITE:
    gt = GROUND_TRUTH[tq.id]
    if isinstance(gt, (int, float)):
        gt_display = str(gt)
    elif isinstance(gt, tuple):
        gt_display = str(list(gt))
    elif isinstance(gt, list):
        gt_display = f"{len(gt)} rows: {[list(r) for r in gt[:2]]}..."
    else:
        gt_display = str(gt)
    print(f"  [{tq.id}] {tq.difficulty:6s} | {tq.question[:55]:55s} | GT: {gt_display[:60]}")


Test suite: 14 queries
  Easy: 5
  Medium: 5
  Hard: 4

  [E1] Easy   | How many employees are there?                           | GT: 8
  [E2] Easy   | List all media types                                    | GT: 5
  [E3] Easy   | What is the most expensive track?                       | GT: ('Battlestar Galactica: The Story So Far', 1.99)
  [E4] Easy   | How many customers are from Brazil?                     | GT: 5
  [E5] Easy   | Show the 5 longest tracks by duration                   | GT: 5 rows: [['Occupation / Precipice', 5286953], ['Through a Lo
  [M1] Medium | Which genre has the most tracks?                        | GT: ('Rock', 1297)
  [M2] Medium | How much has each customer spent in total? Show top 5.  | GT: 5 rows: [['Helena', 'Holý', 49.620000000000005], ['Richard',
  [M3] Medium | List albums that have more than 20 tracks               | GT: 17 rows: [['Greatest Hits', 57], ['Minha Historia', 34]]...
  [M4] Medium | Which employees support the most customers?         

In [25]:
# Cell 28: EXP-001 Evaluation Harness

import time
import sqlglot
from dataclasses import dataclass
from typing import Any, Optional

@dataclass
class EvalResult:
    """Single query evaluation result for EXP-001."""
    query_id: str
    difficulty: str
    question: str
    model: str
    raw_sql: Optional[str] = None
    final_sql: Optional[str] = None
    raw_parsable: bool = False
    effectively_parsable: bool = False
    execution_accurate: bool = False
    post_processing_applied: bool = False
    retry_count: int = 0
    latency_seconds: float = 0.0
    actual_result: Any = None
    error: Optional[str] = None
    error_category: Optional[str] = None

def check_sql_parsable(sql: str) -> bool:
    """Check if SQL parses with sqlglot."""
    if not sql or not sql.strip():
        return False
    try:
        result = sqlglot.parse(sql)
        return len(result) > 0 and result[0] is not None
    except Exception:
        return False

def to_plain(val):
    """Convert SQLAlchemy Row objects to plain Python types."""
    if val is None:
        return None
    if hasattr(val, '_mapping'):
        return tuple(val)
    if isinstance(val, list):
        return [to_plain(v) for v in val]
    if isinstance(val, tuple):
        return tuple(to_plain(v) for v in val)
    return val

def values_match(a, e) -> bool:
    """Compare two values with float tolerance."""
    if isinstance(a, float) and isinstance(e, float):
        return abs(a - e) < 0.01
    return a == e

def compare_results(actual, expected, query_id: str) -> bool:
    """Compare agent output to ground truth.
    
    GT formats:
    - Scalar int/float: count queries (E1, E4) or row-count checks (E2, H2)
    - Tuple: single-row result (E3, M1)  
    - List: multi-row result (E5, M2-M5, H1, H3, H4)
    
    For scalar GT with list actual:
    - If actual is [(N,)] (1 row, 1 col): compare N to GT (count query)
    - If actual is [row, row, ...]: compare len to GT (row count)
    """
    if actual is None:
        return False

    actual = to_plain(actual)
    expected = to_plain(expected)

    if expected is None:
        return False

    # --- Scalar GT ---
    if isinstance(expected, (int, float)):
        if isinstance(actual, list):
            if len(actual) == 0:
                return False
            if len(actual) == 1:
                row = actual[0]
                if isinstance(row, (list, tuple)) and len(row) == 1:
                    return values_match(row[0], expected)
            # Multi-row result: compare row count
            return len(actual) == expected
        return values_match(actual, expected)

    # --- Tuple GT (single row) ---
    if isinstance(expected, tuple):
        if isinstance(actual, list) and len(actual) >= 1:
            row = actual[0]
            if not isinstance(row, (list, tuple)):
                row = (row,)
            return all(values_match(a, e) for a, e in zip(row, expected))
        return False

    # --- List GT (multi-row) ---
    if isinstance(expected, list):
        if not isinstance(actual, list):
            return False
        if len(actual) != len(expected):
            return False
        if len(expected) > 0 and len(actual) > 0:
            a_row = actual[0] if isinstance(actual[0], (list, tuple)) else (actual[0],)
            e_row = expected[0] if isinstance(expected[0], (list, tuple)) else (expected[0],)
            if not all(values_match(a, e) for a, e in zip(a_row, e_row)):
                return False
        return True

    return False

def categorize_error(er: EvalResult) -> str:
    """Assign error category per EXP-001 protocol."""
    err = (er.error or "").lower()
    if "no such column" in err or "ambiguous" in err:
        return "schema_linking"
    if er.raw_sql and not er.raw_parsable:
        return "syntax"
    if er.post_processing_applied and not er.effectively_parsable:
        return "dialect"
    if "no such table" in err:
        return "hallucination"
    if er.effectively_parsable:
        return "logic"
    return "unknown"

def run_evaluation(model_name: str, test_suite: list, ground_truth: dict) -> list:
    """Run EXP-001 evaluation for one model on the full test suite.
    
    Uses graph.stream() to capture raw SQL from generate_sql node
    (before post-processing) alongside final execution results.
    """
    eval_results = []

    print(f"\n{'='*70}")
    print(f"  EXP-001 EVALUATION: {model_name}")
    print(f"  {len(test_suite)} queries | temp=0 | max_retries=3")
    print(f"{'='*70}")

    for i, tq in enumerate(test_suite):
        print(f"\n[{tq.id}] ({i+1}/{len(test_suite)}) {tq.difficulty}")
        print(f"  Q: {tq.question}")

        er = EvalResult(
            query_id=tq.id, difficulty=tq.difficulty,
            question=tq.question, model=model_name
        )

        initial_state = {
            "question": tq.question,
            "model_name": model_name,
            "retry_count": 0,
            "relevant_tables": [],
            "schema_text": "",
            "generated_sql": "",
            "is_valid": False,
            "validation_error": "",
            "results": None,
            "error": "",
        }

        start = time.time()
        try:
            raw_sql_captured = None
            final_state = {}

            for event in graph.stream(initial_state):
                for node_name, update in event.items():
                    if node_name == "generate_sql":
                        raw_sql_captured = update.get("generated_sql")
                    final_state.update(update)

            er.latency_seconds = time.time() - start
            er.raw_sql = raw_sql_captured
            er.final_sql = final_state.get("generated_sql")
            er.actual_result = final_state.get("results")
            er.error = final_state.get("error") or None
            er.retry_count = final_state.get("retry_count", 0)

            # Metrics
            er.raw_parsable = check_sql_parsable(er.raw_sql) if er.raw_sql else False
            er.effectively_parsable = (
                er.actual_result is not None and not er.error
            )
            er.post_processing_applied = (
                er.raw_sql is not None
                and er.final_sql is not None
                and er.raw_sql.strip() != er.final_sql.strip()
            )
            er.execution_accurate = compare_results(
                er.actual_result, ground_truth.get(tq.id), tq.id
            )
            if not er.execution_accurate:
                er.error_category = categorize_error(er)

        except Exception as exc:
            er.latency_seconds = time.time() - start
            er.error = str(exc)
            er.error_category = "runtime"

        eval_results.append(er)

        # Per-query output
        tag = "PASS" if er.execution_accurate else "FAIL"
        print(f"  {tag} | retries={er.retry_count} | {er.latency_seconds:.1f}s"
              f" | pp={'Y' if er.post_processing_applied else 'N'}")
        if er.final_sql:
            print(f"  SQL: {er.final_sql.replace(chr(10), ' ')[:75]}")
        if not er.execution_accurate:
            cat = er.error_category or "?"
            err_msg = f" — {er.error[:55]}" if er.error else ""
            print(f"  [{cat}]{err_msg}")

    # ── Aggregate metrics ──
    n = len(eval_results)
    metrics = {
        "Execution Accuracy (EX)": sum(r.execution_accurate for r in eval_results),
        "Raw Parsability":         sum(r.raw_parsable for r in eval_results),
        "Effective Parsability":   sum(r.effectively_parsable for r in eval_results),
        "Retry Rate":              sum(r.retry_count > 0 for r in eval_results),
        "Post-Processing Rate":    sum(r.post_processing_applied for r in eval_results),
    }
    avg_latency = sum(r.latency_seconds for r in eval_results) / n if n else 0

    print(f"\n{'='*70}")
    print(f"  RESULTS: {model_name}")
    print(f"  {'─'*66}")
    for name, count in metrics.items():
        print(f"  {name:<28s} {count:>2}/{n}  ({count/n*100:5.1f}%)")
    print(f"  {'Avg Latency':<28s} {avg_latency:>6.1f}s")

    # Per-difficulty breakdown
    print(f"  {'─'*66}")
    for diff in ["Easy", "Medium", "Hard"]:
        subset = [r for r in eval_results if r.difficulty == diff]
        if subset:
            ex = sum(r.execution_accurate for r in subset)
            lat = sum(r.latency_seconds for r in subset) / len(subset)
            print(f"  {diff:<8s} EX={ex}/{len(subset)}  Avg latency={lat:.1f}s")
    print(f"{'='*70}")

    return eval_results

print("Evaluation harness defined.")
print("Next steps:")
print("  Cell 29 → sqlcoder_results = run_evaluation('sqlcoder:7b', TEST_SUITE, GROUND_TRUTH)")
print("  Cell 30 → llama_results   = run_evaluation('llama3.1:8b', TEST_SUITE, GROUND_TRUTH)")


Evaluation harness defined.
Next steps:
  Cell 29 → sqlcoder_results = run_evaluation('sqlcoder:7b', TEST_SUITE, GROUND_TRUTH)
  Cell 30 → llama_results   = run_evaluation('llama3.1:8b', TEST_SUITE, GROUND_TRUTH)


In [27]:
# Cell 28b: Fix — alias graph variable (compiled in Cell 17 as 'agent')
graph = agent
print(f"Graph aliased: {type(graph).__name__} with nodes {list(graph.nodes.keys())}")


Graph aliased: CompiledStateGraph with nodes ['__start__', 'schema_filter', 'generate_sql', 'validate_query', 'execute_query', 'handle_error']


In [29]:
# Cell 29: Load EXP-001 Results from Scripts
import json
from pathlib import Path

EXP_DIR = Path("../data/experiments/s01_d02_exp001")

with open(EXP_DIR / "results_sqlcoder_7b.json") as f:
    sqlcoder_data = json.load(f)

with open(EXP_DIR / "results_llama3_1_8b.json") as f:
    llama_data = json.load(f)

# Display summary comparison
print(f"{'Metric':<28s} {'sqlcoder:7b':>12s} {'llama3.1:8b':>12s} {'Delta':>8s}")
print("─" * 64)

metrics = [
    ("Execution Accuracy (EX)", "execution_accuracy"),
    ("Raw Parsability", "raw_parsability"),
    ("Effective Parsability", "effective_parsability"),
    ("Retry Rate", "retry_rate"),
    ("Post-Processing Rate", "post_processing_rate"),
]

n = sqlcoder_data["n_queries"]
for label, key in metrics:
    sv = sqlcoder_data["summary"][key]
    lv = llama_data["summary"][key]
    delta = lv - sv
    sign = "+" if delta > 0 else ""
    print(f"{label:<28s} {sv:>2}/{n} ({sv/n*100:5.1f}%) {lv:>2}/{n} ({lv/n*100:5.1f}%) {sign}{delta:>4}")

s_lat = sqlcoder_data["summary"]["avg_latency"]
l_lat = llama_data["summary"]["avg_latency"]
print(f"{'Avg Latency':<28s} {s_lat:>10.1f}s {l_lat:>10.1f}s {l_lat - s_lat:>+7.1f}s")

print(f"\n{'Difficulty':<10s} {'sqlcoder EX':>12s} {'llama EX':>12s} {'sqlcoder lat':>13s} {'llama lat':>13s}")
print("─" * 64)
for diff in ["Easy", "Medium", "Hard"]:
    sd = sqlcoder_data["per_difficulty"][diff]
    ld = llama_data["per_difficulty"][diff]
    print(f"{diff:<10s} {sd['execution_accuracy']:>2}/{sd['n']}         "
          f"{ld['execution_accuracy']:>2}/{ld['n']}         "
          f"{sd['avg_latency']:>10.1f}s  {ld['avg_latency']:>10.1f}s")

print(f"\nResults loaded from: {EXP_DIR}")
print(f"  sqlcoder:7b — {len(sqlcoder_data['results'])} query results")
print(f"  llama3.1:8b — {len(llama_data['results'])} query results")



Metric                        sqlcoder:7b  llama3.1:8b    Delta
────────────────────────────────────────────────────────────────
Execution Accuracy (EX)       6/14 ( 42.9%)  6/14 ( 42.9%)    0
Raw Parsability              12/14 ( 85.7%) 14/14 (100.0%) +   2
Effective Parsability         9/14 ( 64.3%) 13/14 ( 92.9%) +   4
Retry Rate                    3/14 ( 21.4%)  2/14 ( 14.3%)   -1
Post-Processing Rate          1/14 (  7.1%)  2/14 ( 14.3%) +   1
Avg Latency                        30.3s       17.6s   -12.7s

Difficulty  sqlcoder EX     llama EX  sqlcoder lat     llama lat
────────────────────────────────────────────────────────────────
Easy        4/5          5/5                9.1s         6.8s
Medium      2/5          1/5               39.5s        13.5s
Hard        0/4          0/4               45.4s        36.2s

Results loaded from: ../data/experiments/s01_d02_exp001
  sqlcoder:7b — 14 query results
  llama3.1:8b — 14 query results


In [30]:
# Cell 30: EXP-001 Per-Query Analysis & Error Patterns

# ── Per-query comparison table ──
print("PER-QUERY COMPARISON")
print(f"{'ID':<4s} {'Diff':<7s} {'sqlcoder':>8s} {'llama':>8s} {'sqlcoder err':>14s} {'llama err':>14s}")
print("─" * 60)

for sq, lq in zip(sqlcoder_data["results"], llama_data["results"]):
    s_tag = "PASS" if sq["execution_accurate"] else "FAIL"
    l_tag = "PASS" if lq["execution_accurate"] else "FAIL"
    s_err = sq["error_category"] or ""
    l_err = lq["error_category"] or ""
    print(f"{sq['query_id']:<4s} {sq['difficulty']:<7s} {s_tag:>8s} {l_tag:>8s} {s_err:>14s} {l_err:>14s}")

# ── Where they diverge ──
print("\n\nDIVERGENCES (one passed, other failed)")
print("─" * 60)
for sq, lq in zip(sqlcoder_data["results"], llama_data["results"]):
    if sq["execution_accurate"] != lq["execution_accurate"]:
        winner = "sqlcoder" if sq["execution_accurate"] else "llama"
        loser_data = lq if sq["execution_accurate"] else sq
        print(f"\n[{sq['query_id']}] {sq['question']}")
        print(f"  Winner: {winner}")
        print(f"  Loser error: {loser_data['error_category']}")
        if loser_data["error"]:
            print(f"  Error: {loser_data['error'][:80]}")
        if loser_data["final_sql"]:
            print(f"  SQL: {loser_data['final_sql'][:80]}")

# ── Error category distribution ──
print("\n\nERROR CATEGORY DISTRIBUTION")
print(f"{'Category':<18s} {'sqlcoder':>8s} {'llama':>8s}")
print("─" * 36)
categories = set()
for r in sqlcoder_data["results"] + llama_data["results"]:
    if r["error_category"]:
        categories.add(r["error_category"])

for cat in sorted(categories):
    s_count = sum(1 for r in sqlcoder_data["results"] if r["error_category"] == cat)
    l_count = sum(1 for r in llama_data["results"] if r["error_category"] == cat)
    print(f"{cat:<18s} {s_count:>8d} {l_count:>8d}")

s_fail = sum(1 for r in sqlcoder_data["results"] if not r["execution_accurate"])
l_fail = sum(1 for r in llama_data["results"] if not r["execution_accurate"])
print(f"{'TOTAL FAILURES':<18s} {s_fail:>8d} {l_fail:>8d}")

# ── Latency comparison by difficulty ──
print("\n\nLATENCY BY DIFFICULTY (seconds)")
print(f"{'Difficulty':<10s} {'sqlcoder avg':>12s} {'llama avg':>12s} {'speedup':>8s}")
print("─" * 44)
for diff in ["Easy", "Medium", "Hard"]:
    s_lats = [r["latency_seconds"] for r in sqlcoder_data["results"] if r["difficulty"] == diff]
    l_lats = [r["latency_seconds"] for r in llama_data["results"] if r["difficulty"] == diff]
    s_avg = sum(s_lats) / len(s_lats) if s_lats else 0
    l_avg = sum(l_lats) / len(l_lats) if l_lats else 0
    speedup = f"{s_avg / l_avg:.1f}x" if l_avg > 0 else "N/A"
    print(f"{diff:<10s} {s_avg:>10.1f}s {l_avg:>10.1f}s {speedup:>8s}")

# ── Queries both failed (shared limitations) ──
print("\n\nSHARED FAILURES (both models failed)")
print("─" * 60)
for sq, lq in zip(sqlcoder_data["results"], llama_data["results"]):
    if not sq["execution_accurate"] and not lq["execution_accurate"]:
        print(f"\n[{sq['query_id']}] {sq['question']}")
        print(f"  sqlcoder: [{sq['error_category']}] {(sq['error'] or '')[:60]}")
        print(f"  llama:    [{lq['error_category']}] {(lq['error'] or '')[:60]}")


PER-QUERY COMPARISON
ID   Diff    sqlcoder    llama   sqlcoder err      llama err
────────────────────────────────────────────────────────────
E1   Easy        PASS     PASS                              
E2   Easy        PASS     PASS                              
E3   Easy        FAIL     PASS          logic               
E4   Easy        PASS     PASS                              
E5   Easy        PASS     PASS                              
M1   Medium      PASS     PASS                              
M2   Medium      FAIL     FAIL  hallucination          logic
M3   Medium      FAIL     FAIL          logic          logic
M4   Medium      PASS     FAIL                         logic
M5   Medium      FAIL     FAIL        dialect          logic
H1   Hard        FAIL     FAIL          logic schema_linking
H2   Hard        FAIL     FAIL        runtime          logic
H3   Hard        FAIL     FAIL        runtime          logic
H4   Hard        FAIL     FAIL  hallucination          logic


D

In [31]:
# Cell 31: EXP-001 Findings — Hypothesis Evaluation & Model Recommendation

print("=" * 70)
print("  EXP-001 FINDINGS")
print("=" * 70)

# ── Hypothesis evaluation ──
print("\n1. HYPOTHESIS EVALUATION")
print("─" * 70)

# H1: sqlcoder EX > llama EX
s_ex = sqlcoder_data["summary"]["execution_accuracy"]
l_ex = llama_data["summary"]["execution_accuracy"]
h1_result = "REJECTED" if l_ex >= s_ex else "CONFIRMED"
print(f"""
H1 (Accuracy): sqlcoder:7b achieves higher EX than llama3.1:8b
  Result: {h1_result}
  Evidence: sqlcoder EX={s_ex}/14 (42.9%), llama EX={l_ex}/14 (42.9%)
  Analysis: Identical overall accuracy. Research predicted 15-20% advantage
  for SQL fine-tunes — not observed. The fine-tuning advantage on Easy queries
  (+0 vs llama's +1) is offset by llama's stronger Medium performance trade.
  At 7-8B scale, SQL fine-tuning does not provide an EX advantage over a
  general-purpose model on this test suite.""")

# H2: llama produces more readable SQL
print(f"""
H2 (Readability): llama3.1:8b produces more readable SQL with JOINs
  Result: PARTIALLY CONFIRMED
  Evidence: llama uses JOINs with table aliases consistently (e.g., M4 includes
  EmployeeId in output). sqlcoder also uses JOINs but with less consistent
  aliasing. Both models produce readable SQL when they succeed.
  Caveat: H2 is qualitative — no automated readability metric was defined.
  Observation is based on manual review of generated SQL in results JSON.""")

# H3: sqlcoder needs more post-processing
s_pp = sqlcoder_data["summary"]["post_processing_rate"]
l_pp = llama_data["summary"]["post_processing_rate"]
h3_result = "INCONCLUSIVE (ED-1 risk materialized)"
print(f"""
H3 (Post-processing): sqlcoder:7b requires more post-processing
  Result: {h3_result}
  Evidence: sqlcoder PP={s_pp}/14 (7.1%), llama PP={l_pp}/14 (14.3%)
  Analysis: Post-processing is integrated inside generate_sql (Cell 24), not
  as a separate node. ED-1 risk materialized: graph.stream() captures SQL
  AFTER post-processing, so raw_sql ≈ final_sql for most queries. The PP
  metric is unreliable — it only detects cases where retries changed the SQL.
  From Phase 2 observations, sqlcoder definitely produces more PostgreSQL-isms
  (ILIKE, NULLS LAST, snake_case), but we cannot quantify this from EXP-001
  data. Sprint 2 should separate post-processing into its own node to enable
  accurate measurement.""")

# ── Key findings ──
print(f"\n\n2. KEY FINDINGS")
print("─" * 70)

print("""
F1. EQUAL ACCURACY, DIFFERENT FAILURE MODES
   Both models achieve 42.9% EX (6/14), below the 60% sprint target.
   But their error profiles differ fundamentally:
   - sqlcoder: hallucination (2), runtime (2), logic (3), dialect (1)
   - llama:    logic (7), schema_linking (1)
   sqlcoder fails loudly (non-existent tables, unparseable output).
   llama fails quietly (valid SQL, wrong results).

F2. EASY QUERIES ARE SOLVED; HARD QUERIES ARE NOT
   Easy: sqlcoder 4/5, llama 5/5 — both models handle single-table queries.
   Medium: sqlcoder 2/5, llama 1/5 — JOINs and aggregation are unreliable.
   Hard: both 0/4 — multi-table reasoning exceeds 7-8B model capacity.
   The difficulty curve is steep: ~90% Easy → ~30% Medium → 0% Hard.

F3. LLAMA IS FASTER AND MORE RELIABLE (BUT NOT MORE ACCURATE)
   llama: 100% raw parsability, 92.9% effective parsability, 17.6s avg
   sqlcoder: 85.7% raw parsability, 64.3% effective parsability, 30.3s avg
   llama always produces valid SQL. sqlcoder sometimes produces garbage
   (H2, H3 runtime failures). For a user-facing app, llama's reliability
   is more important than sqlcoder's marginal Medium-query advantage.

F4. TABLE HALLUCINATION IS SQLCODER-SPECIFIC
   sqlcoder invented tables: 'payment' (M2), 'media_type' (M5),
   'invoiceintrack' (H4). These don't exist in Chinook. llama never
   hallucinated a table. This is likely a side effect of sqlcoder's
   fine-tuning on diverse SQL schemas — it has memorized common table
   names from training data that override the provided schema context.

F5. POST-PROCESSING METRIC IS NOT SEPARABLE (ED-1 RISK)
   Post-processing is applied inside generate_sql before streaming
   captures the state. Raw Parsability and Post-Processing Rate cannot
   be accurately measured. Sprint 2 architecture should separate
   post-processing into its own graph node.""")

# ── Model recommendation ──
print(f"\n\n3. MODEL RECOMMENDATION FOR SPRINT 2")
print("─" * 70)

print("""
RECOMMENDATION: Use llama3.1:8b as the default model.

Rationale:
  1. Same accuracy as sqlcoder (42.9% EX) — no accuracy penalty
  2. 100% parsable output — never crashes or produces garbage
  3. 1.7x faster average latency (17.6s vs 30.3s)
  4. No table hallucination — errors are logic-only, more predictable
  5. Lower retry rate (14.3% vs 21.4%) — fewer wasted LLM calls

sqlcoder:7b should remain available as an alternative for users who
want to compare, but it does not justify being the default given its
hallucination issues and slower speed.

Both models fail the 60% EX sprint target. Sprint 2 improvements:
  - Better schema filtering (current keyword-based; consider embedding)
  - Few-shot examples per difficulty level
  - Separate post-processing node for accurate metrics
  - Consider larger models if VRAM allows (sqlcoder:15b, llama3.1:70b)""")

print(f"\n{'=' * 70}")


  EXP-001 FINDINGS

1. HYPOTHESIS EVALUATION
──────────────────────────────────────────────────────────────────────

H1 (Accuracy): sqlcoder:7b achieves higher EX than llama3.1:8b
  Result: REJECTED
  Evidence: sqlcoder EX=6/14 (42.9%), llama EX=6/14 (42.9%)
  Analysis: Identical overall accuracy. Research predicted 15-20% advantage
  for SQL fine-tunes — not observed. The fine-tuning advantage on Easy queries
  (+0 vs llama's +1) is offset by llama's stronger Medium performance trade.
  At 7-8B scale, SQL fine-tuning does not provide an EX advantage over a
  general-purpose model on this test suite.

H2 (Readability): llama3.1:8b produces more readable SQL with JOINs
  Result: PARTIALLY CONFIRMED
  Evidence: llama uses JOINs with table aliases consistently (e.g., M4 includes
  EmployeeId in output). sqlcoder also uses JOINs but with less consistent
  aliasing. Both models produce readable SQL when they succeed.
  Caveat: H2 is qualitative — no automated readability metric was defined.

In [32]:
# Cell 32: EXP-001 Limitation Tracking (DSM C.1.5)

print("=" * 70)
print("  EXP-001 LIMITATION DISCOVERY (C.1.5 Protocol)")
print("=" * 70)

limitations = [
    {
        "id": "LIM-001",
        "description": "7-8B models cannot solve multi-table Hard queries (0/4 both models)",
        "type": "Capability ceiling",
        "severity": "High",
        "evidence": "H1-H4: 0% EX for both sqlcoder:7b and llama3.1:8b. Failures include "
                    "schema linking (T.ArtistId), logic (wrong subquery structure), and "
                    "runtime (context overflow on complex queries).",
        "disposition": "Accept for Sprint 2 MVP — document as known limitation. "
                       "Mitigation: test larger models (15B+) if VRAM allows, or "
                       "decompose hard queries into multi-step agent reasoning.",
        "tracking": "Sprint 2 backlog",
    },
    {
        "id": "LIM-002",
        "description": "sqlcoder:7b hallucinates table names from training data",
        "type": "Model-specific defect",
        "severity": "High",
        "evidence": "M2: 'payment' table (3 retries, never self-corrected). "
                    "M5: 'media_type' (snake_case of MediaType, 3 retries). "
                    "H4: 'invoiceintrack' (non-existent, 3 retries). "
                    "0/3 hallucinated tables were corrected by retry loop.",
        "disposition": "Mitigated by choosing llama3.1:8b as default. "
                       "If sqlcoder is used, add schema validation in validate_query "
                       "to reject SQL referencing non-existent tables before execution.",
        "tracking": "DEC-005 (model selection), Sprint 2 validate_query enhancement",
    },
    {
        "id": "LIM-003",
        "description": "Post-processing metrics unreliable (ED-1 risk materialized)",
        "type": "Measurement limitation",
        "severity": "Medium",
        "evidence": "Post-processing integrated in generate_sql (Cell 24). "
                    "graph.stream() captures SQL after post-processing. "
                    "Raw Parsability and Post-Processing Rate are not separable. "
                    "sqlcoder PP=1/14, llama PP=2/14 — both undercount actual PP.",
        "disposition": "Sprint 2: refactor post-processing into separate graph node "
                       "between generate_sql and validate_query. This enables accurate "
                       "raw vs post-processed comparison.",
        "tracking": "Sprint 2 architecture",
    },
    {
        "id": "LIM-004",
        "description": "Retry loop cannot fix systematic model biases",
        "type": "Architecture limitation",
        "severity": "Medium",
        "evidence": "sqlcoder M2: hallucinated 'payment' 3x (never tried Invoice). "
                    "sqlcoder M5: hallucinated 'media_type' 3x. "
                    "llama H1: repeated T.ArtistId error 3x (Track has no ArtistId). "
                    "Retry with same model + error message reproduces the same bias.",
        "disposition": "Expected behavior for temperature=0. Mitigation options for "
                       "Sprint 2: (a) add schema-aware validation that catches wrong "
                       "table/column names before execution, (b) increase temperature "
                       "on retries for diversity, (c) switch model on final retry.",
        "tracking": "Sprint 2 error handling improvements",
    },
    {
        "id": "LIM-005",
        "description": "EX metric has false positive risk on row-count comparison",
        "type": "Measurement limitation",
        "severity": "Low",
        "evidence": "ED-2 known limitation: scalar GT comparison uses row count as "
                    "fallback. A query returning the right number of rows but wrong "
                    "content would be scored as PASS. Not observed in this run but "
                    "possible with larger test suites.",
        "disposition": "Accept for Sprint 1. Sprint 2: implement full result-set "
                       "comparison (sort both, compare all rows) for production metrics.",
        "tracking": "Sprint 2 eval improvements",
    },
    {
        "id": "LIM-006",
        "description": "Schema filter is keyword-based, misses indirect relationships",
        "type": "Architecture limitation",
        "severity": "Medium",
        "evidence": "H1 (llama): selected 8 tables including irrelevant ones (Invoice, "
                    "InvoiceLine, Customer) but the query only needed Artist→Album→Track→Genre. "
                    "M5 (both): schema filter includes correct tables but models still "
                    "fail to use them properly. Over-selection adds noise to the prompt.",
        "disposition": "Sprint 2: evaluate embedding-based schema filtering or "
                       "LLM-based table selection as replacement for keyword scoring.",
        "tracking": "Sprint 2 schema_filter redesign",
    },
]

# Print formatted limitation table
print(f"\n{'ID':<10s} {'Severity':<8s} {'Type':<26s} Description")
print("─" * 70)
for lim in limitations:
    print(f"{lim['id']:<10s} {lim['severity']:<8s} {lim['type']:<26s} {lim['description'][:50]}")

# Print full details
for lim in limitations:
    print(f"\n{'─' * 70}")
    print(f"{lim['id']}: {lim['description']}")
    print(f"  Type: {lim['type']} | Severity: {lim['severity']}")
    print(f"  Evidence: {lim['evidence']}")
    print(f"  Disposition: {lim['disposition']}")
    print(f"  Tracking: {lim['tracking']}")

print(f"\n{'=' * 70}")
print(f"Total limitations: {len(limitations)}")
print(f"  High: {sum(1 for l in limitations if l['severity'] == 'High')}")
print(f"  Medium: {sum(1 for l in limitations if l['severity'] == 'Medium')}")
print(f"  Low: {sum(1 for l in limitations if l['severity'] == 'Low')}")
print(f"\nEXP-001 analysis complete. Next: update README and sprint checkpoint.")


  EXP-001 LIMITATION DISCOVERY (C.1.5 Protocol)

ID         Severity Type                       Description
──────────────────────────────────────────────────────────────────────
LIM-001    High     Capability ceiling         7-8B models cannot solve multi-table Hard queries 
LIM-002    High     Model-specific defect      sqlcoder:7b hallucinates table names from training
LIM-003    Medium   Measurement limitation     Post-processing metrics unreliable (ED-1 risk mate
LIM-004    Medium   Architecture limitation    Retry loop cannot fix systematic model biases
LIM-005    Low      Measurement limitation     EX metric has false positive risk on row-count com
LIM-006    Medium   Architecture limitation    Schema filter is keyword-based, misses indirect re

──────────────────────────────────────────────────────────────────────
LIM-001: 7-8B models cannot solve multi-table Hard queries (0/4 both models)
  Type: Capability ceiling | Severity: High
  Evidence: H1-H4: 0% EX for both sqlcoder:7b