**# Install required libraries**

In [None]:
!pip install langchain langchain-openai langchain-community sqlalchemy matplotlib pandas numpy

import os
import re
import json
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt

# Set up OpenAI API key
if "OPENAI_API_KEY" not in os.environ:
    from getpass import getpass
    os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")

# Import necessary components
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


**# Create a simple database for demonstration**

In [None]:
print("Setting up a sample database for our examples...")

# Create a simple database for demonstration
def create_sample_database(db_path="retail_store.db"):
    """Create a sample retail database with basic tables."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Create customers table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS customers (
        customer_id INTEGER PRIMARY KEY,
        first_name TEXT NOT NULL,
        last_name TEXT NOT NULL,
        email TEXT UNIQUE NOT NULL,
        registration_date DATE NOT NULL,
        city TEXT,
        state TEXT,
        lifetime_value REAL
    )
    ''')

    # Create products table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS products (
        product_id INTEGER PRIMARY KEY,
        product_name TEXT NOT NULL,
        category TEXT NOT NULL,
        price REAL NOT NULL,
        inventory_count INTEGER NOT NULL,
        description TEXT
    )
    ''')

    # Create orders table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS orders (
        order_id INTEGER PRIMARY KEY,
        customer_id INTEGER NOT NULL,
        order_date DATE NOT NULL,
        total_amount REAL NOT NULL,
        status TEXT NOT NULL,
        FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
    )
    ''')

    # Create order_items table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS order_items (
        order_item_id INTEGER PRIMARY KEY,
        order_id INTEGER NOT NULL,
        product_id INTEGER NOT NULL,
        quantity INTEGER NOT NULL,
        price_per_unit REAL NOT NULL,
        FOREIGN KEY (order_id) REFERENCES orders (order_id),
        FOREIGN KEY (product_id) REFERENCES products (product_id)
    )
    ''')

    conn.commit()
    return conn

# Sample data for our database
def populate_database(conn):
    """Add sample data to our database."""
    cursor = conn.cursor()

    # Add sample customers
    customers = [
        (1, 'John', 'Smith', 'john.smith@email.com', '2021-01-15', 'New York', 'NY', 1250.75),
        (2, 'Sarah', 'Johnson', 'sarah.j@email.com', '2021-02-20', 'Los Angeles', 'CA', 890.25),
        (3, 'Michael', 'Brown', 'michael.b@email.com', '2021-03-10', 'Chicago', 'IL', 1475.50),
        (4, 'Emily', 'Davis', 'emily.d@email.com', '2021-04-05', 'Houston', 'TX', 760.80),
        (5, 'David', 'Wilson', 'david.w@email.com', '2021-05-22', 'Phoenix', 'AZ', 2100.30)
    ]

    # Add sample products
    products = [
        (1, 'Laptop', 'Electronics', 899.99, 25, 'High-performance laptop with 16GB RAM'),
        (2, 'Smartphone', 'Electronics', 699.99, 40, '5G smartphone with 128GB storage'),
        (3, 'Headphones', 'Electronics', 149.99, 60, 'Noise-cancelling wireless headphones'),
        (4, 'Coffee Maker', 'Kitchen', 79.99, 30, 'Programmable coffee maker with timer'),
        (5, 'T-shirt', 'Clothing', 19.99, 100, 'Cotton t-shirt, available in multiple colors')
    ]

    # Add sample orders
    orders = [
        (1, 1, '2023-01-05', 949.98, 'Delivered'),  # John ordered a laptop
        (2, 2, '2023-01-12', 699.99, 'Delivered'),  # Sarah ordered a smartphone
        (3, 3, '2023-01-18', 229.98, 'Delivered'),  # Michael ordered headphones and shirts
        (4, 4, '2023-02-03', 79.99, 'Delivered'),   # Emily ordered a coffee maker
        (5, 5, '2023-02-15', 169.98, 'Delivered')   # David ordered headphones and a shirt
    ]

    # Add sample order items
    order_items = [
        (1, 1, 1, 1, 899.99),  # Order 1: 1 laptop
        (2, 1, 3, 1, 49.99),   # Order 1: 1 headphones
        (3, 2, 2, 1, 699.99),  # Order 2: 1 smartphone
        (4, 3, 3, 1, 149.99),  # Order 3: 1 headphones
        (5, 3, 5, 4, 19.99),   # Order 3: 4 t-shirts
        (6, 4, 4, 1, 79.99),   # Order 4: 1 coffee maker
        (7, 5, 3, 1, 149.99),  # Order 5: 1 headphones
        (8, 5, 5, 1, 19.99)    # Order 5: 1 t-shirt
    ]

    # Insert the data
    cursor.executemany('INSERT OR REPLACE INTO customers VALUES (?,?,?,?,?,?,?,?)', customers)
    cursor.executemany('INSERT OR REPLACE INTO products VALUES (?,?,?,?,?,?)', products)
    cursor.executemany('INSERT OR REPLACE INTO orders VALUES (?,?,?,?,?)', orders)
    cursor.executemany('INSERT OR REPLACE INTO order_items VALUES (?,?,?,?,?)', order_items)

    conn.commit()

# Create and populate our database
db_path = "retail_store.db"
conn = create_sample_database(db_path)
populate_database(conn)

# Connect using LangChain's SQLDatabase
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

# Simple way to view the database schema
print("\nDatabase Schema:")
print(db.get_table_info())

# Initialize our LLM
llm = ChatOpenAI(
    model_name="gpt-3.5-turbo",
    temperature=0  # Use 0 for more deterministic outputs
)

# Function to execute a query and return results
def execute_query(query, db_path, max_rows=5):
    """Execute a SQL query and return results."""
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(query)

        # Get column names
        columns = [description[0] for description in cursor.description]

        # Fetch results (limited to max_rows)
        results = cursor.fetchmany(max_rows)

        # Create a DataFrame for nice display
        df = pd.DataFrame(results, columns=columns)

        return {
            "success": True,
            "results": df,
            "message": f"Showing {len(results)} of {cursor.rowcount if cursor.rowcount >= 0 else '?'} rows"
        }
    except Exception as e:
        return {
            "success": False,
            "results": None,
            "message": f"Error executing query: {str(e)}"
        }
    finally:
        conn.close()

# Basic SQL generation function for comparison
def basic_sql_generation(question, db):
    """Generate SQL using a basic prompt."""

    basic_prompt = PromptTemplate.from_template(
        """Generate a SQL query to answer the following question:

        Database Schema:
        {schema}

        Question: {question}

        SQL Query:"""
    )

    chain = basic_prompt | llm | StrOutputParser()
    return chain.invoke({"schema": db.get_table_info(), "question": question})

# Test a simple query with the basic approach for comparison
test_question = "What are the total sales by product category?"
print(f"\nTesting basic approach with: '{test_question}'")
basic_sql = basic_sql_generation(test_question, db)
print(f"Generated SQL:\n{basic_sql}")

# Execute to see if it works
basic_result = execute_query(basic_sql, db_path)
print(f"Query execution {'succeeded' if basic_result['success'] else 'failed'}")
if basic_result["success"]:
    print(basic_result["results"])
else:
    print(basic_result["message"])


**# 11.3.1 Schema-aware prompting strategies**

In [None]:
print("\n" + "="*80)
print("SECTION 11.3.1: Schema-aware prompting strategies")
print("="*80)

def create_enhanced_schema(db):
    """Create a more detailed schema representation with relationships and semantics."""

    # Get basic schema information
    raw_schema = db.get_table_info()

    # Add semantic details and relationships
    enhanced_schema = """
    # Retail Database Schema with Semantic Information

    ## Table: customers
    Stores information about customers who have made purchases.

    - customer_id (INTEGER): Primary key, unique identifier for each customer
    - first_name (TEXT): Customer's first name
    - last_name (TEXT): Customer's last name
    - email (TEXT): Customer's email address (unique)
    - registration_date (DATE): Date when customer first registered
    - city (TEXT): Customer's city of residence
    - state (TEXT): Customer's state of residence (2-letter code)
    - lifetime_value (REAL): Total value of all purchases made by customer

    ## Table: products
    Contains details about products available for purchase.

    - product_id (INTEGER): Primary key, unique identifier for each product
    - product_name (TEXT): Name of the product
    - category (TEXT): Product category (e.g., Electronics, Clothing)
    - price (REAL): Current price of the product
    - inventory_count (INTEGER): Number of units in stock
    - description (TEXT): Detailed description of the product

    ## Table: orders
    Records of customer orders.

    - order_id (INTEGER): Primary key, unique identifier for each order
    - customer_id (INTEGER): Foreign key referencing customers table
    - order_date (DATE): Date when the order was placed
    - total_amount (REAL): Total monetary value of the order
    - status (TEXT): Current status of the order (e.g., Processing, Shipped, Delivered)

    ## Table: order_items
    Details of individual items within each order.

    - order_item_id (INTEGER): Primary key, unique identifier for each order item
    - order_id (INTEGER): Foreign key referencing orders table
    - product_id (INTEGER): Foreign key referencing products table
    - quantity (INTEGER): Number of units of the product ordered
    - price_per_unit (REAL): Price of the product at the time of purchase

    ## Key Relationships:
    - Customers place Orders (one-to-many relationship)
    - Orders contain Order Items (one-to-many relationship)
    - Order Items reference Products (many-to-one relationship)

    ## Semantic Information:
    - The lifetime_value in customers represents the total amount that customer has spent
    - The price in products is the current price, which may differ from price_per_unit in order_items (historical price)
    - The total_amount in orders should equal the sum of (quantity * price_per_unit) for all items in that order
    - The status in orders follows a sequence: Processing → Shipped → Delivered
    """

    return enhanced_schema

def schema_aware_sql_generation(question, db):
    """Generate SQL using schema-aware prompting."""

    # Get enhanced schema
    enhanced_schema = create_enhanced_schema(db)

    schema_aware_prompt = PromptTemplate.from_template(
        """You are an expert SQL query generator for a retail database.

        Given the schema below, write a SQL query to answer the user's question.

        {schema}

        User Question: {question}

        Important Guidelines:
        1. Always use proper table aliases when joining tables (e.g., c for customers)
        2. Round monetary values to 2 decimal places
        3. For date comparisons, use proper format (YYYY-MM-DD)
        4. Always check for NULL values where appropriate
        5. DO NOT include markdown formatting like ```sql in your response
        6. Just return the raw SQL query without any explanations

        SQL Query:"""
    )

    chain = schema_aware_prompt | llm | StrOutputParser()
    sql = chain.invoke({"schema": enhanced_schema, "question": question})

    # Clean up any markdown formatting that might be present
    sql = sql.replace("```sql", "").replace("```", "").strip()

    return sql

# Test the schema-aware approach
print(f"\nTesting schema-aware approach with: '{test_question}'")
schema_aware_sql = schema_aware_sql_generation(test_question, db)
print(f"Generated SQL:\n{schema_aware_sql}")

# Execute to see if it works
schema_aware_result = execute_query(schema_aware_sql, db_path)
print(f"Query execution {'succeeded' if schema_aware_result['success'] else 'failed'}")
if schema_aware_result["success"]:
    print(schema_aware_result["results"])
else:
    print(schema_aware_result["message"])

**# 11.3.2 Few-shot examples for complex queries**

In [None]:
print("\n" + "="*80)
print("SECTION 11.3.2: Few-shot examples for complex queries")
print("="*80)

def create_few_shot_examples():
    """Create example question-SQL pairs for different query types."""

    examples = {
        "aggregation": {
            "question": "What's the average order value for each customer?",
            "sql": """
                SELECT
                    c.customer_id,
                    c.first_name,
                    c.last_name,
                    ROUND(AVG(o.total_amount), 2) as avg_order_value
                FROM customers c
                JOIN orders o ON c.customer_id = o.customer_id
                GROUP BY c.customer_id, c.first_name, c.last_name
                ORDER BY avg_order_value DESC;
            """
        },
        "filtering": {
            "question": "Which customers from California ordered electronics products?",
            "sql": """
                SELECT DISTINCT
                    c.customer_id,
                    c.first_name,
                    c.last_name
                FROM customers c
                JOIN orders o ON c.customer_id = o.customer_id
                JOIN order_items oi ON o.order_id = oi.order_id
                JOIN products p ON oi.product_id = p.product_id
                WHERE c.state = 'CA'
                AND p.category = 'Electronics';
            """
        },
        "temporal": {
            "question": "How many orders were placed in January 2023?",
            "sql": """
                SELECT
                    COUNT(*) as order_count
                FROM orders
                WHERE order_date BETWEEN '2023-01-01' AND '2023-01-31';
            """
        }
    }

    return examples

def few_shot_sql_generation(question, db):
    """Generate SQL using few-shot examples."""

    # Get examples
    examples = create_few_shot_examples()

    # Select relevant examples based on question type
    # For this demo, we'll just use a simple keyword-based approach
    selected_example = None
    if any(word in question.lower() for word in ["average", "mean", "sum", "total", "count"]):
        selected_example = examples["aggregation"]
    elif any(word in question.lower() for word in ["which", "who", "where", "find"]):
        selected_example = examples["filtering"]
    elif any(word in question.lower() for word in ["when", "date", "month", "year"]):
        selected_example = examples["temporal"]
    else:
        # Default to aggregation example
        selected_example = examples["aggregation"]

    few_shot_prompt = f"""
    Given the following database schema and example, write a SQL query to answer the user's question.

    SCHEMA:
    {db.get_table_info()}

    EXAMPLE QUESTION: {selected_example["question"]}
    EXAMPLE SQL QUERY: {selected_example["sql"]}

    USER QUESTION: {question}

    Important:
    - DO NOT include markdown formatting like ```sql in your response
    - Just return the raw SQL query without any explanations

    SQL QUERY:
    """

    response = llm.invoke(few_shot_prompt).content

    # Clean up any markdown formatting that might be present
    response = response.replace("```sql", "").replace("```", "").strip()

    return response

# Test with a complex query
complex_question = "What is the most popular product category based on total quantity ordered?"
print(f"\nTesting few-shot approach with complex question: '{complex_question}'")
few_shot_sql = few_shot_sql_generation(complex_question, db)
print(f"Generated SQL:\n{few_shot_sql}")

# Execute to see if it works
few_shot_result = execute_query(few_shot_sql, db_path)
print(f"Query execution {'succeeded' if few_shot_result['success'] else 'failed'}")
if few_shot_result["success"]:
    print(few_shot_result["results"])
else:
    print(few_shot_result["message"])

**# 11.3.3 Chain-of-thought prompting for query decomposition**

In [None]:
print("\n" + "="*80)
print("SECTION 11.3.3: Chain-of-thought prompting for query decomposition")
print("="*80)

def chain_of_thought_sql_generation(question, db):
    """Generate SQL using chain-of-thought prompting."""

    cot_prompt = f"""You are an expert SQL developer tasked with translating natural language questions into SQL.

    SCHEMA:
    {db.get_table_info()}

    QUESTION: {question}

    Let's break this down step by step:

    1) Identify the tables needed to answer this question.
    2) Determine the required joins between these tables.
    3) Identify any filtering conditions needed.
    4) Determine what aggregations or calculations are required.
    5) Consider how to order or limit the results appropriately.

    After reasoning through these steps, provide your final SQL query.

    IMPORTANT: At the end, include your final SQL query clearly marked with "FINAL SQL QUERY:" on its own line.
    DO NOT use markdown formatting or code blocks. Just provide the raw SQL.

    REASONING:
    """

    full_response = llm.invoke(cot_prompt)

    # Try to extract SQL from the response
    final_sql_match = re.search(r"FINAL SQL QUERY:\s*(.*?(?:\n.*?)*?)(?:\n\n|$)", full_response.content, re.DOTALL)

    if final_sql_match:
        sql_query = final_sql_match.group(1).strip()
    else:
        # Fallback: look for lines that appear to be SQL
        lines = full_response.content.split('\n')
        sql_lines = []
        capturing = False

        for line in lines:
            clean_line = line.strip().upper()
            # Start capturing when we see SELECT
            if clean_line.startswith("SELECT") and not capturing:
                capturing = True
                sql_lines.append(line)
            # Keep capturing until we see a line ending with semicolon
            elif capturing:
                sql_lines.append(line)
                if line.strip().endswith(";"):
                    break

        sql_query = "\n".join(sql_lines) if sql_lines else "Could not extract SQL query"

    # Clean up any markdown formatting
    sql_query = sql_query.replace("```sql", "").replace("```", "").strip()

    return {
        "reasoning": full_response.content,
        "sql_query": sql_query
    }

# Test with an analytical question that requires multiple steps
analytical_question = "Which customer has placed the most orders of Electronics products, and what's their total spend on that category?"
print(f"\nTesting chain-of-thought approach with analytical question: '{analytical_question}'")
cot_result = chain_of_thought_sql_generation(analytical_question, db)
print(f"\nReasoning process:\n{cot_result['reasoning']}")
print(f"\nExtracted SQL query:\n{cot_result['sql_query']}")

# Execute to see if it works
cot_execution = execute_query(cot_result['sql_query'], db_path)
print(f"\nQuery execution {'succeeded' if cot_execution['success'] else 'failed'}")
if cot_execution["success"]:
    print(cot_execution["results"])
else:
    print(cot_execution["message"])

**# 11.3.4 Iterative refinement approaches**

In [None]:
print("\n" + "="*80)
print("SECTION 11.3.4: Iterative refinement approaches")
print("="*80)

def validate_sql_query(query, db_path):
    """Simple validation to check if a query runs."""
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(query)
        conn.close()
        return {"valid": True, "error": None}
    except Exception as e:
        if conn:
            conn.close()
        return {"valid": False, "error": str(e)}

def iterative_refinement(question, db, max_attempts=3):
    """Generate SQL with iterative refinement."""

    print(f"Initial question: '{question}'")

    for attempt in range(max_attempts):
        print(f"\nAttempt {attempt+1}:")

        # Generate SQL - use different techniques in each attempt
        if attempt == 0:
            # First try schema-aware
            sql_query = schema_aware_sql_generation(question, db)
        elif attempt == 1:
            # Then try few-shot
            sql_query = few_shot_sql_generation(question, db)
        else:
            # Finally try chain-of-thought
            cot_result = chain_of_thought_sql_generation(question, db)
            sql_query = cot_result["sql_query"]

        print(f"Generated SQL:\n{sql_query}")

        # Validate
        validation_result = validate_sql_query(sql_query, db_path)

        if validation_result["valid"]:
            print("✓ SQL query is valid!")

            # Execute to check results
            execution = execute_query(sql_query, db_path)
            if execution["success"] and not execution["results"].empty:
                print("✓ Query returned results!")
                return {
                    "success": True,
                    "sql": sql_query,
                    "results": execution["results"]
                }
            else:
                print("⚠ Query executed but returned no results. Continuing refinement...")
        else:
            print(f"✗ SQL query has an error: {validation_result['error']}")

        # If still here, we need to refine
        if attempt < max_attempts - 1:
            refinement_prompt = f"""
            The SQL query has issues:
            {validation_result.get('error', 'Query returned no results')}

            Original question: {question}

            Original SQL attempt:
            {sql_query}

            Please correct the SQL query to fix this issue.
            """

            refined_sql = llm.invoke(refinement_prompt).content
            sql_query = refined_sql.strip()

    print(f"⚠ Reached maximum attempts ({max_attempts}). Returning best effort.")
    return {
        "success": False,
        "sql": sql_query,
        "error": validation_result.get('error', 'Failed to generate valid SQL after multiple attempts')
    }

# Test with a question that might require refinement
tricky_question = "What's the average order value per month, showing only months where the average is above $500?"
print(f"\nTesting iterative refinement with tricky question: '{tricky_question}'")
refinement_result = iterative_refinement(tricky_question, db)

if refinement_result["success"]:
    print("\nFinal successful SQL:")
    print(refinement_result["sql"])
    print("\nResults:")
    print(refinement_result["results"])
else:
    print("\nFailed to generate valid SQL:")
    print(refinement_result["sql"])
    print(f"Error: {refinement_result.get('error')}")

**# Comparison of Approaches**

In [None]:
print("\n" + "="*80)
print("Comparing Different Prompting Approaches")
print("="*80)

# Define a set of test questions to compare approaches
test_questions = [
    "What are the top 3 customers by lifetime value?",
    "How many orders has each customer placed?",
    "What's the total sales amount for each product category?",
    "Which customer spent the most on Electronics products?",
    "What's the average order value by month in 2023?"
]

# Function to test all approaches
def compare_approaches(questions):
    """Compare different prompting approaches on a set of questions."""
    results = []

    for i, question in enumerate(questions):
        print(f"\nProcessing question {i+1}: '{question}'")

        try:
            # Generate SQL with each approach
            basic = basic_sql_generation(question, db)
            schema_aware = schema_aware_sql_generation(question, db)
            few_shot = few_shot_sql_generation(question, db)
            cot_result = chain_of_thought_sql_generation(question, db)
            cot = cot_result["sql_query"]

            # Check validity
            basic_valid = validate_sql_query(basic, db_path)["valid"]
            schema_valid = validate_sql_query(schema_aware, db_path)["valid"]
            few_shot_valid = validate_sql_query(few_shot, db_path)["valid"]
            cot_valid = validate_sql_query(cot, db_path)["valid"]

            # Store results
            results.append({
                "question": question,
                "approaches": {
                    "Basic": {"valid": basic_valid, "sql": basic},
                    "Schema-aware": {"valid": schema_valid, "sql": schema_aware},
                    "Few-shot": {"valid": few_shot_valid, "sql": few_shot},
                    "Chain-of-thought": {"valid": cot_valid, "sql": cot}
                }
            })

            # Show results for this question
            print(f"  Basic: {'✓' if basic_valid else '✗'}")
            print(f"  Schema-aware: {'✓' if schema_valid else '✗'}")
            print(f"  Few-shot: {'✓' if few_shot_valid else '✗'}")
            print(f"  Chain-of-thought: {'✓' if cot_valid else '✗'}")

        except Exception as e:
            print(f"Error processing question: {str(e)}")

    return results

# Run the comparison
comparison_results = compare_approaches(test_questions)

# Calculate success rates
approaches = ["Basic", "Schema-aware", "Few-shot", "Chain-of-thought"]
success_counts = {approach: 0 for approach in approaches}

for result in comparison_results:
    for approach in approaches:
        if result["approaches"][approach]["valid"]:
            success_counts[approach] += 1

# Convert to percentages
total_questions = len(comparison_results)
success_rates = {approach: (count / total_questions) * 100 for approach, count in success_counts.items()}

print("\nSuccess rates by approach:")
for approach, rate in success_rates.items():
    print(f"{approach}: {rate:.1f}%")

# Visualize the results
plt.figure(figsize=(10, 6))
plt.bar(approaches, [success_rates[a] for a in approaches], color=['#FF9999', '#66B2FF', '#99FF99', '#FFCC99'])
plt.axhline(y=80, color='r', linestyle='--', alpha=0.7, label='Good Performance (80%)')
plt.ylim(0, 105)
plt.xlabel('Prompting Technique')
plt.ylabel('Success Rate (%)')
plt.title('SQL Generation Success Rate by Prompting Technique')
plt.legend()
plt.tight_layout()
plt.show()