# SQL Agent E2E Evaluation Testing Notebook

This notebook provides an interactive environment for testing and developing SQL agent evaluation scenarios.

In [None]:
# Setup imports and environment
import os
import sys
import json
import time
from pathlib import Path
from dotenv import load_dotenv

# Add parent directory to path for imports
sys.path.append(str(Path.cwd().parent))

# Load environment variables
load_dotenv(override=True)

# Import required modules
from src.agent import bootstrap
from src.agent.adapters.adapter import RouterAdapter
from src.agent.adapters.notifications import CollectingNotifications
from src.agent.domain import commands, events
from src.agent.service_layer import messagebus
from evals.utils import load_database_schema, normalize_sql
from evals.llm_judge import JudgeCriteria, LLMJudge

## Load Database Schema

The SQL agent needs the database schema to generate proper SQL queries.

In [None]:
# Load the schema
schema_path = Path("../evals/sql_agent/schema/schema.json")
schema = load_database_schema(schema_path.parent, schema_path.name)

# Display schema info
print(f"Database: {schema.database_name}")
print(f"Tables: {len(schema.tables)}")
for table in schema.tables:
    print(f"\n  Table: {table.name}")
    print(f"  Columns: {', '.join([f'{col.name} ({col.data_type})' for col in table.columns[:5]])}{'...' if len(table.columns) > 5 else ''}")

## Initialize the SQL Agent System

Set up the message bus with a collecting notifications handler to capture responses.

In [None]:
# Create notifications collector
notifications = CollectingNotifications()

# Bootstrap the message bus with router adapter
adapter = RouterAdapter()
bus = bootstrap.bootstrap(adapter=adapter, notifications=[notifications])

print("SQL Agent system initialized!")

## Helper Functions

Functions to make testing easier.

In [None]:
def clear_notifications():
    """Clear all collected notifications."""
    for key in list(notifications.sent.keys()):
        del notifications.sent[key]

def extract_sql_response(session_id: str) -> str:
    """Extract the SQL query from the evaluation response."""
    session_events = notifications.sent.get(session_id, [])
    
    for event in reversed(session_events):
        if isinstance(event, events.Evaluation):
            summary = event.summary
            # Extract SQL from the summary
            if "Here is the SQL query:" in summary:
                return summary.split("\n\nHere is the SQL query:\n\n")[-1]
    
    return ""

def test_sql_question(question: str, expected_sql: str = None, session_id: str = None):
    """Test a SQL question and optionally evaluate against expected SQL."""
    if session_id is None:
        session_id = f"test-{int(time.time())}"
    
    # Clear previous notifications
    clear_notifications()
    
    # Create the question command
    q_id = f"notebook-{int(time.time())}"
    question_cmd = commands.Question(question=question, q_id=q_id)
    
    # Process through message bus
    print(f"Processing question: {question}")
    print("-" * 80)
    
    start_time = time.time()
    messagebus.handle(question_cmd, bus, session_id)
    
    # Wait for processing to complete
    max_wait = 30
    elapsed = 0
    while elapsed < max_wait:
        events_list = notifications.sent.get(session_id, [])
        if any(isinstance(e, events.Evaluation) for e in events_list):
            break
        time.sleep(0.5)
        elapsed += 0.5
    
    execution_time = time.time() - start_time
    
    # Extract SQL response
    actual_sql = extract_sql_response(session_id)
    
    print(f"\nGenerated SQL:\n{actual_sql}")
    print(f"\nExecution time: {execution_time:.2f}s")
    
    # Print all events for debugging
    print(f"\nEvents received: {[type(e).__name__ for e in notifications.sent.get(session_id, [])]}")
    
    # If expected SQL provided, evaluate
    if expected_sql:
        print("\n" + "=" * 80)
        print("EVALUATION")
        print("=" * 80)
        print(f"\nExpected SQL:\n{expected_sql}")
        
        # Use LLM Judge
        judge = LLMJudge()
        criteria = JudgeCriteria()
        
        judge_result = judge.evaluate(
            question=question,
            expected=normalize_sql(expected_sql),
            actual=normalize_sql(actual_sql) if actual_sql else "NO SQL GENERATED",
            criteria=criteria,
            test_type="sql_e2e_notebook"
        )
        
        print(f"\nJudge Result: {'PASSED' if judge_result.passed else 'FAILED'}")
        print(f"Scores:")
        print(f"  - Accuracy: {judge_result.scores.accuracy}")
        print(f"  - Relevance: {judge_result.scores.relevance}")
        print(f"  - Completeness: {judge_result.scores.completeness}")
        print(f"  - Hallucination: {judge_result.scores.hallucination}")
        print(f"\nAssessment: {judge_result.overall_assessment}")
        
        return actual_sql, judge_result
    
    return actual_sql, None

