### Chat With Multiple CSV

In [156]:
# Take csv or excel sheet & upload data in sqlite3
# Take User questions
# Parse questions & get relevant tables and columns
# Get Unique noun
# Generate SQL query
# Validate SQL Query if something wrong fix it
# Execute SQL
#     If there is no error & SQL Query is relevent
#         Choose visualization
#         format data for visualization
#     Else
#         Format result


In [110]:
# from langchain_community.llms import Ollama
# llm = Ollama(model="llama3.1:latest")
from langchain_groq import ChatGroq

llm = ChatGroq(groq_api_key="", model="gemma2-9b-it")

In [238]:
from langchain_core.output_parsers import StrOutputParser
output_parser = StrOutputParser()

In [177]:
## Insert data in sequalite
import os
import pandas as pd
from sqlalchemy import create_engine

# Function to read CSV/Excel and insert into SQLite database
def insert_data_to_sqlite(file_path):
    # Extract the file name without extension to use as table name
    file_name = os.path.splitext(os.path.basename(file_path))[0]
    
    # Read the data (change this to pd.read_excel() for Excel files)
    data = pd.read_csv(file_path)

    # Create a SQLite database (or connect if it already exists)
    engine = create_engine('sqlite:///lumin.db')
    
    # Insert data into the SQLite database with the table name as the file name
    data.to_sql(file_name, con=engine, if_exists='replace', index=False)
    print(f"Data from {file_path} has been inserted into the '{file_name}' table in the 'lumin.db' database.")



In [183]:
# Specify the path to your CSV/Excel file
ecom_data = [
  "olist_customers_dataset.csv",
  "olist_geolocation_dataset.csv",
  "olist_order_items_dataset.csv",
  "olist_order_payments_dataset.csv",
  "olist_order_reviews_dataset.csv",
  "olist_orders_dataset.csv",
  "olist_products_dataset.csv",
  "olist_sellers_dataset.csv",
  "product_category_name_translation.csv"
]

# for data in ecom_data:
#     file_path = ("./ecommerce/{data}".format(data=data))
#     # print(file_path)
#     insert_data_to_sqlite(file_path)

In [191]:
from sqlalchemy import inspect
from typing import List, Dict

# Assuming you have an SQLAlchemy engine created
engine = create_engine('sqlite:///lumin.db')  

def get_schemas(table_names: List[str]) -> List[Dict]:
    try:
        # Create an inspector object
        inspector = inspect(engine)

        # Initialize an array to hold the schema information for all tables
        schemas_info = []

        for table_name in table_names:
            schema_info = {
                "table_name": table_name,
                "schema": []
            }

            # Get the columns for the specified table
            columns = inspector.get_columns(table_name)

            # Collect column information
            for column in columns:
                schema_info["schema"].append({
                    "name": column['name'],
                    "type": str(column['type']),
                    "nullable": column['nullable']
                })

            # Append the schema information for the current table to the list
            schemas_info.append(schema_info)

        # Return the schema information for all tables
        return schemas_info

    except Exception as e:
        print(f"An error occurred: {e}")
        return []  # Return an empty list in case of an error

schema = get_schemas(['olist_products_dataset',"olist_orders_dataset","olist_customers_dataset","olist_order_items_dataset"])

print(schema)

