# DIVA-SQL Framework Demo

This notebook demonstrates the key features of the DIVA-SQL framework:
- Semantic decomposition of natural language queries
- Step-by-step SQL clause generation
- In-line verification and error correction
- Interpretable query generation process

In [None]:
# Import required libraries
import sys
import json
from pathlib import Path

# Add src to path
sys.path.append(str(Path.cwd().parent / "src"))

from src.core.pipeline import DIVASQLPipeline
from src.core.semantic_dag import SemanticDAG, SemanticNode, NodeType
from src.utils.error_taxonomy import ErrorTaxonomy, analyze_sql_errors

## Setup Mock LLM Client

For demonstration purposes, we'll use a mock LLM client that returns predefined responses.

In [None]:
class MockLLMClient:
    """Mock LLM client for demonstration"""
    
    def __init__(self):
        self.responses = {
            "decomposition": {
                "query_type": "COUNT",
                "complexity_indicators": ["FILTER", "GROUP", "JOIN"],
                "estimated_steps": 4,
                "reasoning": "This query requires filtering, joining, grouping, and counting"
            },
            "components": {
                "components": [
                    {
                        "type": "FILTER",
                        "description": "Filter employees hired after 2022",
                        "tables": ["Employees"],
                        "columns": ["HireDate"],
                        "priority": 1
                    },
                    {
                        "type": "JOIN",
                        "description": "Join employees with departments",
                        "tables": ["Employees", "Departments"],
                        "columns": ["DeptID"],
                        "priority": 2
                    },
                    {
                        "type": "GROUP",
                        "description": "Group by department",
                        "tables": ["Employees"],
                        "columns": ["DeptID"],
                        "priority": 3
                    },
                    {
                        "type": "SELECT",
                        "description": "Select department names with count > 10",
                        "tables": ["Departments"],
                        "columns": ["DeptName"],
                        "priority": 4
                    }
                ]
            },
            "sql_generation": {
                "sql_clause": "WHERE T1.HireDate > '2022-01-01'",
                "explanation": "Filter for employees hired after 2022",
                "confidence": 0.9
            },
            "verification": {
                "is_aligned": True,
                "issues": [],
                "confidence": 0.9
            },
            "final_composition": {
                "final_sql": "SELECT T2.DeptName FROM Employees AS T1 JOIN Departments AS T2 ON T1.DeptID = T2.DeptID WHERE T1.HireDate > '2022-01-01' GROUP BY T2.DeptID HAVING COUNT(*) > 10",
                "confidence": 0.9
            }
        }
    
    class Chat:
        def __init__(self, parent):
            self.parent = parent
        
        class Completions:
            def __init__(self, parent):
                self.parent = parent
            
            def create(self, **kwargs):
                # Determine response type based on prompt content
                prompt_content = kwargs.get('messages', [{}])[0].get('content', '')
                
                if "structure" in prompt_content.lower():
                    response_data = self.parent.parent.responses["decomposition"]
                elif "components" in prompt_content.lower():
                    response_data = self.parent.parent.responses["components"]
                elif "alignment" in prompt_content.lower():
                    response_data = self.parent.parent.responses["verification"]
                elif "compose" in prompt_content.lower():
                    response_data = self.parent.parent.responses["final_composition"]
                else:
                    response_data = self.parent.parent.responses["sql_generation"]
                
                class Response:
                    class Choice:
                        class Message:
                            content = json.dumps(response_data)
                    choices = [Choice()]
                
                return Response()
        
        def __init__(self, parent):
            self.completions = self.Completions(self)
            self.parent = parent
    
    def __init__(self):
        self.chat = self.Chat(self)

# Initialize mock client
mock_client = MockLLMClient()
print("Mock LLM client initialized successfully!")

## Example 1: Basic DIVA-SQL Pipeline

Let's demonstrate the complete DIVA-SQL pipeline with a complex query.

In [None]:
# Initialize DIVA-SQL pipeline
pipeline = DIVASQLPipeline(mock_client, model_name="gpt-4")

# Define sample database schema
schema = {
    "tables": {
        "Employees": ["EmpID", "Name", "DeptID", "HireDate", "Salary"],
        "Departments": ["DeptID", "DeptName", "ManagerID"],
        "Projects": ["ProjectID", "ProjectName", "Budget", "DeptID"]
    }
}

# Complex natural language query
nl_query = "What are the names of departments with more than 10 employees hired after 2022?"

print(f"Natural Language Query: {nl_query}")
print(f"Database Schema: {json.dumps(schema, indent=2)}")
print("\n" + "="*60 + "\n")

In [None]:
# Generate SQL using DIVA-SQL
result = pipeline.generate_sql(nl_query, schema)

print("DIVA-SQL Results:")
print(f"Status: {result.status.value}")
print(f"Execution Time: {result.execution_time:.2f} seconds")
print(f"Confidence Score: {result.confidence_score:.2f}")
print(f"\nGenerated SQL:")
print(result.final_sql)

if result.semantic_dag:
    print(f"\nSemantic Decomposition:")
    print(result.semantic_dag.visualize())

## Example 2: Semantic DAG Visualization

Let's create and visualize a semantic DAG manually to understand the decomposition process.

In [None]:
# Create a semantic DAG manually for educational purposes
dag = SemanticDAG("demo_query")

# Node 1: Filter employees by hire date
filter_node = SemanticNode(
    id="filter_employees",
    node_type=NodeType.FILTER,
    description="Filter employees hired after 2022",
    tables=["Employees"],
    columns=["HireDate"],
    conditions=["HireDate > '2022-01-01'"]
)

