## Setup and Imports

In [1]:
import os
import sys
import logging
from pathlib import Path

# Add the project root to Python path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print(f"Project root: {project_root}")
print(f"Current working directory: {os.getcwd()}")

Project root: /home/vlad/dev/data-analyser
Current working directory: /home/vlad/dev/data-analyser/notebooks


In [2]:
from src.agent.agent import DataAnalysisAgent
from src.models.schemas import JiraTicket, TicketStatus
from src.tools.sql_tool import SQLTool
from dotenv import load_dotenv

from src.tools.validator_tool import ValidatorTool
from src.models.schemas import SQLQuery, ValidationResult

load_dotenv()

True

## Initialize Agent and SQL Tool

In [3]:
# agent
agent = DataAnalysisAgent(config_path=str(project_root / "config" / "config.yaml"))

# tools
sql_tool = SQLTool(llm=agent.llm)
validator_tool = ValidatorTool(llm=agent.llm, schema_dict=agent.schema)

In [4]:
# display database schema
print("📋 Available Database Schema:")
print("=" * 50)

for table_name, columns in agent.schema.items():
    print(f"\n🗂️  Table: {table_name}")
    print("\tColumns:")
    for column in columns:
        print(f"\t - {column['column_name']} ({column['data_type']})")
        
print(f"\n📊 Total tables: {len(agent.schema)}")

📋 Available Database Schema:

🗂️  Table: models
	Columns:
	 - model_id (INTEGER)
	 - model_name (TEXT)
	 - model_code (TEXT)
	 - production_start_year (INTEGER)
	 - production_end_year (INTEGER)
	 - segment (TEXT)
	 - base_price (REAL)
	 - horsepower (INTEGER)
	 - body_type (TEXT)
	 - is_electric (INTEGER)
	 - description (TEXT)

🗂️  Table: dealerships
	Columns:
	 - dealership_id (INTEGER)
	 - name (TEXT)
	 - address (TEXT)
	 - city (TEXT)
	 - country (TEXT)
	 - region (TEXT)
	 - opening_date (TEXT)
	 - service_center (INTEGER)
	 - sales_capacity (INTEGER)
	 - rating (REAL)
	 - manager_name (TEXT)

🗂️  Table: customers
	Columns:
	 - customer_id (INTEGER)
	 - first_name (TEXT)
	 - last_name (TEXT)
	 - email (TEXT)
	 - phone (TEXT)
	 - address (TEXT)
	 - city (TEXT)
	 - country (TEXT)
	 - date_of_birth (TEXT)
	 - registration_date (TEXT)
	 - loyalty_points (INTEGER)
	 - preferred_dealership_id (INTEGER)

🗂️  Table: sales
	Columns:
	 - sale_id (INTEGER)
	 - customer_id (INTEGER)
	 - deale

In [5]:
# tickets to test
JIRA_PROJECT_KEY="KAN"

tickets = [
    {
        'project': JIRA_PROJECT_KEY,
        'summary': 'Car Models Analysis',
        'description': 'How many unqiue car models we have per car category? Sort the results in descending order!',
        'issuetype': 'Task',
    },
    {
        'project': JIRA_PROJECT_KEY,
        'summary': 'Dealership Performance by Region Analysis',
        'description': 'Analyze the average dealership rating and sales capacity by region. Which regions have the highest performing dealerships? Sort the results by average rating in descending order.',
        'issuetype': 'Task',
    },
    {
        'project': JIRA_PROJECT_KEY,
        'summary': 'Service Cost Analysis by Model and Service Type',
        'description': 'Analyze the average service costs by model and service type. Identify which models have higher maintenance costs and which service types contribute most to overall service revenue.',
        'issuetype': 'Task',
    },
    # irrelevant task that doesn't match the schema
    {
        "id": "DA-101",
        "summary": "Total Sales Overview",
        "description": "What is the average user basket size?"
    },
]

### SQL Generation

In [6]:
sql_queires = [
    sql_tool.generate_query(task_description=ticket['description'], schema_dict=agent.schema)
    for ticket in tickets
]

