# Query Engine Testing Notebook

This notebook tests the SQL query generation functionality using Gemini.

In [None]:
import sys
import os

# Add backend to path
sys.path.append(os.path.join(os.getcwd(), '..', 'backend'))

from dotenv import load_dotenv
load_dotenv('../.env')

## Test Gemini Base Class

In [None]:
from shared.gemini_base import GeminiBase, GeminiConfig, quick_gemini_call

# Test quick function
response = quick_gemini_call(
    prompt="What is SQL?",
    system_prompt="You are a helpful database assistant.",
    temperature=0.3
)
print("Quick call response:", response)

## Test SQL Query Generator

In [None]:
from query_engine.llm_service import SQLQueryGenerator

# Initialize the SQL generator
sql_gen = SQLQueryGenerator()

# Test basic query generation
natural_query = "Show me all users who signed up last month"
sql_result = sql_gen.generate_sql(natural_query)

print(f"Natural Language: {natural_query}")
print(f"Generated SQL:\n{sql_result}")

In [None]:
# Test with schema context
schema_context = """
Table: users
Columns: id (INT), name (VARCHAR), email (VARCHAR), created_at (DATETIME)

Table: orders
Columns: id (INT), user_id (INT), total (DECIMAL), order_date (DATETIME)
"""

complex_query = "Find the top 5 customers by total order value in the last 6 months"
sql_with_schema = sql_gen.generate_sql(complex_query, schema_context=schema_context)

print(f"Natural Language: {complex_query}")
print(f"Generated SQL:\n{sql_with_schema}")

In [None]:
# Test async functionality
import asyncio

async def test_async_sql():
    query = "Count the number of active users"
    result = await sql_gen.generate_sql_async(query)
    return result

# Run async test
async_result = await test_async_sql()
print(f"Async SQL Result:\n{async_result}")

## Test Different Query Types

In [None]:
# Test various query types
test_queries = [
    "Show me all products with price greater than $100",
    "Calculate average order value by month",
    "Find users who haven't placed an order in the last 3 months",
    "Get the most popular products by category",
    "Show daily sales trends for the last week"
]

for i, query in enumerate(test_queries, 1):
    print(f"\n--- Test {i} ---")
    print(f"Query: {query}")
    
    try:
        sql = sql_gen.generate_sql(query, temperature=0.2)
        print(f"SQL:\n{sql}")
    except Exception as e:
        print(f"Error: {e}")

## Performance Testing

In [None]:
import time

def measure_performance(query, iterations=3):
    times = []
    results = []
    
    for i in range(iterations):
        start_time = time.time()
        result = sql_gen.generate_sql(query, temperature=0.1)
        end_time = time.time()
        
        times.append(end_time - start_time)
        results.append(result)
    
    avg_time = sum(times) / len(times)
    print(f"Query: {query}")
    print(f"Average time: {avg_time:.2f}s")
    print(f"Results consistent: {all(r.strip() == results[0].strip() for r in results)}")
    print(f"Sample result:\n{results[0]}\n")

# Test performance
measure_performance("Select all users created today")
measure_performance("Calculate monthly revenue by product category")