In [38]:
import pandas as pd
import json

In [7]:
products_df = pd.read_csv('../data/products.csv')
users_df = pd.read_csv('../data/users.csv')
orders_df = pd.read_csv('../data/orders.csv')
order_items_df = pd.read_csv('../data/order_items.csv')
inventory_items_df = pd.read_csv('../data/inventory_items.csv')
distribution_centers_df = pd.read_csv('../data/distribution_centers.csv')
events_df = pd.read_csv('../data/events.csv')

dataframes = {
    "products": products_df,
    "users": users_df,
    "orders": orders_df,
    "order_items": order_items_df,
    "inventory_items": inventory_items_df,
    "distribution_centers": distribution_centers_df,
    "events": events_df,
}

In [8]:
for name, df in dataframes.items():
    print(f"\n{name.upper()} DataFrame:")
    print(df.head())
    print("-" * 80)


PRODUCTS DataFrame:
      id     cost     category  \
0  13842  2.51875  Accessories   
1  13928  2.33835  Accessories   
2  14115  4.87956  Accessories   
3  14157  4.64877  Accessories   
4  14273  6.50793  Accessories   

                                                name brand  retail_price  \
0   Low Profile Dyed Cotton Twill Cap - Navy W39S55D    MG          6.25   
1  Low Profile Dyed Cotton Twill Cap - Putty W39S55D    MG          5.95   
2       Enzyme Regular Solid Army Caps-Black W35S45D    MG         10.99   
3  Enzyme Regular Solid Army Caps-Olive W35S45D (...    MG         10.99   
4              Washed Canvas Ivy Cap - Black W11S64C    MG         15.99   

  department                               sku  distribution_center_id  
0      Women  EBD58B8A3F1D72F4206201DA62FB1204                       1  
1      Women  2EAC42424D12436BDD6A5B8A88480CC3                       1  
2      Women  EE364229B2791D1EF9355708EFF0BA34                       1  
3      Women  00BD13095D0

### SQLite Database

In [9]:
import sqlite3
import os

db_dir_path = 'db_data'
db = "ecommerce.db"

if not os.path.exists(db_dir_path):
    os.makedirs(db_dir_path)

conn = sqlite3.connect(os.path.join(db_dir_path, db))

for name, df in dataframes.items():
    df.to_sql(name, conn, if_exists='replace', index=False)
    print(f"Table '{name}' created in database '{db}'.")

conn.close()
print("All tables have been created and data has been inserted successfully.")



Table 'products' created in database 'ecommerce.db'.
Table 'users' created in database 'ecommerce.db'.
Table 'orders' created in database 'ecommerce.db'.
Table 'order_items' created in database 'ecommerce.db'.
Table 'inventory_items' created in database 'ecommerce.db'.
Table 'distribution_centers' created in database 'ecommerce.db'.
Table 'events' created in database 'ecommerce.db'.
All tables have been created and data has been inserted successfully.


### Gemini Model

In [10]:
import getpass
import os

if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter your Google AI API key: ")

In [11]:
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

In [12]:
messages = [
    (
        "system",
        "You are a helpful assistant that translates English to German. Translate the user sentence.",
    ),
    ("human", "I love programming."),
]
ai_msg = llm.invoke(messages)
ai_msg.content

'Ich liebe Programmieren.'

### Schema Definition

In [13]:
import sqlite3
import pandas as pd
import os

# Connect to the database
conn = sqlite3.connect(os.path.join(db_dir_path, db))
cursor = conn.cursor()

# Get all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()

print("DATABASE SCHEMA\n" + "=" * 80)

for table in tables:
    table_name = table[0]
    print(f"\nTable: {table_name.upper()}")
    print("-" * 80)

    # Fetch schema for that table
    cursor.execute(f"PRAGMA table_info({table_name});")
    columns = cursor.fetchall()

    # Convert to dataframe with only name + type
    schema_df = pd.DataFrame(
        columns, columns=["cid", "name", "type", "notnull", "dflt_value", "pk"]
    )
    schema_df = schema_df[["name", "type"]]
    schema_df.columns = ["Column Name", "Data Type"]

    print(schema_df.to_string(index=False))
    print()

conn.close()

DATABASE SCHEMA

Table: PRODUCTS
--------------------------------------------------------------------------------
           Column Name Data Type
                    id   INTEGER
                  cost      REAL
              category      TEXT
                  name      TEXT
                 brand      TEXT
          retail_price      REAL
            department      TEXT
                   sku      TEXT
distribution_center_id   INTEGER


Table: USERS
--------------------------------------------------------------------------------
   Column Name Data Type
            id   INTEGER
    first_name      TEXT
     last_name      TEXT
         email      TEXT
           age   INTEGER
        gender      TEXT
         state      TEXT
street_address      TEXT
   postal_code      TEXT
          city      TEXT
       country      TEXT
      latitude      REAL
     longitude      REAL
traffic_source      TEXT
    created_at      TEXT


Table: ORDERS
--------------------------------------------

In [14]:
SCHEMA_DEFINITION = """
TABLE: PRODUCTS
Description: Catalog of items available for sale.
COLUMNS:
- id (INTEGER): PK. Unique identifier for the product.
- cost (REAL): The cost to manufacture or acquire the item (not the sale price).
- category (TEXT): High-level product category (e.g., 'Accessories', 'Outerwear').
- name (TEXT): The commercial name of the product.
- brand (TEXT): The brand manufacturer.
- retail_price (REAL): The suggested MSRP or list price of the item.
- department (TEXT): Gender or demographic target (e.g., 'Men', 'Women').
- sku (TEXT): Stock Keeping Unit code.
- distribution_center_id (INTEGER): FK. Links to DISTRIBUTION_CENTERS table (location where stocked).

TABLE: USERS
Description: Registered customers and their demographic data.
COLUMNS:
- id (INTEGER): PK. Unique identifier for the user.
- first_name (TEXT): User's first name.
- last_name (TEXT): User's last name.
- email (TEXT): User's email address.
- age (INTEGER): User's age.
- gender (TEXT): User's gender (M/F).
- state (TEXT): State of residence.
- street_address (TEXT): Street address.
- postal_code (TEXT): Zip/Postal code.
- city (TEXT): City of residence.
- country (TEXT): Country of residence.
- latitude (REAL): GPS latitude of user address.
- longitude (REAL): GPS longitude of user address.
- traffic_source (TEXT): Marketing channel that acquired the user (e.g., 'Search', 'Organic').
- created_at (TIMESTAMP): Date and time the account was created.

TABLE: ORDERS
Description: Summary of a purchase event (basket level).
COLUMNS:
- order_id (INTEGER): PK. Unique identifier for the order.
- user_id (INTEGER): FK. Links to USERS table.
- status (TEXT): Current state of the order (e.g., 'Complete', 'Cancelled', 'Returned').
- gender (TEXT): Gender associated with the order items (often redundant with User gender).
- created_at (TIMESTAMP): Timestamp when the order was placed.
- returned_at (TIMESTAMP): Timestamp if/when the order was returned.
- shipped_at (TIMESTAMP): Timestamp when the order left the warehouse.
- delivered_at (TIMESTAMP): Timestamp when the order reached the customer.
- num_of_item (INTEGER): Total count of items in this order.

TABLE: ORDER_ITEMS
Description: Individual line items within an order. Use this for revenue calculations.
COLUMNS:
- id (INTEGER): PK. Unique identifier for the line item.
- order_id (INTEGER): FK. Links to ORDERS table.
- user_id (INTEGER): FK. Links to USERS table.
- product_id (INTEGER): FK. Links to PRODUCTS table.
- inventory_item_id (INTEGER): FK. Links to INVENTORY_ITEMS table (specific stock instance).
- status (TEXT): Status of this specific item (e.g., 'Returned', 'Complete').
- created_at (TIMESTAMP): Purchase timestamp.
- shipped_at (TIMESTAMP): Shipping timestamp.
- delivered_at (TIMESTAMP): Delivery timestamp.
- returned_at (TIMESTAMP): Return timestamp.
- sale_price (REAL): The actual price the user paid for this item (Revenue).

TABLE: INVENTORY_ITEMS
Description: Historical log of every specific physical item in the warehouse.
COLUMNS:
- id (INTEGER): PK. Unique identifier for the inventory unit.
- product_id (INTEGER): FK. Links to PRODUCTS table.
- created_at (TIMESTAMP): When the item arrived in inventory.
- sold_at (TIMESTAMP): When the item was sold (NULL if currently in stock).
- cost (REAL): Cost of this specific inventory batch.
- product_category (TEXT): Redundant snapshot of product category.
- product_name (TEXT): Redundant snapshot of product name.
- product_brand (TEXT): Redundant snapshot of brand.
- product_retail_price (REAL): Redundant snapshot of retail price.
- product_department (TEXT): Redundant snapshot of department.
- product_sku (TEXT): Redundant snapshot of SKU.
- product_distribution_center_id (INTEGER): FK. Links to DISTRIBUTION_CENTERS table.

TABLE: DISTRIBUTION_CENTERS
Description: Physical warehouse locations.
COLUMNS:
- id (INTEGER): PK. Unique identifier for the distribution center.
- name (TEXT): Name of the facility (e.g., 'Memphis TN').
- latitude (REAL): GPS latitude of the facility.
- longitude (REAL): GPS longitude of the facility.

TABLE: EVENTS
Description: Web traffic logs (views, clicks, interactions).
COLUMNS:
- id (INTEGER): PK. Unique identifier for the event log.
- user_id (REAL): FK. Links to USERS (can be NULL for guest visitors).
- sequence_number (INTEGER): Order of events within a session.
- session_id (TEXT): Unique ID for the browsing session.
- created_at (TIMESTAMP): Timestamp of the event.
- ip_address (TEXT): User's IP address.
- city (TEXT): Estimated city based on IP.
- state (TEXT): Estimated state based on IP.
- postal_code (TEXT): Estimated zip code based on IP.
- browser (TEXT): Browser used (e.g., 'Chrome', 'Safari').
- traffic_source (TEXT): Marketing source for this session.
- uri (TEXT): The specific URL path visited.
- event_type (TEXT): Type of interaction (e.g., 'product', 'department', 'cart', 'purchase').
"""

### Graph State

In [15]:
from langgraph.graph import MessagesState

class GraphState(MessagesState):
    is_question_relavant: bool
    user_query: str
    sql_query_generated: str
    result_for_sql_query: str
    final_answer: str
    error_message: str
    curr_iteration: int
    needs_plotly_figure: bool
    type_of_plotly_figure: str
    plotly_figure_json_string: str

### Agents

#### Guardrails Agent

In [16]:
from pydantic import BaseModel, Field

class GuardrailsResponse(BaseModel):
    is_question_relavant: bool = Field(
        description="Indicates if the user's query is relevant to the e-commerce database.",
    )
    is_greeting: bool = Field(
        description="Indicates if the user's query is a greeting message.",
    )
    reason: str = Field(
        description="Explanation for the classification decision.",
    )

In [17]:

def guardrails_agent(state: GraphState) -> GraphState:
    """Validates if user query is relevant to the e-commerce database"""
    user_query = state["user_query"]

    prompt = f"""You are a guardrails agent for an e-commerce SQL query system. Analyze if the user's question can be answered using the available database.

DATABASE SCOPE:
The system has access to an e-commerce database containing:
- Products: catalog, pricing, categories, brands, departments
- Users: customer demographics, locations, registration info
- Orders: transactions, status tracking, delivery timestamps
- Order Items: line-level details, revenue data
- Inventory Items: stock levels, warehouse tracking
- Distribution Centers: warehouse locations
- Events: user behavior, web analytics, session tracking

CLASSIFICATION RULES:

1. GREETING - Casual conversational starters:
   - "Hi", "Hello", "Hey there"
   - "Good morning/afternoon/evening"
   - "How are you doing?"
   
2. IN-SCOPE - Questions answerable with database:
   - Sales analytics: "What was total revenue in 2022?"
   - Product queries: "Which brand has highest sales?"
   - Customer analysis: "How many users from Texas?"
   - Inventory questions: "Show products out of stock"
   - Trend analysis: "Monthly order trends"
   - Behavioral insights: "What pages do users visit most?"
   
3. OUT-OF-SCOPE - Cannot be answered with this database:
   - Personal information: "What's my order history?" (no authentication context)
   - Future predictions: "What will sell next month?" (no ML capability)
   - External data: "Compare our prices to competitors"
   - General knowledge: "How does e-commerce work?"
   - Unrelated topics: "Tell me a joke", "Weather forecast"
   - Real-time data: "Current inventory right now" (data is historical)

User Question: "{user_query}"

Guidelines:
- If greeting: set is_greeting=true, is_question_relavant=false
- If ambiguous but potentially answerable: mark is_question_relavant=true
- Be permissive - favor is_question_relavant=true when uncertain"""

    structured_llm = llm.with_structured_output(GuardrailsResponse)

    response = structured_llm.invoke(prompt)

    state["is_question_relavant"] = response.is_question_relavant
    is_greeting = response.is_greeting

    if is_greeting:
        state["final_answer"] = (
            "Hello! How can I assist you with e-commerce data today?"
        )
        return state

    if not state["is_question_relavant"]:
        state["final_answer"] = (
            "I'm sorry, but your question is outside the scope of the e-commerce database I have access to. "
            "Please ask something related to products, users, orders, inventory, or sales analytics."
        )
        return state

    return state

#### SQL Generation Agent

In [18]:
class SQLGenerationResponse(BaseModel):
    sql_query: str = Field(
        description="The generated SQL query based on the user's question.",
    )
    explanation: str = Field(
        description="Brief explanation of what the query does (max 30 words).",
    )


def sql_generation_agent(state: GraphState) -> GraphState:
    """Generate SQL query from natural language question"""
    user_query = state["user_query"]
    iteration = state.get("curr_iteration", 0)

    prompt = f"""You are an expert SQL developer specializing in SQLite databases. Convert the user's natural language question into a valid, optimized SQLite query.

{SCHEMA_DEFINITION}

QUERY GENERATION RULES:

1. SCHEMA COMPLIANCE:
   - Use ONLY tables and columns defined in the schema above
   - Respect data types (TEXT, INTEGER, REAL, TIMESTAMP)
   - Follow foreign key relationships for JOINs

2. SQL BEST PRACTICES:
   - Use explicit JOIN syntax (INNER JOIN, LEFT JOIN) with ON clauses
   - Apply WHERE filters before aggregations
   - Use meaningful table aliases (p for products, u for users, o for orders)
   - Add ORDER BY for ranked results
   - Include LIMIT 10 unless user specifies a different number

3. AGGREGATIONS & ANALYTICS:
   - Use COUNT, SUM, AVG, MIN, MAX appropriately
   - GROUP BY required columns when using aggregates
   - Use HAVING for post-aggregation filters
   - Calculate revenue using order_items.sale_price (NOT products.retail_price)

4. DATE HANDLING:
   - Dates are stored as TEXT in ISO format (YYYY-MM-DD HH:MM:SS)
   - Use DATE() function for date comparisons
   - Use strftime() for date formatting and extraction

5. REVENUE CALCULATIONS:
   - Always use order_items.sale_price for actual revenue
   - Join with orders table to filter by status (exclude 'Cancelled', 'Returned')
   - Consider order status when calculating metrics

6. COMMON PATTERNS:
   - Top N queries: ORDER BY ... DESC LIMIT N
   - Trend analysis: GROUP BY strftime('%Y-%m', created_at)
   - Customer segmentation: JOIN users with orders/order_items
   - Product analytics: JOIN products with order_items

User Question: "{user_query}"

Generate a single, executable SQL query. No markdown formatting, no explanations in the query itself.

SQL Query:"""

    structured_llm = llm.with_structured_output(SQLGenerationResponse)

    response = structured_llm.invoke(prompt)

    # Clean the SQL query
    sql_query = response.sql_query.strip()
    sql_query = sql_query.replace("```sql", "").replace("```", "").strip()

    state["sql_query_generated"] = sql_query
    state["curr_iteration"] = iteration + 1

    return state

#### SQL Execution 

In [19]:
def execute_sql(state: GraphState) -> GraphState:
    """
    Executes the generated SQL query and handles multiple queries if present.

    This function:
    1. Splits the SQL query into individual statements (separated by semicolons)
    2. Executes each statement sequentially
    3. Formats results as DataFrames for readability
    4. Handles errors gracefully
    5. Stores results in state for downstream processing
    """

    sql_query = state["sql_query_generated"]

    try:
        # Establish database connection
        conn = sqlite3.connect(os.path.join(db_dir_path, db))
        cursor = conn.cursor()

        # Split multiple SQL statements (separated by semicolons)
        # Filter out empty statements and strip whitespace
        queries = [q.strip() for q in sql_query.split(";") if q.strip()]

        all_results = []

        # Execute each statement separately
        for idx, query in enumerate(queries):
            cursor.execute(query)

            # Fetch results for this statement
            results = cursor.fetchall()

            if results:
                # Get column names from cursor description
                column_names = [description[0] for description in cursor.description]

                # Convert to DataFrame for better readability
                df = pd.DataFrame(results, columns=column_names)

                # Format result with query number if multiple queries exist
                if len(queries) > 1:
                    result_text = f"Query {idx + 1}:\n{query}\n\nResult:\n{df.to_string(index=False)}"
                else:
                    result_text = df.to_string(index=False)

                all_results.append(result_text)
            else:
                # Handle queries that return no rows (e.g., CREATE, INSERT, UPDATE)
                if len(queries) > 1:
                    all_results.append(
                        f"Query {idx + 1}:\n{query}\n\nResult: No rows returned"
                    )
                else:
                    all_results.append("No results found.")

        # Close the database connection
        conn.close()

        # Store formatted results in state
        if all_results:
            state["result_for_sql_query"] = "\n\n" + "=" * 80 + "\n\n".join(all_results)
        else:
            state["result_for_sql_query"] = (
                "Query executed successfully but returned no results."
            )

        # Clear any previous error
        state["error_message"] = ""

    except sqlite3.Error as e:
        # Handle SQLite-specific errors
        state["error_message"] = f"SQL Execution Error: {str(e)}"
        state["result_for_sql_query"] = ""

    except Exception as e:
        # Handle unexpected errors
        state["error_message"] = f"Unexpected Error: {str(e)}"
        state["result_for_sql_query"] = ""

    finally:
        # Ensure connection is closed even if an error occurs
        if "conn" in locals():
            conn.close()

    return state

#### Error Agent

In [20]:
class ErrorCorrectionResponse(BaseModel):
    corrected_sql_query: str = Field(
        description="The fixed SQL query that should resolve the error.",
    )
    error_analysis: str = Field(
        description="Brief explanation of what was wrong and how it was fixed (max 50 words).",
    )


def error_correction_agent(state: GraphState) -> GraphState:
    """
    Attempts to automatically fix SQL errors by analyzing the error message and regenerating the query.

    This function:
    1. Analyzes the SQL error message and failed query
    2. Uses the schema definition to understand what went wrong
    3. Generates a corrected SQL query
    4. Implements retry logic with a maximum of 3 attempts
    5. Returns an apology message if all retries fail
    """

    error_message = state["error_message"]
    failed_sql_query = state["sql_query_generated"]
    user_query = state["user_query"]
    iteration = state.get("curr_iteration", 0)

    # Maximum retry limit - prevent infinite loops
    MAX_RETRIES = 3

    if iteration > MAX_RETRIES:
        state["final_answer"] = (
            f"I apologize, but I'm unable to generate a correct SQL query for your question after {MAX_RETRIES} attempts. "
            f"The error encountered was: {error_message}\n\n"
            "Please try rephrasing your question or contact support for assistance."
        )
        return state

    prompt = f"""You are an expert SQL debugger. A SQL query has failed and you need to fix it.

{SCHEMA_DEFINITION}

ORIGINAL USER QUESTION: "{user_query}"

FAILED SQL QUERY:
{failed_sql_query}

ERROR MESSAGE:
{error_message}

COMMON SQL ERRORS AND FIXES:

1. COLUMN NOT FOUND:
   - Check spelling of column names against schema
   - Ensure table aliases match the columns being referenced
   - Use table.column notation for ambiguous columns

2. TABLE NOT FOUND:
   - Verify table name spelling matches schema exactly
   - Check for typos (e.g., 'order_item' vs 'order_items')

3. SYNTAX ERRORS:
   - Missing commas between column names
   - Unmatched parentheses in subqueries
   - Missing ON clause in JOIN statements
   - Incorrect GROUP BY usage (all non-aggregated columns must be in GROUP BY)

4. AGGREGATION ERRORS:
   - Ensure all non-aggregated columns appear in GROUP BY
   - Use HAVING for filtering aggregated results, WHERE for row-level filters
   - Don't mix aggregated and non-aggregated columns incorrectly

5. JOIN ERRORS:
   - Ensure foreign key relationships are correct
   - Use proper join types (INNER vs LEFT JOIN)
   - Include ON clause with valid join conditions

DEBUGGING STEPS:
1. Identify the exact error from the error message
2. Locate the problematic part of the query
3. Reference the schema to find correct column/table names
4. Fix the issue while preserving the original query intent
5. Ensure the corrected query still answers the user's question

Generate a corrected SQL query. No markdown formatting, no explanations in the query itself.

Corrected SQL Query:"""

    structured_llm = llm.with_structured_output(ErrorCorrectionResponse)

    response = structured_llm.invoke(prompt)

    # Clean the corrected SQL query
    corrected_query = response.corrected_sql_query.strip()
    corrected_query = corrected_query.replace("```sql", "").replace("```", "").strip()

    # Update state with corrected query
    state["sql_query_generated"] = corrected_query
    state["error_message"] = ""  # Clear error to trigger retry
    state["curr_iteration"] = iteration + 1  # Increment retry counter

    return state

#### Analysis Agent

In [21]:
# Analysis Agent
class AnalysisResponse(BaseModel):
    natural_language_answer: str = Field(
        description="Clear, concise natural language explanation of the query results.",
    )
    key_insights: list[str] = Field(
        description="List of 2-3 key takeaways or insights from the data.",
    )
    needs_visualization: bool = Field(
        description="Whether the data would benefit from a chart/graph visualization.",
    )


def analysis_agent(state: GraphState) -> GraphState:
    """
    Converts SQL query results into natural language answers.
    
    This function:
    1. Takes the raw SQL results and the original user question
    2. Generates a human-readable explanation of the findings
    3. Identifies key insights from the data
    4. Determines if visualization would enhance understanding
    5. Formats the response in a clear, user-friendly manner
    """
    
    user_query = state["user_query"]
    sql_query = state["sql_query_generated"]
    query_result = state["result_for_sql_query"]
    
    prompt = f"""You are a data analyst expert who explains database query results in clear, natural language.

ORIGINAL USER QUESTION: "{user_query}"

SQL QUERY EXECUTED:
{sql_query}

QUERY RESULTS:
{query_result}

ANALYSIS GUIDELINES:

1. ANSWER FORMAT:
   - Start with a direct answer to the user's question
   - Use clear, conversational language (avoid technical jargon)
   - Present numbers with proper formatting (e.g., "$1,234.56" for money, "1,234" for counts)
   
2. DATA PRESENTATION:
   - For single values: state them clearly (e.g., "The total revenue was $45,678")
   - For lists/rankings: use bullet points or numbered lists
   - For comparisons: highlight differences explicitly
   - For trends: describe the pattern observed

3. CONTEXT & INSIGHTS:
   - Explain what the numbers mean in business terms
   - Identify notable patterns or outliers
   - Provide 2-3 key takeaways from the data
   
4. MULTI-PART QUESTIONS:
   - Address each part of the question separately
   - Use clear section headers if needed
   - Maintain logical flow in the answer

5. VISUALIZATION CONSIDERATION:
   - Determine if a chart would help visualize the data
   - Consider visualization for: trends over time, comparisons, distributions, rankings

Generate a comprehensive, user-friendly answer based on the query results."""

    structured_llm = llm.with_structured_output(AnalysisResponse)
    
    response = structured_llm.invoke(prompt)
    
    # Format the final answer with insights
    final_answer_parts = [response.natural_language_answer]
    
    if response.key_insights:
        final_answer_parts.append("\n\n**Key Insights:**")
        for i, insight in enumerate(response.key_insights, 1):
            final_answer_parts.append(f"{i}. {insight}")
    
    state["final_answer"] = "\n".join(final_answer_parts)
    state["needs_plotly_figure"] = response.needs_visualization
    
    return state


#### Decide Vizualization Agent

In [22]:

class VisualizationDecisionResponse(BaseModel):
    needs_visualization: bool = Field(
        description="Whether the data would benefit from visualization.",
    )
    visualization_type: str = Field(
        description="Type of chart: 'bar', 'line', 'pie', 'scatter', or 'none'.",
    )
    reasoning: str = Field(
        description="Brief explanation for the decision (max 30 words).",
    )


def decide_visualization_agent(state: GraphState) -> GraphState:
    """
    Determines if visualization would enhance data understanding.
    
    This function:
    1. Analyzes the query results and question type
    2. Decides if a chart would add value
    3. Selects the most appropriate chart type
    4. Provides reasoning for the decision
    """
    
    user_query = state["user_query"]
    query_result = state["result_for_sql_query"]
    
    # Skip if no results or already has error
    if not query_result or "No results found" in query_result or state.get("error_message"):
        state["needs_plotly_figure"] = False
        state["type_of_plotly_figure"] = "none"
        return state
    
    prompt = f"""You are a data visualization expert. Analyze whether a chart would enhance understanding of this data.

USER QUESTION: "{user_query}"

QUERY RESULTS (first 500 chars):
{query_result[:500]}

VISUALIZATION DECISION RULES:

1. BAR CHART - Use for:
   - Comparing categories (top products, sales by region)
   - Ranking items (top 10 customers)
   - Discrete comparisons

2. LINE CHART - Use for:
   - Trends over time (monthly revenue, daily orders)
   - Time series data
   - Sequential patterns

3. PIE CHART - Use for:
   - Proportions/percentages (market share, category distribution)
   - Part-to-whole relationships
   - Maximum 5-7 categories

4. SCATTER PLOT - Use for:
   - Correlations between two variables
   - Distribution patterns
   - Outlier detection

5. NO VISUALIZATION - When:
   - Single value answers ("total: 42")
   - Simple yes/no responses
   - Text-heavy results
   - Already clear from numbers alone

Determine if visualization would add value and select the best chart type."""

    structured_llm = llm.with_structured_output(VisualizationDecisionResponse)
    
    response = structured_llm.invoke(prompt)
    
    state["needs_plotly_figure"] = response.needs_visualization
    state["type_of_plotly_figure"] = response.visualization_type
    
    return state

#### Visualization Agent

In [23]:
class PlotlyCodeResponse(BaseModel):
    plotly_code: str = Field(
        description="Python code to generate a Plotly visualization.",
    )
    chart_title: str = Field(
        description="Title for the chart.",
    )


def visualization_agent(state: GraphState) -> GraphState:
    """
    Generates Plotly visualization code from query results.
    
    This function:
    1. Takes the query results and chart type
    2. Generates Python code using Plotly
    3. Executes the code to create a figure
    4. Exports the figure as JSON for rendering
    """
    
    user_query = state["user_query"]
    query_result = state["result_for_sql_query"]
    chart_type = state["type_of_plotly_figure"]
    
    try:
        import io
        
        # Extract data from the formatted string
        # This is a simplified parser - may need adjustment based on actual format
        lines = query_result.strip().split('\n')
        
        # For simplicity, let's assume the data can be parsed
        # You may need to enhance this based on your actual result format
        
        prompt = f"""Generate Python code using Plotly to create a {chart_type} chart for this data.

USER QUESTION: "{user_query}"

QUERY RESULTS:
{query_result}

REQUIREMENTS:
1. Use plotly.graph_objects (as 'go') or plotly.express (as 'px')
2. Data is available as a pandas DataFrame named 'df'
3. Create a {chart_type} chart
4. Add proper title, labels, and formatting
5. Variable must be named 'fig'
6. NO import statements (already imported)
7. NO fig.show() or display commands
8. Limit to top 20 data points if there are many rows
9. Use appropriate colors and styling
10. Add hover information

EXAMPLE STRUCTURE:
```python
# Parse data from results
df = pd.DataFrame({{
    'column1': [values],
    'column2': [values]
}})

# Create figure
fig = go.Figure(...)
# or
fig = px.{chart_type}(df, ...)

# Update layout
fig.update_layout(
    title='Chart Title',
    xaxis_title='X Label',
    yaxis_title='Y Label'
)

Generate the complete Plotly code:"""

        structured_llm = llm.with_structured_output(PlotlyCodeResponse)
        
        response = structured_llm.invoke(prompt)
        
        plotly_code = response.plotly_code.strip()
        plotly_code = plotly_code.replace("```python", "").replace("```", "").strip()
        
        # Prepare execution environment
        exec_globals = {
            'pd': pd,
            'json': json
        }
        
        # Import Plotly
        import plotly.graph_objects as go
        import plotly.express as px
        exec_globals['go'] = go
        exec_globals['px'] = px
        
        # Execute the generated code
        exec(plotly_code, exec_globals)
        
        # Get the figure
        fig = exec_globals.get('fig')
        
        if fig is None:
            raise ValueError("Generated code did not create a 'fig' variable")
        
        # Convert to JSON
        state["plotly_figure_json_string"] = fig.to_json()
    
    except Exception as e:
        print(f"Visualization generation error: {e}")
        state["plotly_figure_json_string"] = ""
        state["needs_plotly_figure"] = False

    return state


### Helpers

In [24]:
def should_retry(state: GraphState) -> str:
    """Decide whether to retry after an error"""
    if state.get("error_message"):
        iteration = state.get("curr_iteration", 0)
        if iteration <= 3:
            return "retry"
        else:
            return "end"
    return "success"


def should_visualize(state: GraphState) -> str:
    """Decide whether to generate visualization"""
    if state.get("needs_plotly_figure", False) and state.get("type_of_plotly_figure") != "none":
        return "visualize"
    return "skip"


def check_relevance(state: GraphState) -> str:
    """Check if question is relevant to proceed"""
    # If final_answer is already set by guardrails, it's either greeting or out-of-scope
    if state.get("final_answer"):
        return "end"
    if state.get("is_question_relavant", False):
        return "relevant"
    return "end"

### Graph

In [25]:
from langgraph.graph import StateGraph, END

def create_text2sql_graph():
    """Create the LangGraph workflow for Text-to-SQL with visualization"""

    workflow = StateGraph(GraphState)

    # Add nodes
    workflow.add_node("guardrails_agent", guardrails_agent)
    workflow.add_node("sql_generation_agent", sql_generation_agent)
    workflow.add_node("execute_sql", execute_sql)
    workflow.add_node("error_correction_agent", error_correction_agent)
    workflow.add_node("analysis_agent", analysis_agent)
    workflow.add_node("decide_visualization_agent", decide_visualization_agent)
    workflow.add_node("visualization_agent", visualization_agent)

    # Set entry point
    workflow.set_entry_point("guardrails_agent")

    # Guardrails ‚Üí SQL Generation (if relevant)
    workflow.add_conditional_edges(
        "guardrails_agent",
        check_relevance,
        {"relevant": "sql_generation_agent", "end": END},
    )

    # SQL Generation ‚Üí Execute SQL
    workflow.add_edge("sql_generation_agent", "execute_sql")

    # Execute SQL ‚Üí Analysis or Error Correction
    workflow.add_conditional_edges(
        "execute_sql",
        should_retry,
        {
            "success": "analysis_agent",
            "retry": "error_correction_agent",
            "end": "analysis_agent",
        },
    )

    # Error Correction ‚Üí Execute SQL (retry)
    workflow.add_edge("error_correction_agent", "execute_sql")

    # Analysis ‚Üí Decide Visualization
    workflow.add_edge("analysis_agent", "decide_visualization_agent")

    # Decide Visualization ‚Üí Generate or Skip
    workflow.add_conditional_edges(
        "decide_visualization_agent",
        should_visualize,
        {"visualize": "visualization_agent", "skip": END},
    )

    # Visualization ‚Üí End
    workflow.add_edge("visualization_agent", END)

    return workflow.compile()

In [26]:
# Create the compiled graph
text2sql_graph = create_text2sql_graph()


## Defining the function to generate LangGraph flow visualization
def generate_graph_visualization(output_path: str = "text2sql_workflow.png") -> str:
    """
    Generate a PNG visualization of the LangGraph workflow.

    Args:
        output_path: Path where the PNG file will be saved (default: "text2sql_workflow.png")

    Returns:
        str: Path to the generated PNG file
    """
    try:
        # Get the graph visualization
        graph_image = text2sql_graph.get_graph().draw_mermaid_png()

        # Save to file
        with open(output_path, "wb") as f:
            f.write(graph_image)

        print(f"Graph visualization saved to: {output_path}")
        return output_path

    except Exception as e:
        print(f"Error generating graph visualization: {e}")
        print("Make sure you have 'pygraphviz' or 'grandalf' installed:")
        print("  pip install pygraphviz")
        print("  or")
        print("  pip install grandalf")
        return None

In [27]:
generate_graph_visualization()

Graph visualization saved to: text2sql_workflow.png


'text2sql_workflow.png'

### Streaming Function For Langgraph Execution

In [28]:

async def process_question_stream(user_query: str):
    """
    Process a natural language question and stream node execution events.
    
    This async generator streams events from the LangGraph workflow execution,
    allowing for real-time updates in the Chainlit UI.
    
    Args:
        user_query: The natural language question from the user
        
    Yields:
        dict: Event dictionaries with structure:
            - type: 'node_start', 'node_end', 'error', or 'final'
            - node: Name of the agent node
            - data: Relevant state/output data
            
    Event Types:
        - node_start: When an agent node begins execution
        - node_end: When an agent node completes execution
        - final: When the entire workflow completes
        - error: When an exception occurs
    """
    
    # Initialize the graph state
    initial_state = {
        "user_query": user_query,
        "is_question_relavant": False,
        "sql_query_generated": "",
        "result_for_sql_query": "",
        "final_answer": "",
        "error_message": "",
        "curr_iteration": 0,
        "needs_plotly_figure": False,
        "type_of_plotly_figure": "none",
        "plotly_figure_json_string": "",
        "messages": []  # Required by MessagesState
    }
    
    try:
        # Stream events from the compiled graph
        async for event in text2sql_graph.astream_events(
            initial_state,
            config={"recursion_limit": 50},
            version="v2"  # Use v2 for better event streaming
        ):
            event_type = event.get("event")
            event_name = event.get("name", "")
            
            # Node execution start
            if event_type == "on_chain_start":
                # Filter for our agent nodes
                if event_name in [
                    "guardrails_agent",
                    "sql_generation_agent", 
                    "execute_sql",
                    "error_correction_agent",
                    "analysis_agent",
                    "decide_visualization_agent",
                    "visualization_agent"
                ]:
                    yield {
                        "type": "node_start",
                        "node": event_name,
                        "timestamp": event.get("timestamp")
                    }
            
            # Node execution end
            elif event_type == "on_chain_end":
                if event_name in [
                    "guardrails_agent",
                    "sql_generation_agent",
                    "execute_sql", 
                    "error_correction_agent",
                    "analysis_agent",
                    "decide_visualization_agent",
                    "visualization_agent"
                ]:
                    # Extract output from event data
                    output = event.get("data", {}).get("output", {})
                    
                    yield {
                        "type": "node_end",
                        "node": event_name,
                        "output": output,
                        "timestamp": event.get("timestamp")
                    }
        
        # Get final state
        final_state = await text2sql_graph.ainvoke(initial_state)
        
        yield {
            "type": "final",
            "result": final_state
        }
        
    except Exception as e:
        yield {
            "type": "error",
            "error": str(e),
            "error_type": type(e).__name__
        }



In [29]:
# ============================================================================
# SYNCHRONOUS PROCESSING FUNCTION (For Notebook Testing)
# ============================================================================


def process_question(question: str) -> dict:
    """
    Process a natural language question and return the final result.
    This is a simple synchronous function for notebook usage.

    Args:
        question: Natural language question about the e-commerce data

    Returns:
        dict: Final state with answer, SQL query, and graph data if applicable
    """
    initial_state = {
        "user_query": question,
        "is_question_relavant": False,
        "sql_query_generated": "",
        "result_for_sql_query": "",
        "final_answer": "",
        "error_message": "",
        "curr_iteration": 0,
        "needs_plotly_figure": False,
        "type_of_plotly_figure": "none",
        "plotly_figure_json_string": "",
        "messages": [],  # Required by MessagesState
    }

    try:
        # Invoke the graph synchronously
        final_state = text2sql_graph.invoke(
            initial_state, config={"recursion_limit": 50}
        )

        return final_state

    except Exception as e:
        return {
            "error_message": str(e),
            "final_answer": f"An error occurred while processing your question: {str(e)}",
            "user_query": question,
            "sql_query_generated": "",
            "result_for_sql_query": "",
            "needs_plotly_figure": False,
            "type_of_plotly_figure": "none",
            "plotly_figure_json_string": "",
        }

In [30]:
# ============================================================================
# TEST FUNCTION WITH VISUALIZATION SUPPORT
# ============================================================================


def test_process_question(
    question: str, show_sql: bool = True, show_results: bool = False
):
    """
    Test a question and display formatted results with optional visualization.

    Args:
        question: Natural language question to test
        show_sql: Whether to display the generated SQL query
        show_results: Whether to display raw SQL results
    """
    print("=" * 80)
    print(f"QUESTION: {question}")
    print("=" * 80)

    # Process the question
    result = process_question(question)

    # Display error if present
    if result.get("error_message"):
        print(f"\n‚ùå ERROR: {result['error_message']}\n")

    # Display SQL query
    if show_sql and result.get("sql_query_generated"):
        print("\nüìù GENERATED SQL QUERY:")
        print("-" * 80)
        print(result["sql_query_generated"])
        print("-" * 80)

    # Display raw results (optional)
    if show_results and result.get("result_for_sql_query"):
        print("\nüìä RAW QUERY RESULTS:")
        print("-" * 80)
        # Truncate if too long
        raw_result = result["result_for_sql_query"]
        if len(raw_result) > 1000:
            print(raw_result[:1000] + "\n... (truncated)")
        else:
            print(raw_result)
        print("-" * 80)

    # Display final answer
    print("\nüí¨ FINAL ANSWER:")
    print("-" * 80)
    print(result.get("final_answer", "No answer generated."))
    print("-" * 80)

    # Display visualization info
    print(f"\nüìà Visualization: ", end="")
    if result.get("needs_plotly_figure"):
        print(f"Yes ({result.get('type_of_plotly_figure', 'unknown').upper()} chart)")
    else:
        print("No")

    # Show the graph if available
    if result.get("plotly_figure_json_string"):
        try:
            import plotly.graph_objects as go
            import json

            # Parse and display the figure
            fig = go.Figure(json.loads(result["plotly_figure_json_string"]))
            print("\nüìä Displaying interactive visualization...")
            fig.show()
        except Exception as e:
            print(f"\n‚ö†Ô∏è  Could not display visualization: {e}")

    print("\n" + "=" * 80 + "\n")

    return result

In [31]:
# ============================================================================
# BATCH TEST RUNNER
# ============================================================================


def run_test_suite(queries: list[str], delay: float = 0.5):
    """
    Run multiple test queries in sequence with formatted output.

    Args:
        queries: List of questions to test
        delay: Delay between tests in seconds
    """
    import time

    results = []

    for i, query in enumerate(queries, 1):
        print(f"\n{'#' * 80}")
        print(f"TEST {i}/{len(queries)}")
        print(f"{'#' * 80}\n")

        result = test_process_question(query, show_sql=True, show_results=False)
        results.append({"query": query, "result": result})

        # Delay between tests
        if i < len(queries):
            time.sleep(delay)

    return results

In [39]:
# ============================================================================
# PREDEFINED TEST CASES
# ============================================================================

# Basic test queries covering different scenarios
BASIC_TEST_QUERIES = [
    # Simple aggregation
    "What is the total revenue?",
    # Top N query
    "Show me the top 5 products by sales",
    # Greeting (guardrails test)
    "Hello!",
    # Out of scope
    "What's the weather like today?",
]

# Advanced test queries
ADVANCED_TEST_QUERIES = [
    # Trend analysis
    "What are the monthly order trends for 2023?",
    # Customer segmentation
    "How many users are from California?",
    # Product analytics
    "What are the top 3 product categories by revenue?",
    # Complex join
    "Which brands have the highest average order value?",
    # Time-based analysis
    "Show me daily sales for the last 30 days",
    # Multi-table join
    "What is the average delivery time by state?",
]

# Visualization test queries (should trigger chart generation)
VISUALIZATION_TEST_QUERIES = [
    # Should trigger bar chart
    "Show the top 10 selling products",
    # Should trigger line chart
    "What are the monthly revenue trends?",
    # Should trigger pie chart
    "What is the distribution of orders by status?",
    # Should not trigger visualization (single value)
    "What is the total number of orders?",
]

# Error handling test queries
ERROR_TEST_QUERIES = [
    # Intentionally complex (may cause errors)
    "Show me the correlation between user age and purchase frequency",
    # Ambiguous query
    "Show me everything about sales",
]

In [41]:
# ============================================================================
# EXAMPLE USAGE
# ============================================================================

# Test a single question
test_process_question("What are the monthly revenue trends?")

QUESTION: What are the monthly revenue trends?

üìù GENERATED SQL QUERY:
--------------------------------------------------------------------------------
SELECT strftime('%Y-%m', created_at) AS sales_month, SUM(sale_price) AS total_revenue FROM ORDER_ITEMS WHERE status = 'Complete' GROUP BY sales_month ORDER BY sales_month;
--------------------------------------------------------------------------------

üí¨ FINAL ANSWER:
--------------------------------------------------------------------------------
Monthly revenue has shown a consistent and significant upward trend from January 2019 to January 2024. Starting at just over $500 in early 2019, it has grown dramatically to over $200,000 by January 2024.


**Key Insights:**
1. Monthly revenue has experienced substantial year-over-year growth, increasing from an average of a few thousand dollars in 2019 to over $100,000 per month by late 2023.
2. There is a noticeable pattern of revenue peaking towards the end of each year, particularly

{'messages': [],
 'is_question_relavant': True,
 'user_query': 'What are the monthly revenue trends?',
 'sql_query_generated': "SELECT strftime('%Y-%m', created_at) AS sales_month, SUM(sale_price) AS total_revenue FROM ORDER_ITEMS WHERE status = 'Complete' GROUP BY sales_month ORDER BY sales_month;",
 'final_answer': 'Monthly revenue has shown a consistent and significant upward trend from January 2019 to January 2024. Starting at just over $500 in early 2019, it has grown dramatically to over $200,000 by January 2024.\n\n\n**Key Insights:**\n1. Monthly revenue has experienced substantial year-over-year growth, increasing from an average of a few thousand dollars in 2019 to over $100,000 per month by late 2023.\n2. There is a noticeable pattern of revenue peaking towards the end of each year, particularly in the last quarter (October, November, December), followed by a strong start to the new year.\n3. The most significant growth appears to have occurred between 2022 and 2023, with mon