2025-07-02 11:43:25,431 - src.tools.sql_tool - INFO - Generating SQL query for task: How many unqiue car models we have per car category? Sort the results in descending order!
2025-07-02 11:43:27,943 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-07-02 11:43:30,859 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-07-02 11:43:30,861 - src.tools.sql_tool - INFO - Successfully generated SQL query
2025-07-02 11:43:30,862 - src.tools.sql_tool - INFO - Generating SQL query for task: Analyze the average dealership rating and sales capacity by region. Which regions have the highest performing dealerships? Sort the results by average rating in descending order.
2025-07-02 11:43:33,714 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-07-02 11:43:37,675 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1

In [7]:
sql_query = sql_queires[0]

print('ReturnedType -> ', type(sql_query))
print('\nGenerated SQL:\n', sql_query.query)
print('\nDescription:\n', sql_query.description)
print('\nUsed Tables:\n', sql_query.tables_used)

ReturnedType ->  <class 'src.models.schemas.SQLQuery'>

Generated SQL:
 SELECT segment, COUNT(DISTINCT model_id) AS unique_models
FROM models
GROUP BY segment
ORDER BY unique_models DESC;

Description:
 This query counts the number of unique 'model_id' for each 'segment' in the 'models' table. It then orders the results in descending order based on the count of unique 'model_id'.

Used Tables:
 ['models']


### SQL Validation
Before moving to more complex queries, let's see how the validation system works. The validator ensures that generated SQL queries are safe, syntactically correct, and use valid schema elements.

In [8]:
sql_queires