## Test Basic SQL Questions

Let's test some basic SQL generation scenarios.

In [None]:
# Test 1: Simple SELECT query
sql, judge = test_sql_question(
    question="Show me all customers from New York",
    expected_sql="SELECT * FROM customers WHERE city = 'New York'"
)

In [None]:
# Test 2: JOIN query
sql, judge = test_sql_question(
    question="Show me all orders with customer names",
    expected_sql="SELECT o.*, c.name FROM orders o JOIN customers c ON o.customer_id = c.id"
)

In [None]:
# Test 3: Aggregation query
sql, judge = test_sql_question(
    question="What is the total revenue by product category?",
    expected_sql="SELECT category, SUM(revenue) as total_revenue FROM products GROUP BY category"
)

## Create Custom Test Scenarios

Use this section to create and test your own SQL scenarios.

In [None]:
# Your custom test
custom_question = "Find the top 5 customers by total order value"
custom_expected_sql = """
SELECT c.id, c.name, SUM(o.total_amount) as total_order_value 
FROM customers c 
JOIN orders o ON c.id = o.customer_id 
GROUP BY c.id, c.name 
ORDER BY total_order_value DESC 
LIMIT 5
"""

sql, judge = test_sql_question(
    question=custom_question,
    expected_sql=custom_expected_sql
)

## Batch Testing

Test multiple scenarios at once.

In [None]:
# Define test scenarios
test_scenarios = [
    {
        "name": "simple_filter",
        "question": "Show all active products",
        "expected_sql": "SELECT * FROM products WHERE status = 'active'"
    },
    {
        "name": "date_filter",
        "question": "Find orders from last month",
        "expected_sql": "SELECT * FROM orders WHERE order_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH)"
    },
    {
        "name": "complex_join",
        "question": "Show customer orders with product details",
        "expected_sql": """
            SELECT c.name as customer_name, o.order_date, p.name as product_name, oi.quantity, oi.price
            FROM customers c
            JOIN orders o ON c.id = o.customer_id
            JOIN order_items oi ON o.id = oi.order_id
            JOIN products p ON oi.product_id = p.id
        """
    }
]

# Run all tests
results = []
for scenario in test_scenarios:
    print(f"\n{'='*80}\nTesting: {scenario['name']}\n{'='*80}")
    
    sql, judge = test_sql_question(
        question=scenario['question'],
        expected_sql=scenario['expected_sql']
    )
    
    results.append({
        "name": scenario['name'],
        "passed": judge.passed if judge else None,
        "scores": judge.scores if judge else None
    })
    
    # Wait between tests to avoid rate limiting
    time.sleep(2)

# Summary
print("\n\n" + "="*80)
print("SUMMARY")
print("="*80)
for result in results:
    status = "PASSED" if result['passed'] else "FAILED" if result['passed'] is not None else "NOT EVALUATED"
    print(f"{result['name']}: {status}")
    if result['scores']:
        avg_score = (result['scores'].accuracy + result['scores'].relevance + 
                    result['scores'].completeness + result['scores'].hallucination) / 4
        print(f"  Average Score: {avg_score:.2f}")

## Export Test Cases to YAML

Convert your tested scenarios into YAML format for the evaluation framework.

In [None]:
import yaml

# Create YAML structure for new test cases
def create_yaml_test(name, question, expected_sql, judge_criteria=None):
    test = {
        "name": name,
        "question": question,
        "sql": expected_sql
    }
    
    if judge_criteria:
        test["judge_criteria"] = judge_criteria
    
    return test