# Node 2: Join with departments
join_node = SemanticNode(
    id="join_departments",
    node_type=NodeType.JOIN,
    description="Join employees with departments",
    tables=["Employees", "Departments"],
    columns=["DeptID"]
)

# Node 3: Group by department
group_node = SemanticNode(
    id="group_by_dept",
    node_type=NodeType.GROUP,
    description="Group employees by department",
    tables=["Employees"],
    columns=["DeptID"]
)

# Node 4: Filter departments with count > 10
having_node = SemanticNode(
    id="filter_count",
    node_type=NodeType.HAVING,
    description="Keep departments with more than 10 employees",
    conditions=["COUNT(*) > 10"]
)

# Node 5: Select department names
select_node = SemanticNode(
    id="select_names",
    node_type=NodeType.SELECT,
    description="Select department names",
    tables=["Departments"],
    columns=["DeptName"]
)

# Add nodes to DAG
for node in [filter_node, join_node, group_node, having_node, select_node]:
    dag.add_node(node)

# Add dependencies (execution order)
dag.add_edge("filter_employees", "join_departments")
dag.add_edge("join_departments", "group_by_dept")
dag.add_edge("group_by_dept", "filter_count")
dag.add_edge("filter_count", "select_names")

print("Manual Semantic DAG:")
print(dag.visualize())

## Example 3: Error Taxonomy Demonstration

Let's explore the error taxonomy system that helps detect common SQL errors.

In [None]:
# Initialize error taxonomy
taxonomy = ErrorTaxonomy()

# Show taxonomy summary
summary = taxonomy.get_taxonomy_summary()
print("Error Taxonomy Summary:")
print(f"Total Error Patterns: {summary['total_patterns']}")
print(f"Categories: {list(summary['categories'].keys())}")
print(f"Severity Distribution: {summary['severity_distribution']}")

print("\nMost Common Error Categories:")
for category, count in summary['most_common_categories']:
    print(f"  {category}: {count} patterns")

In [None]:
# Test error detection on problematic SQL
problematic_sql = "SELECT Name, COUNT(*) FROM Employee WHERE EmpID = '123'"

print(f"Analyzing SQL: {problematic_sql}")
print("-" * 50)

analysis = analyze_sql_errors(problematic_sql, taxonomy)

print(f"Issues Found: {analysis['total_issues']}")
print(f"Risk Score: {analysis['risk_score']}")
print(f"Categories Affected: {analysis['categories_affected']}")
print("\nDetailed Issues:")

for pattern in analysis['patterns_matched']:
    print(f"  - {pattern['name']} ({pattern['severity']})")
    print(f"    {pattern['description']}")
    print(f"    Fix: {pattern['suggested_fix']}")
    print()

print("Recommendations:")
for rec in analysis['recommended_actions']:
    print(f"  {rec}")

## Example 4: Verification Process Simulation

Let's simulate the verification process that DIVA-SQL uses to check generated SQL.

In [None]:
from src.agents.verifier import VerificationAgent, VerificationStatus

# Initialize verification agent
verifier = VerificationAgent(mock_client)

# Test SQL clause
test_sql = "WHERE T1.HireDate > '2022-01-01'"

# Create corresponding semantic node
test_node = SemanticNode(
    id="test_filter",
    node_type=NodeType.FILTER,
    description="Filter employees hired after 2022",
    tables=["Employees"],
    columns=["HireDate"]
)

print(f"Testing SQL Clause: {test_sql}")
print(f"Semantic Intent: {test_node.description}")
print("-" * 40)

# Verify the clause
verification_result = verifier.verify_clause(test_node, test_sql, schema)

print(f"Verification Status: {verification_result.status.value}")
print(f"Confidence: {verification_result.confidence:.2f}")
print(f"Issues Found: {len(verification_result.issues)}")

if verification_result.detailed_feedback:
    print(f"\nDetailed Feedback:")
    print(verification_result.detailed_feedback)

## Example 5: Comparing Multiple Queries

Let's process multiple queries to see how DIVA-SQL handles different complexity levels.

In [None]:
# Define test queries of varying complexity
test_queries = [
    "Show me all employee names",
    "How many employees are in each department?",
    "Which departments have more than 5 employees?",
    "What is the average salary of employees hired after 2020 by department?",
    "Find departments with the highest average salary among employees hired in the last 2 years"
]

print("Processing Multiple Queries:")
print("=" * 50)

for i, query in enumerate(test_queries, 1):
    print(f"\nQuery {i}: {query}")
    print("-" * 30)
    
    # Process with DIVA-SQL
    result = pipeline.generate_sql(query, schema)
    
    print(f"Status: {result.status.value}")
    print(f"Generated SQL: {result.final_sql}")
    print(f"Confidence: {result.confidence_score:.2f}")
    print(f"Time: {result.execution_time:.2f}s")
    
    if result.semantic_dag:
        node_count = len(result.semantic_dag.nodes)
        print(f"Semantic Nodes: {node_count}")
        
        # Show the semantic steps
        print("Steps:")
        for node_id, node in result.semantic_dag.nodes.items():
            print(f"  - {node.description}")

## Conclusion

This demo showcased the key features of DIVA-SQL:

1. **Semantic Decomposition**: Breaking complex queries into interpretable steps
2. **Stepwise Generation**: Creating SQL clauses for individual semantic operations
3. **In-line Verification**: Checking each clause before composition
4. **Error Detection**: Using comprehensive taxonomy to identify common mistakes
5. **Interpretability**: Providing clear visibility into the reasoning process

These capabilities make DIVA-SQL more reliable, debuggable, and trustworthy compared to traditional monolithic Text-to-SQL approaches.