[SQLQuery(query='SELECT segment, COUNT(DISTINCT model_id) AS unique_models\nFROM models\nGROUP BY segment\nORDER BY unique_models DESC;', description="This query counts the number of unique 'model_id' for each 'segment' in the 'models' table. It then orders the results in descending order based on the count of unique 'model_id'.", tables_used=['models']),
 SQLQuery(query='SELECT \n    region, \n    AVG(rating) AS average_rating, \n    AVG(sales_capacity) AS average_sales_capacity \nFROM \n    dealerships \nGROUP BY \n    region \nORDER BY \n    average_rating DESC;', description="This query selects the 'region', the average 'rating', and the average 'sales_capacity' from the 'dealerships' table. It groups the results by 'region' and orders them in descending order based on the average 'rating'.", tables_used=['dealerships']),
 SQLQuery(query='SELECT \n    m.model_name, \n    sr.service_type, \n    AVG(sr.cost) AS average_service_cost, \n    SUM(sr.cost) AS total_service_revenue\nFROM \

In [6]:
queries = [
"""
SELECT segment, COUNT(DISTINCT model_id) AS unique_models
FROM models
GROUP BY segment
ORDER BY unique_models DESC;
""",

"""
SELECT 
    region, 
    AVG(rating) AS average_rating, 
    AVG(sales_capacity) AS average_sales_capacity 
FROM 
    dealerships 
GROUP BY 
    region 
ORDER BY 
    average_rating DESC;
""",


"""
SELECT 
    m.model_name, 
    sr.service_type, 
    AVG(sr.cost) AS average_service_cost, 
    SUM(sr.cost) AS total_service_revenue
FROM 
    service_records sr
JOIN 
    sales s ON sr.vin = s.vin
JOIN 
    models m ON s.model_id = m.model_id
GROUP BY 
    m.model_name, 
    sr.service_type
ORDER BY 
    total_service_revenue DESC;
""",

"""
SELECT AVG(basket_size) AS average_basket_size
FROM (
  SELECT COUNT(sale_id) AS basket_size
  FROM sales
  GROUP BY customer_id
) AS baskets;
"""
]

In [7]:
for sql_query in queries:
    print('\nQuery:\n', sql_query)

    # validation steps
    syntax_check = validator_tool.check_syntax(sql_query)
    dangerous_code_check = validator_tool.check_dangerous_patterns(sql_query)
    shema_check = validator_tool.check_schema_compatibility(sql_query)

    if syntax_check[0]:
        print("✅ Syntax check passed")
    else:
        print("❌ Syntax check failed:", syntax_check[1])

    if dangerous_code_check[0]:
        print("✅ Dangerous code check passed")
    else:
        print("❌ Dangerous code check failed:", dangerous_code_check[1])

    if shema_check[0]:
        print("✅ Schema compatibility check passed")
    else:
        print("❌ Schema compatibility check failed:", shema_check[1])



Query:
 
SELECT segment, COUNT(DISTINCT model_id) AS unique_models
FROM models
GROUP BY segment
ORDER BY unique_models DESC;

✅ Syntax check passed
✅ Dangerous code check passed
✅ Schema compatibility check passed

Query:
 
SELECT 
    region, 
    AVG(rating) AS average_rating, 
    AVG(sales_capacity) AS average_sales_capacity 
FROM 
    dealerships 
GROUP BY 
    region 
ORDER BY 
    average_rating DESC;

✅ Syntax check passed
✅ Dangerous code check passed
✅ Schema compatibility check passed

Query:
 
SELECT 
    m.model_name, 
    sr.service_type, 
    AVG(sr.cost) AS average_service_cost, 
    SUM(sr.cost) AS total_service_revenue
FROM 
    service_records sr
JOIN 
    sales s ON sr.vin = s.vin
JOIN 
    models m ON s.model_id = m.model_id
GROUP BY 
    m.model_name, 
    sr.service_type
ORDER BY 
    total_service_revenue DESC;

✅ Syntax check passed
✅ Dangerous code check passed
✅ Schema compatibility check passed

Query:
 
SELECT AVG(basket_size) AS average_basket_size
FROM (

In [20]:
sql_query = queries[3]

errors = []
warnings = [] # TODO: align in the future if it's needed
suggestion = None

# syntax validation
syntax_valid, syntax_error = validator_tool.check_syntax(sql_query)
if not syntax_valid:
    errors.append(syntax_error)
    
# dangerous patterns validation
safe, safety_error = validator_tool.check_dangerous_patterns(sql_query)
if not safe:
    errors.append(safety_error)
    
# schema compatibility validation
schema_valid, schema_error = validator_tool.check_schema_compatibility(sql_query)
if not schema_valid:
    errors.append(schema_error)

# final validation
is_valid = len(errors) == 0

is_valid, warnings, suggestion

(True, [], None)

In [21]:
ValidationResult(
    is_valid=is_valid,
    errors=errors,
    warnings=warnings,
    suggestion=None
)



In [None]:
# complete check
val_results = []

for i in range(len(queries)):
    sql_query = queries[i]
    task_description = tickets[i]['description']

    val_results.append(
        validator_tool.validate_sql(sql_query=sql_query, task_description=task_description)
    )

val_results

### Validation Testing

In [26]:
test_cases = [
    {
        "name": "❌ Incompatible Schema",
        "query": "SELECT region, SUM(sales_amount) as total_sales FROM sales_data GROUP BY region ORDER BY total_sales DESC",
        "task": "Show total sales by region",
        "expected": "VALID"
    },
    {
        "name": "❌ Dangerous DROP Query",
        "query": "DROP TABLE sales_data",
        "task": "Remove sales data table",
        "expected": "INVALID - Dangerous operation"
    },
    {
        "name": "❌ Dangerous DELETE Query", 
        "query": "DELETE FROM sales_data WHERE region = 'North'",
        "task": "Remove northern region data",
        "expected": "INVALID - Dangerous operation"
    },
    {
        "name": "❌ Invalid Table Reference",
        "query": "SELECT * FROM nonexistent_table",
        "task": "Query non-existent table",
        "expected": "INVALID - Schema validation"
    },
    {
        "name": "❌ Syntax Error Query",
        "query": "SEL * FORM sales_data WHRE region = 'North'",
        "task": "Query with multiple typos",
        "expected": "INVALID - Syntax error"
    },
]

In [27]:
val_results = []

for i in range(len(test_cases)):
    sql_query = test_cases[i]['query']
    task_description = test_cases[i]['task']

    val_results.append(
        validator_tool.validate_sql(sql_query=sql_query, task_description=task_description)
    )

val_results

2025-07-02 12:08:44,186 - Parser - ERROR - Not supported query type: SEL * FORM