# Create a test suite
test_suite = {
    "schema_file": "schema/schema.json",
    "default_judge_criteria": {
        "check_accuracy": True,
        "check_relevance": True,
        "check_completeness": True,
        "check_hallucination": True
    },
    "tests": [
        create_yaml_test(
            name="customer_orders_count",
            question="How many orders does each customer have?",
            expected_sql="SELECT c.name, COUNT(o.id) as order_count FROM customers c LEFT JOIN orders o ON c.id = o.customer_id GROUP BY c.id, c.name"
        ),
        create_yaml_test(
            name="revenue_by_month",
            question="Show monthly revenue for this year",
            expected_sql="SELECT DATE_FORMAT(order_date, '%Y-%m') as month, SUM(total_amount) as revenue FROM orders WHERE YEAR(order_date) = YEAR(CURRENT_DATE) GROUP BY month ORDER BY month"
        )
    ]
}

# Display the YAML
yaml_content = yaml.dump(test_suite, default_flow_style=False, sort_keys=False)
print("Generated YAML for evaluation framework:")
print("-" * 80)
print(yaml_content)

# Optionally save to file
# with open('../evals/sql_agent/e2e/custom_tests.yaml', 'w') as f:
#     yaml.dump(test_suite, f, default_flow_style=False, sort_keys=False)

## Debug SQL Pipeline Stages

Test individual stages of the SQL pipeline for detailed debugging.

In [None]:
from src.agent.adapters.llm import LLM
from src.agent.config import get_llm_config
from src.agent.domain import sql_model
import uuid

def test_sql_pipeline_stages(question: str):
    """Test each stage of the SQL pipeline individually."""
    
    q_id = f"debug-{str(uuid.uuid4())}"
    sql_question = commands.SQLQuestion(question=question, q_id=q_id)
    
    llm = LLM(get_llm_config())
    agent = sql_model.SQLBaseAgent(
        question=sql_question,
        kwargs={"schema_info": schema}
    )
    
    print(f"Testing SQL Pipeline for: {question}")
    print("=" * 80)
    
    # Stage 1: Pre-check
    print("\n1. PRE-CHECK STAGE")
    check_cmd = agent.update(sql_question)
    check_response = llm.use(check_cmd.question, response_model=commands.GuardrailPreCheckModel)
    print(f"   Approved: {check_response.approved}")
    if not check_response.approved:
        print(f"   Reason: {check_response.reason}")
        return
    
    # Stage 2: Grounding
    print("\n2. GROUNDING STAGE")
    check_result = commands.SQLCheck(question=question, q_id=q_id, approved=True)
    grounding_cmd = agent.update(check_result)
    grounding_response = llm.use(grounding_cmd.question, response_model=commands.SQLGrounding)
    
    print("   Table Mappings:")
    for tm in grounding_response.table_mapping:
        print(f"     '{tm.question_term}' -> {tm.table_name} (confidence: {tm.confidence})")
    
    print("   Column Mappings:")
    for cm in grounding_response.column_mapping:
        print(f"     '{cm.question_term}' -> {cm.table_name}.{cm.column_name} (confidence: {cm.confidence})")
    
    # Update agent state
    agent.construction.table_mapping = grounding_response.table_mapping
    agent.construction.column_mapping = grounding_response.column_mapping
    
    # Stage 3: Filter
    print("\n3. FILTER STAGE")
    filter_cmd = agent.update(grounding_response)
    filter_response = llm.use(filter_cmd.question, response_model=commands.SQLFilter)
    
    print("   Filter Conditions:")
    for cond in filter_response.conditions:
        print(f"     {cond.column} {cond.operator} {cond.value}")
    
    agent.construction.conditions = filter_response.conditions
    
    # Stage 4: Join Inference
    print("\n4. JOIN INFERENCE STAGE")
    join_cmd = agent.update(filter_response)
    join_response = llm.use(join_cmd.question, response_model=commands.SQLJoinInference)
    
    print("   Joins:")
    for join in join_response.joins:
        print(f"     {join.from_table}.{join.from_column} -> {join.to_table}.{join.to_column} ({join.join_type})")
    
    agent.construction.joins = join_response.joins
    
    # Stage 5: Aggregation
    print("\n5. AGGREGATION STAGE")
    agg_cmd = agent.update(join_response)
    agg_response = llm.use(agg_cmd.question, response_model=commands.SQLAggregation)
    
    print(f"   Is Aggregation Query: {agg_response.is_aggregation_query}")
    if agg_response.aggregations:
        print("   Aggregations:")
        for agg in agg_response.aggregations:
            print(f"     {agg.function}({agg.column}) as {agg.alias}")
    if agg_response.group_by_columns:
        print(f"   Group By: {', '.join(agg_response.group_by_columns)}")
    
    agent.construction.aggregations = agg_response.aggregations
    agent.construction.group_by_columns = agg_response.group_by_columns
    agent.construction.is_aggregation_query = agg_response.is_aggregation_query
    
    # Stage 6: Construction
    print("\n6. CONSTRUCTION STAGE")
    construction_cmd = agent.update(agg_response)
    construction_response = llm.use(construction_cmd.question, response_model=commands.SQLConstruction)
    
    print("   Generated SQL:")
    print(f"   {construction_response.sql_query}")
    
    return construction_response.sql_query

