In [1]:
# ===== IMPORTS =====
import sqlite3
import pandas as pd
from typing import Dict, List, Tuple, Any
import json

# ===== CREATE SAMPLE DATABASE =====
def create_sample_database():
    """Create sample database with employee and department data."""
    conn = sqlite3.connect('company.db')
    cursor = conn.cursor()

    # Create tables
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS departments (
            dept_id INTEGER PRIMARY KEY,
            dept_name TEXT NOT NULL,
            location TEXT,
            budget INTEGER
        )
    ''')

    cursor.execute('''
        CREATE TABLE IF NOT EXISTS employees (
            emp_id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            dept_id INTEGER,
            salary INTEGER,
            hire_date TEXT,
            FOREIGN KEY (dept_id) REFERENCES departments(dept_id)
        )
    ''')

    # Insert department data
    departments = [
        (1, 'Engineering', 'San Francisco', 5000000),
        (2, 'Marketing', 'New York', 2000000),
        (3, 'Sales', 'Chicago', 3000000),
        (4, 'HR', 'Boston', 1000000)
    ]
    cursor.executemany('INSERT OR REPLACE INTO departments VALUES (?, ?, ?, ?)', departments)

    # Insert employee data
    employees = [
        (1, 'Alice Johnson', 1, 120000, '2022-01-15'),
        (2, 'Bob Smith', 1, 110000, '2022-03-20'),
        (3, 'Carol White', 2, 85000, '2021-06-10'),
        (4, 'David Brown', 2, 90000, '2023-02-01'),
        (5, 'Eve Davis', 3, 95000, '2022-07-12'),
        (6, 'Frank Miller', 3, 88000, '2023-05-18'),
        (7, 'Grace Lee', 1, 125000, '2020-11-05'),
        (8, 'Henry Wilson', 4, 75000, '2023-08-22'),
        (9, 'Iris Martinez', 3, 92000, '2022-09-30'),
        (10, 'Jack Taylor', 1, 115000, '2021-12-14')
    ]
    cursor.executemany('INSERT OR REPLACE INTO employees VALUES (?, ?, ?, ?, ?)', employees)

    conn.commit()
    conn.close()
    print("Database 'company.db' created successfully!")

# ===== QUERY INTENTS (HARD-CODED) =====
query_intents = [
    {
        "intent_id": "Q001",
        "natural_language": "Show all employees in the Engineering department",
        "expected_sql": "SELECT e.name, e.salary, e.hire_date FROM employees e JOIN departments d ON e.dept_id = d.dept_id WHERE d.dept_name = 'Engineering'",
        "expected_row_count": 4
    },
    {
        "intent_id": "Q002",
        "natural_language": "Find the average salary by department",
        "expected_sql": "SELECT d.dept_name, AVG(e.salary) as avg_salary FROM employees e JOIN departments d ON e.dept_id = d.dept_id GROUP BY d.dept_name",
        "expected_row_count": 4
    },
    {
        "intent_id": "Q003",
        "natural_language": "List employees hired after 2022",
        "expected_sql": "SELECT name, hire_date, salary FROM employees WHERE hire_date > '2022-12-31' ORDER BY hire_date",
        "expected_row_count": 3
    },
    {
        "intent_id": "Q004",
        "natural_language": "Get total budget by location",
        "expected_sql": "SELECT location, SUM(budget) as total_budget FROM departments GROUP BY location",
        "expected_row_count": 4
    },
    {
        "intent_id": "Q005",
        "natural_language": "Find employees earning more than 100000",
        "expected_sql": "SELECT name, salary, dept_id FROM employees WHERE salary > 100000 ORDER BY salary DESC",
        "expected_row_count": 5
    }
]

# ===== SQL VALIDATOR CLASS =====
class SQLQueryValidator:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = None

    def connect(self):
        self.conn = sqlite3.connect(self.db_path)
        return self.conn

    def close(self):
        if self.conn:
            self.conn.close()

    def execute_query(self, sql: str) -> Tuple[bool, Any, str]:
        try:
            df = pd.read_sql_query(sql, self.conn)
            return True, df, ""
        except Exception as e:
            return False, None, str(e)

    def validate_query_intent(self, intent: Dict) -> Dict:
        result = {
            "intent_id": intent["intent_id"],
            "natural_language": intent["natural_language"],
            "sql_executed": intent["expected_sql"],
            "success": False,
            "row_count": 0,
            "expected_row_count": intent["expected_row_count"],
            "row_count_match": False,
            "error": None,
            "sample_results": []
        }
        success, data, error = self.execute_query(intent["expected_sql"])
        if success:
            result["success"] = True
            result["row_count"] = len(data)
            result["row_count_match"] = (len(data) == intent["expected_row_count"])
            result["sample_results"] = data.head(3).to_dict('records')
        else:
            result["error"] = error
        return result

    def generate_validation_report(self, intents: List[Dict]) -> Dict:
        self.connect()
        results = []
        successful_queries = 0
        row_count_matches = 0

        for intent in intents:
            validation_result = self.validate_query_intent(intent)
            results.append(validation_result)
            if validation_result["success"]:
                successful_queries += 1
            if validation_result["row_count_match"]:
                row_count_matches += 1

        self.close()
        total_queries = len(intents)
        report = {
            "summary": {
                "total_queries": total_queries,
                "successful_executions": successful_queries,
                "failed_executions": total_queries - successful_queries,
                "row_count_matches": row_count_matches,
                "execution_rate": f"{(successful_queries/total_queries*100):.2f}%",
                "accuracy_rate": f"{(row_count_matches/total_queries*100):.2f}%"
            },
            "detailed_results": results
        }
        return report

# ===== DISPLAY REPORT =====
def display_report(report: Dict):
    print("=" * 80)
    print("SQL QUERY VALIDATION REPORT")
    print("=" * 80)
    summary = report["summary"]
    print(f"\nTotal Queries: {summary['total_queries']}")
    print(f"Successful Executions: {summary['successful_executions']}")
    print(f"Failed Executions: {summary['failed_executions']}")
    print(f"Row Count Matches: {summary['row_count_matches']}")
    print(f"Execution Rate: {summary['execution_rate']}")
    print(f"Accuracy Rate: {summary['accuracy_rate']}")

    print("\n" + "=" * 80)
    print("DETAILED QUERY RESULTS")
    print("=" * 80)
    for result in report["detailed_results"]:
        print(f"\n[{result['intent_id']}] {result['natural_language']}")
        print(f"Status: {'✓ SUCCESS' if result['success'] else '✗ FAILED'}")
        if result['success']:
            print(f"Rows Returned: {result['row_count']} (Expected: {result['expected_row_count']})")
            match_status = "✓ MATCH" if result['row_count_match'] else "✗ MISMATCH"
            print(f"Row Count Validation: {match_status}")
            if result['sample_results']:
                print("\nSample Results (first 3 rows):")
                for i, row in enumerate(result['sample_results'], 1):
                    print(f"  Row {i}: {row}")
        else:
            print(f"Error: {result['error']}")
        print("-" * 80)

# ===== MAIN =====
def main():
    create_sample_database()
    validator = SQLQueryValidator('company.db')
    report = validator.generate_validation_report(query_intents)
    display_report(report)

    # Save report
    with open('sql_validation_report.json', 'w') as f:
        json.dump(report, f, indent=2)
    print("\nReport saved to 'sql_validation_report.json'")

if __name__ == "__main__":
    main()


Database 'company.db' created successfully!
SQL QUERY VALIDATION REPORT

Total Queries: 5
Successful Executions: 5
Failed Executions: 0
Row Count Matches: 4
Execution Rate: 100.00%
Accuracy Rate: 80.00%

DETAILED QUERY RESULTS

[Q001] Show all employees in the Engineering department
Status: ✓ SUCCESS
Rows Returned: 4 (Expected: 4)
Row Count Validation: ✓ MATCH

Sample Results (first 3 rows):
  Row 1: {'name': 'Alice Johnson', 'salary': 120000, 'hire_date': '2022-01-15'}
  Row 2: {'name': 'Bob Smith', 'salary': 110000, 'hire_date': '2022-03-20'}
  Row 3: {'name': 'Grace Lee', 'salary': 125000, 'hire_date': '2020-11-05'}
--------------------------------------------------------------------------------

[Q002] Find the average salary by department
Status: ✓ SUCCESS
Rows Returned: 4 (Expected: 4)
Row Count Validation: ✓ MATCH

Sample Results (first 3 rows):
  Row 1: {'dept_name': 'Engineering', 'avg_salary': 117500.0}
  Row 2: {'dept_name': 'HR', 'avg_salary': 75000.0}
  Row 3: {'dept_name'