[{'table_name': 'olist_products_dataset', 'schema': [{'name': 'product_id', 'type': 'TEXT', 'nullable': True}, {'name': 'product_category_name', 'type': 'TEXT', 'nullable': True}, {'name': 'product_name_lenght', 'type': 'FLOAT', 'nullable': True}, {'name': 'product_description_lenght', 'type': 'FLOAT', 'nullable': True}, {'name': 'product_photos_qty', 'type': 'FLOAT', 'nullable': True}, {'name': 'product_weight_g', 'type': 'FLOAT', 'nullable': True}, {'name': 'product_length_cm', 'type': 'FLOAT', 'nullable': True}, {'name': 'product_height_cm', 'type': 'FLOAT', 'nullable': True}, {'name': 'product_width_cm', 'type': 'FLOAT', 'nullable': True}]}, {'table_name': 'olist_orders_dataset', 'schema': [{'name': 'order_id', 'type': 'TEXT', 'nullable': True}, {'name': 'customer_id', 'type': 'TEXT', 'nullable': True}, {'name': 'order_status', 'type': 'TEXT', 'nullable': True}, {'name': 'order_purchase_timestamp', 'type': 'TEXT', 'nullable': True}, {'name': 'order_approved_at', 'type': 'TEXT', 'nu

In [287]:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text  # Import text for raw SQL queries
from typing import List, Dict, Any

# Assuming you have an SQLAlchemy engine created
engine = create_engine('sqlite:///lumin.db')  # Example for SQLite

# Create a session factory bound to the engine
Session = sessionmaker(bind=engine)

def execute_query(query: str) -> List[Any]:
    try:
        # Create a new session
        with Session() as session:  # Call Session() without any parameters
            # Use a text() construct for the query
            result = session.execute(text(query))  # Use text() for raw SQL
            # If the query is a SELECT statement, fetch the results
            if result.returns_rows:
                return [row for row in result.fetchall()]  # Convert RowProxy to dict
            else:
                # For non-SELECT queries, commit the transaction and return an empty list
                session.commit()
                return []

    except Exception as e:
        print(f"An error occurred: {e}")
        return []  # Return an empty list in case of an error


In [7]:
execute_query("SELECT * FROM results;")

[('1872-11-30', 'Scotland', 'England', 0, 0, 'Friendly', 'Glasgow', 'Scotland', 0),
 ('1873-03-08', 'England', 'Scotland', 4, 2, 'Friendly', 'London', 'England', 0),
 ('1874-03-07', 'Scotland', 'England', 2, 1, 'Friendly', 'Glasgow', 'Scotland', 0),
 ('1875-03-06', 'England', 'Scotland', 2, 2, 'Friendly', 'London', 'England', 0),
 ('1876-03-04', 'Scotland', 'England', 3, 0, 'Friendly', 'Glasgow', 'Scotland', 0),
 ('1876-03-25', 'Scotland', 'Wales', 4, 0, 'Friendly', 'Glasgow', 'Scotland', 0),
 ('1877-03-03', 'England', 'Scotland', 1, 3, 'Friendly', 'London', 'England', 0),
 ('1877-03-05', 'Wales', 'Scotland', 0, 2, 'Friendly', 'Wrexham', 'Wales', 0),
 ('1878-03-02', 'Scotland', 'England', 7, 2, 'Friendly', 'Glasgow', 'Scotland', 0),
 ('1878-03-23', 'Scotland', 'Wales', 9, 0, 'Friendly', 'Glasgow', 'Scotland', 0),
 ('1879-01-18', 'England', 'Wales', 2, 1, 'Friendly', 'London', 'England', 0),
 ('1879-04-05', 'England', 'Scotland', 5, 4, 'Friendly', 'London', 'England', 0),
 ('1879-04-07'

In [159]:
import json

def simple_json_extraction(content):
    # Find the JSON part
    start = content.find('{')
    end = content.rfind('}') + 1
    
    # Extract and parse the JSON
    json_str = content[start:end]
    parsed_json = json.loads(json_str)
    
    # Return the formatted JSON
    return json.dumps(parsed_json, indent=2)

In [160]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser

def get_parse_question(state: dict) -> dict:

    question = state['question']
    schema = state['schema']

    # print(schema, question)

    # prompt = ChatPromptTemplate.from_messages([
    #     ("system", '''
    # You are an expert data analyst tasked with analyzing SQL databases. Your goal is to interpret user questions, understand the schema provided, and identify relevant tables and columns.

    # Instructions:
    # 1. Based on the user question and the provided database schema, identify the relevant tables and columns.
    # 2. Set "is_relevant" to false if the question doesn't apply to the database or lacks sufficient information for an answer.
    # 3. Focus on columns with meaningful nouns (like names or entities) and exclude columns with non-noun values (e.g., IDs, numerical data).
    # 4. Return the response in the following JSON format:
    # {{
    #     "is_relevant": boolean,
    #     "relevant_tables": [
    #         {{
    #             "table_names": ["string"],
    #             "columns": ["string"],
    #             "noun_columns": ["string"]
    #         }}
    #     ]
    # }}

    # Key Notes:
    # - "noun_columns" should include only noun-based columns relevant to the question (e.g., "Artist name" for "What are the top-selling artists?")
    # - "noun_columns" must **only contain valid column names** from the schema. Do not include noun-based values from the user question unless they match an actual column name in the schema.
    # - Exclude numerical or ID columns that do not represent nouns or entities.
    # - The values in "noun_columns" must match the exact format of the corresponding column names in "columns".
    # - Check the schema if string is not available in it then don't include it in columns or noun_columns, EX: if question contain name as spandan than dont add spandan in columns or noun_columns as it is not a column name its a column value.
    # '''),
    #     ("human", "===Database Schema:\n{schema}\n\n===User Question:\n{question}\n\nIdentify the relevant tables and columns based on the provided information:")
    # ])

    prompt = ChatPromptTemplate.from_messages([
        ("system", '''
You are an expert data analyst tasked with analyzing SQL databases. Your goal is to interpret user questions, understand the provided schema, and identify relevant tables and columns.

Instructions:
1. Analyze the user question and database schema to identify relevant tables and columns.
2. Set "is_relevant" to false if the question is not applicable to the database or lacks sufficient information for an answer.
3. Focus on columns with meaningful nouns (e.g., names, entities) and exclude non-noun columns (e.g., IDs, numerical data) unless specifically relevant to the question.
4. Return the response in the following JSON format:
{{
  "is_relevant": boolean,
  "relevant_tables": [
    {{
      "table_name": "string",
      "columns": ["string"],
      "noun_columns": ["string"]
    }}
  ]
}}

Key Guidelines:
- Always verify column names against the provided schema.
- Include only existing schema column names in the "columns" and "noun_columns" lists.
- "noun_columns" shouldn't include any numeric value verify the type from schema, type must be not Int, Bigint or any type of integer value .
- Do not add query-mentioned values or entities to "columns" or "noun_columns" unless they are actual column names in the schema.
- Ensure "noun_columns" contains only valid column names from the schema, matching their exact format.
- Include in "noun_columns" only noun-based columns relevant to the question (e.g., "artist_name" for "Who are the top-selling artists?").
- Exclude numerical or ID columns from "noun_columns" unless they represent meaningful entities.
- If a term in the query matches a likely column value rather than a column name (e.g., "Brazil" in "matches where Brazil scored"), do not include it in the lists.

Example:
Question: "What is the total number of matches where the Brazil team scored more than 2 goals?"
- Do not include "Brazil team" in columns or noun_columns as it's likely a value, not a column name.
- Include relevant columns like "team_name", "goals_scored" if they exist in the schema.

    '''),
        ("human", "===Database Schema:\n{schema}\n\n===User Question:\n{question}\n\nIdentify the relevant tables and columns based on the provided information:")
    ])


    output_parser = JsonOutputParser()
    # Use the format method to create the formatted prompt
    formatted_prompt = prompt.format(schema=schema, question=question)
    
    # Invoke the LLM with the formatted prompt
    response = llm.invoke(formatted_prompt)
    
    extracted_json = simple_json_extraction(response.content)
    parsed_response = output_parser.parse(extracted_json)
    return {"parsed_question": parsed_response}


In [199]:
# parse_question = get_parse_question({"schema":schema,"question":"What is the total number of matches where the Brazil team scored more than 2 goals?"})
parse_question = get_parse_question({"schema":schema,"question":"Which product categories generate the most revenue?"})
print(parse_question)


{'parsed_question': {'is_relevant': True, 'relevant_tables': [{'table_name': 'olist_orders_dataset', 'columns': ['order_id', 'customer_id'], 'noun_columns': []}, {'table_name': 'olist_customers_dataset', 'columns': ['customer_id'], 'noun_columns': ['customer_city', 'customer_state']}]}}


In [200]:
parse_question

{'parsed_question': {'is_relevant': True,
  'relevant_tables': [{'table_name': 'olist_orders_dataset',
    'columns': ['order_id', 'customer_id'],
    'noun_columns': []},
   {'table_name': 'olist_customers_dataset',
    'columns': ['customer_id'],
    'noun_columns': ['customer_city', 'customer_state']}]}}

In [201]:
def get_unique_nouns(state: dict) -> dict:
    """Find unique nouns in relevant tables and columns."""
    # Extract parsed question details from the state
    parsed_question = state['parsed_question']
    
    # Return an empty dictionary if the question is not relevant
    if not parsed_question['is_relevant']:
        return {"unique_nouns": {}}

    # Initialize a dictionary to collect unique nouns segregated by table and column
    unique_nouns = {}

    # Loop through the relevant tables extracted from the parsed question
    for table_info in parsed_question['relevant_tables']:
        table_name = table_info['table_name']
        noun_columns = table_info['noun_columns']
        
        # Initialize a dictionary for the current table
        unique_nouns[table_name] = {}

        # If there are noun columns to process
        for column in noun_columns:
            # Construct the SQL query to select distinct values from the column
            query = f"SELECT DISTINCT `{column}` FROM `{table_name}`"
            
            with Session() as session:           
                result = session.execute(text(query))  # Use text() for raw SQL
                # If the query is a SELECT statement, fetch the results
                if result.returns_rows:
                    # Create a set for unique values for the current column
                    unique_values = set()
                    for row in result:
                        # Add each non-null value as a string to the unique values set
                        if row[0]:  # row[0] is the value of the column
                            unique_values.add(str(row[0]))

                    # Store the unique values in the dictionary under the respective table and column
                    unique_nouns[table_name][column] = unique_values

                else:
                    # For non-SELECT queries, commit the transaction
                    session.commit()

    # Return the unique nouns segregated by table and column
    return {"unique_nouns": unique_nouns}


In [202]:
unique_nouns = get_unique_nouns(parse_question)

In [205]:
# unique_nouns

### Generate SQL Query

In [221]:
prompt = ChatPromptTemplate.from_messages([
    ("system", '''
    You are an AI assistant that generates SQL queries based on user questions, database schema, and unique nouns found in the relevant tables. Your goal is to generate valid SQL queries that can directly answer the user's question.

    ### Instructions:
    1. Parse the user question, identify relevant tables and columns from the schema, and generate an SQL query using the correct table and column names.
    2. Ensure the SQL query answers the question using only two or three columns in the result.
    3. If there isn't enough information to generate a query, return "NOT_ENOUGH_INFO".
    4. Always enclose table and column names in backticks (`) for SQL syntax consistency.
    5. Skip rows where any column is NULL, empty (""), or contains "N/A".
    6. Use the exact spellings of nouns from the unique nouns list, but only include nouns that match actual column names in the schema.

    Here are some examples:

    1. **What is the top selling product?**
       **Type**: Simple Aggregation  
       **Answer**: 
       ```sql
       SELECT `product_name`, SUM(`quantity`) AS `total_quantity`
       FROM `sales`
       WHERE `product_name` IS NOT NULL AND `quantity` IS NOT NULL 
       AND `product_name` != "" AND `quantity` != "" 
       AND `product_name` != "N/A" AND `quantity` != "N/A" 
       GROUP BY `product_name` 
       ORDER BY `total_quantity` DESC 
       LIMIT 1```
         
    2. **What is the total revenue for each product?**
       **Type**: Revenue Calculation
       **Answer**: 
       ```sql
         SELECT `product_name`, SUM(`quantity` * `price`) AS `total_revenue`
         FROM `sales`
         WHERE `product_name` IS NOT NULL AND `quantity` IS NOT NULL 
         AND `price` IS NOT NULL AND `product_name` != "" 
         AND `quantity` != "" AND `price` != "" 
         AND `product_name` != "N/A" AND `quantity` != "N/A" 
         AND `price` != "N/A"
         GROUP BY `product_name`
         ORDER BY `total_revenue` DESC
         ```

    3. **What is the market share of each product?** 
       **Type**: Market Share Calculation
       **Answer**:
       ```sql
         SELECT `product_name`, 
         SUM(`quantity`) * 100.0 / (SELECT SUM(`quantity`) FROM `sales`) AS `market_share`
         FROM `sales`
         WHERE `product_name` IS NOT NULL AND `quantity` IS NOT NULL 
         AND `product_name` != "" AND `quantity` != "" 
         AND `product_name` != "N/A" AND `quantity` != "N/A"
         GROUP BY `product_name`
         ORDER BY `market_share` DESC
         ```
    4. **Which customers purchased the top-selling products?** 
       **Type**: Join Query
       **Answer**:
        ```sql
         SELECT `customers`.`customer_name`, `sales`.`product_name`, `sales`.`total_quantity`
         FROM `customers`
         JOIN `sales` ON `customers`.`customer_id` = `sales`.`customer_id`
         WHERE `sales`.`total_quantity` = (
             SELECT MAX(`total_quantity`) FROM `sales`
         )
        ```

    5. **Plot the distribution of income over time.**
       **Type**: Distribution Plot
       **Answer**:
         ```sql
         SELECT `income`, COUNT(*) AS `count`
         FROM `users`
         WHERE `income` IS NOT NULL AND `income` != "" AND `income` != "N/A"
         GROUP BY `income`
        ```

    6. **What is the total sales between 2021 and 2023?**
       **Type**: Date Range Query
       **Answer**:
         ```sql
         SELECT SUM(`quantity` * `price`) AS `total_sales`
         FROM `sales`
         WHERE `sale_date` BETWEEN '2021-01-01' AND '2023-12-31'
         ```
    7. **Find the total sales for each region, including customer count**
       **Type**: Complex Aggregation
       **Answer**:
         ```sql
         SELECT `regions`.`region_name`, SUM(`sales`.`quantity` * `sales`.`price`) AS `total_sales`, COUNT(DISTINCT `customers`.`customer_id`) AS `customer_count`
         FROM `sales`
         JOIN `customers` ON `sales`.`customer_id` = `customers`.`customer_id`
         JOIN `regions` ON `customers`.`region_id` = `regions`.`region_id`
         GROUP BY `regions`.`region_name`
         ORDER BY `total_sales` DESC
         ```
         
    ### Format for Results:
    - For simple queries (without labels): `[[x, y]]`
    - For queries with labels: `[[label, x, y]]`

    Just return the SQL query string based on the schema, question, and unique nouns provided.
    '''), 
    ("human", '''===Database schema: {schema}

    ===User question: {question}

    ===Relevant tables and columns: {relevant_table_column}
      

    Generate SQL query string:''')
])

    # ===Unique nouns in relevant tables: {unique_nouns}
    # This is the unique_noun format dont get confused with it :
    #   {
    #     "unique_nouns": {
    #         "table_name": {
    #             "column_name": {"noun1", "noun2",...},
    #             ...
    #         },
    #           "table_name": {
    #             "column_name": {"noun1", "noun2",...},
    #             ...
    #         },
    #          ...
    #     }
    # }


In [290]:
question = "Which product categories generate the most revenue?"
formatted_prompt = prompt.format(schema=schema, question=question, relevant_table_column=parse_question, unique_nouns=unique_nouns)

In [291]:
formatted_prompt

'System: \n    You are an AI assistant that generates SQL queries based on user questions, database schema, and unique nouns found in the relevant tables. Your goal is to generate valid SQL queries that can directly answer the user\'s question.\n\n    ### Instructions:\n    1. Parse the user question, identify relevant tables and columns from the schema, and generate an SQL query using the correct table and column names.\n    2. Ensure the SQL query answers the question using only two or three columns in the result.\n    3. If there isn\'t enough information to generate a query, return "NOT_ENOUGH_INFO".\n    4. Always enclose table and column names in backticks (`) for SQL syntax consistency.\n    5. Skip rows where any column is NULL, empty (""), or contains "N/A".\n    6. Use the exact spellings of nouns from the unique nouns list, but only include nouns that match actual column names in the schema.\n\n    Here are some examples:\n\n    1. **What is the top selling product?**\n     

In [253]:
def generete_sql_query():

    # output_parser = JsonOutputParser()

    chain = prompt | llm | output_parser
    # extracted_json = simple_json_extraction(response.content)
    # parsed_response = output_parser.parse(extracted_json)
    response = chain.invoke({"schema":schema, "question":question, "relevant_table_column":parse_question})
    # print(response)
    clean_sql_query = response.strip('`').replace('sql\n', '', 1).strip()
    if response.strip() == "NOT_ENOUGH_INFO":
        return {"sql_query": "NOT_RELEVANT"}
    else:
        return {"sql_query": clean_sql_query}

In [292]:
generated_sql_que = generete_sql_query()

### Validate SQL Query


In [283]:
def validate_and_fix_sql(state: dict) -> dict:
    """Validate and fix the generated SQL query."""
    sql_query = state['sql_query']
    if sql_query == "NOT_RELEVANT":
        return {"sql_query": "NOT_RELEVANT", "sql_valid": False}
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", '''
    You are an AI assistant that validates and fixes SQL queries. Your task is to:
    1. Check if the SQL query is valid.
    2. Ensure all table and column names are correctly spelled and exist in the schema. All table and column names should be enclosed in backticks, especially if they contain spaces or special characters.
    3. Ensure the SQL query follows proper syntax (e.g., `JOIN`, `WHERE`, and other clauses are used correctly).
    4. Take into account case sensitivity based on the schema.
    5. If there are any issues, fix them and provide the corrected SQL query.
    6. If no issues are found, return the original query.

    Respond in JSON format with the following structure. Only respond with the JSON:
    {{
        "valid": boolean,
        "issues": string or null,
        "corrected_query": string
    }}
    '''),
        ("human", '''===Database schema:
    {schema}

    ===Generated SQL query:
    {sql_query}

    Respond in JSON format with the following structure. Only respond with the JSON:
    {{
        "valid": boolean,
        "issues": string or null,
        "corrected_query": string
    }}

    For example:
    1. {{
        "valid": true,
        "issues": null,
        "corrected_query": "None"
    }}
                
    2. {{
        "valid": false,
        "issues": "Column USERS does not exist",
        "corrected_query": "SELECT * FROM \`users\` WHERE age > 25"
    }}

    3. {{
        "valid": false,
        "issues": "Column names and table names should be enclosed in backticks if they contain spaces or special characters",
        "corrected_query": "SELECT * FROM \`gross income\` WHERE \`age\` > 25"
    }}
                
    '''),
    ])

    # prompt.format(schema=schema, sql_query=sql_query)
    output_parser = JsonOutputParser()

    chain = prompt | llm | output_parser

    result = chain.invoke({"schema":schema, "sql_query":sql_query})
    
    if result["valid"] and result["issues"] is None:
        return {"sql_query": sql_query, "sql_valid": True}
    else:
        return {
            "sql_query": result["corrected_query"],
            "sql_valid": result["valid"],
            "sql_issues": result["issues"]
        }


In [293]:
valid_sql_query = validate_and_fix_sql(generated_sql_que)

In [296]:
query_result = execute_query(valid_sql_query["sql_query"])

### Get answer in sentence

In [305]:
def format_results(state: dict) -> dict:
    """Format query results into a human-readable response."""
    question = state['question']
    results = state['results']
    if results == "NOT_RELEVANT":
        return {"answer": "Sorry, I can only give answers relevant to the database."}
    prompt = ChatPromptTemplate.from_messages([
        ("system", '''
    You are an AI assistant that converts database query results into a clear, concise human-readable response. Your goal is to provide a brief conclusion to the user's question based on the query results. 
    Instructions:
    1. Respond in one sentence.
    2. Highlight the key result by enclosing it in double asterisks (**).
    3. Avoid using markdown or unnecessary formatting.

    '''),
        ("human", "User question: {question}\n\nQuery results: {results}\n\nConclusion:")
    ])
    chain = prompt | llm | output_parser
    response = chain.invoke({"question":question, "results":results})
    return {"answer": response}

In [306]:
format_results({
    "question":question, 
    "results":query_result
})

{'answer': 'The product category that generates the most revenue is **beleza_saude** . \n'}

### Format response for visualization

In [313]:
def choose_visualization(state: dict) -> dict:
    """Choose an appropriate visualization for the data."""
    question = state['question']
    results = state['results']
    sql_query = state['sql_query']

    if results == "NOT_RELEVANT":
        return {"visualization": "none", "visualization_reasoning": "No visualization needed for irrelevant questions."}
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", '''
    You are an AI assistant recommending the best data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable graph or chart type.

    ### Chart Types:
    - **Bar Graph**: For comparing categorical data or showing changes over time with more than two categories.
    - **Horizontal Bar Graph**: For comparing few categories or when there's a large disparity between them.
    - **Scatter Plot**: For showing relationships or distributions between two continuous numerical variables.
    - **Pie Chart**: For displaying proportions or percentages of a whole.
    - **Line Graph**: For showing trends over time, where both x and y axes are continuous.
    - **None**: If no visualization is appropriate.

    ### Consider These Questions:
    1. **Aggregations**: Summarize data (e.g., average revenue by month) — Line Graph.
    2. **Comparisons**: Compare metrics (e.g., sales of Product A vs. Product B) — Line or Bar Graph.
    3. **Distributions**: Show data distribution (e.g., age distribution) — Scatter Plot.
    4. **Trends Over Time**: Show changes over time (e.g., website visits) — Line Graph.
    5. **Proportions**: Show percentages (e.g., market share) — Pie Chart.
    6. **Correlations**: Show relationships (e.g., marketing spend vs. revenue) — Scatter Plot.

    ### Format:
         {{
            recommended_visualization: string (bar | horizontal_bar | line | pie | scatter | none),
            reason: Brief explanation of your recommendation
         }}
    '''),
        ("human", '''
    User question: {question}
    SQL query: {sql_query}
    Query results: {results}

    Recommend a visualization:
        '''),
    ])

    output_parser = JsonOutputParser()

    chain = prompt | llm | output_parser

    response = chain.invoke({"question":question,"sql_query":sql_query,"results":results})

    return response


In [314]:
choose_visualization({
    "question":question, 
    "results":query_result,
    "sql_query":valid_sql_query
})

{'recommended_visualization': 'bar',
 'reason': 'The query results show product categories and their total revenue. A bar graph is suitable for comparing categorical data (product categories) using the revenue as the value.'}