# Test the pipeline
sql = test_sql_pipeline_stages("Show me the top 5 products by revenue last quarter")

## Load and Test Existing YAML Fixtures

Load existing test cases from YAML files to verify they work correctly.

In [None]:
from evals.utils import load_yaml_fixtures

# Load existing fixtures
existing_fixtures = load_yaml_fixtures(Path("../evals/sql_agent"), "e2e")

print(f"Found {len(existing_fixtures)} existing test cases:\n")

# Display first few test cases
for i, (name, fixture) in enumerate(list(existing_fixtures.items())[:3]):
    print(f"{i+1}. {name}")
    print(f"   Question: {fixture['question']}")
    print(f"   Expected SQL: {fixture['sql'][:100]}..." if len(fixture['sql']) > 100 else f"   Expected SQL: {fixture['sql']}")
    print()

In [None]:
# Test a specific fixture
if existing_fixtures:
    test_name = list(existing_fixtures.keys())[0]
    fixture = existing_fixtures[test_name]
    
    print(f"Testing fixture: {test_name}\n")
    
    sql, judge = test_sql_question(
        question=fixture['question'],
        expected_sql=fixture['sql']
    )

## Performance Testing

Test the performance of SQL generation with different complexity levels.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define queries of increasing complexity
complexity_tests = [
    {"level": "Simple", "question": "Show all customers"},
    {"level": "Filter", "question": "Show customers from California"},
    {"level": "Join", "question": "Show orders with customer names"},
    {"level": "Aggregate", "question": "Count orders by customer"},
    {"level": "Complex", "question": "Show top 10 customers by total order value with their most recent order date"}
]

# Run performance tests
performance_results = []

for test in complexity_tests:
    print(f"Testing {test['level']} query...")
    
    # Time the execution
    start_time = time.time()
    sql, _ = test_sql_question(test['question'])
    execution_time = time.time() - start_time
    
    performance_results.append({
        "level": test['level'],
        "time": execution_time,
        "sql_length": len(sql) if sql else 0
    })
    
    # Wait between tests
    time.sleep(2)

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Execution time
levels = [r['level'] for r in performance_results]
times = [r['time'] for r in performance_results]
ax1.bar(levels, times)
ax1.set_xlabel('Query Complexity')
ax1.set_ylabel('Execution Time (seconds)')
ax1.set_title('SQL Generation Time by Complexity')

# SQL length
lengths = [r['sql_length'] for r in performance_results]
ax2.bar(levels, lengths)
ax2.set_xlabel('Query Complexity')
ax2.set_ylabel('SQL Query Length (characters)')
ax2.set_title('Generated SQL Length by Complexity')

plt.tight_layout()
plt.show()

# Summary statistics
print("\nPerformance Summary:")
print(f"Average execution time: {np.mean(times):.2f}s")
print(f"Min/Max execution time: {np.min(times):.2f}s / {np.max(times):.2f}s")

## Interactive Testing

Use this cell to interactively test SQL questions.

In [None]:
# Interactive testing - modify the question and run the cell
interactive_question = "What are the total sales by product category for last month?"

sql, _ = test_sql_question(interactive_question)

## Save Session Results

Save your testing session results for future reference.

In [None]:
# Save session results
session_results = {
    "timestamp": int(time.time()),
    "tests_run": len(performance_results) if 'performance_results' in locals() else 0,
    "model_info": {
        "llm_model_id": os.environ.get("llm_model_id", "unknown"),
        "llm_temperature": os.environ.get("llm_temperature", "unknown")
    }
}

# Save to file
results_file = f"sql_eval_session_{session_results['timestamp']}.json"
with open(results_file, 'w') as f:
    json.dump(session_results, f, indent=2)

print(f"Session results saved to: {results_file}")