In [1]:
import os
import google.generativeai as genai
from langchain.chains import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.agent_toolkits import create_sql_agent
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from urllib.parse import quote_plus


In [2]:
# Load environment variables
load_dotenv(dotenv_path="google_apiKEY.env")
api_key = os.getenv("GOOGLE_API_KEY")

if not api_key:
    raise ValueError("API key not found")

genai.configure(api_key=api_key)

In [3]:
#Database connection
username = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
host = os.getenv("DB_HOST")
port = os.getenv("DB_PORT", "3306")
database = os.getenv("DB_NAME")

password_encoded = quote_plus(password)
mysql_uri = f"mysql+pymysql://{username}:{password_encoded}@{host}:{port}/{database}"
db = SQLDatabase.from_uri(mysql_uri)

In [4]:
print(db.get_table_names())
print(db.table_info)

['walmart_sales']

CREATE TABLE walmart_sales (
	`Invoice ID` TEXT, 
	`Branch` TEXT, 
	`City` TEXT, 
	`Customer type` TEXT, 
	`Gender` TEXT, 
	`Product line` TEXT, 
	`Unit price` DOUBLE, 
	`Quantity` INTEGER, 
	`Tax 5%%` DOUBLE, 
	`Total` DOUBLE, 
	`Date` TEXT, 
	`Time` TEXT, 
	`Payment` TEXT, 
	cogs DOUBLE, 
	`gross margin percentage` DOUBLE, 
	`gross income` DOUBLE, 
	`Rating` DOUBLE
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4

/*
3 rows from walmart_sales table:
Invoice ID	Branch	City	Customer type	Gender	Product line	Unit price	Quantity	Tax 5%	Total	Date	Time	Payment	cogs	gross margin percentage	gross income	Rating
750-67-8428	A	Yangon	Member	Female	Health and beauty	74.6900000000	7	26.1415000000	548.9715000000	2019-01-05	13:08:00	Ewallet	522.8300000000	4.7619047620	26.1415000000	9.1000000000
226-31-3081	C	Naypyitaw	Normal	Female	Electronic accessories	15.2800000000	5	3.8200000000	80.2200000000	2019-03-08	10:29:00	Cash	76.4000000000	4.7619047620	3.8200000000	9

  print(db.get_table_names())


In [5]:

import pprint

# 1. Initialize your LLM
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",  # Your model
    google_api_key=api_key,
    temperature=0
)
print("LLM initialized successfully.")
query_chain = create_sql_query_chain(llm, db)

print("LLM and Query Chain are ready.")


LLM initialized successfully.
LLM and Query Chain are ready.


In [6]:
def get_sql_answer(question: str, chain, model, database):
    """
    Takes a question, generates SQL, executes it, and gets a final answer.
    Prints all steps in the format requested.
    """
    
    # Step 1: Generate the query
    sql_query = chain.invoke({"question": question})
    print(f"Original, bad query: {sql_query}")
    
    # Step 2: Clean the query string - remove both SQLQuery: prefix AND markdown code blocks
    cleaned_sql = sql_query.replace("SQLQuery:", "").strip()
    
    # Remove markdown code blocks if present
    if cleaned_sql.startswith("```sql"):
        cleaned_sql = cleaned_sql[6:]  # Remove ```sql
    elif cleaned_sql.startswith("```"):
        cleaned_sql = cleaned_sql[3:]  # Remove ```
    
    if cleaned_sql.endswith("```"):
        cleaned_sql = cleaned_sql[:-3]  # Remove trailing ```
    
    cleaned_sql = cleaned_sql.strip()
    print(f"Cleaned, good query: {cleaned_sql}")
    
    # Step 3: Run the cleaned query against the database
    result = database.run(cleaned_sql)
    print("\n--- Query Result ---")
    print(result)
    
    # Step 4: Use the LLM to get a natural language answer
    final_prompt = f"""
    Based on this SQL query: {cleaned_sql}
    And this result: {result}
    
    Please provide a clean, natural language answer to the original question: {question}
    """
    
    final_answer = model.invoke(final_prompt)
    print("\n--- Final Answer ---")
    print(final_answer.content)

In [7]:
q1 = "Find the top-performing Product line by total revenue (Total) in each City. Show city, product line, total revenue, and rank."
get_sql_answer(question=q1, chain=query_chain, model=llm, database=db)

Original, bad query: SQLQuery: WITH RankedSales AS (
  SELECT
    `City`,
    `Product line`,
    SUM(`Total`) AS `Total Revenue`,
    RANK() OVER (PARTITION BY `City` ORDER BY SUM(`Total`) DESC) AS `Rank`
  FROM walmart_sales
  GROUP BY
    `City`,
    `Product line`
)
SELECT
  `City`,
  `Product line`,
  `Total Revenue`,
  `Rank`
FROM RankedSales
WHERE
  `Rank` = 1
ORDER BY
  `City`
LIMIT 5;
Cleaned, good query: WITH RankedSales AS (
  SELECT
    `City`,
    `Product line`,
    SUM(`Total`) AS `Total Revenue`,
    RANK() OVER (PARTITION BY `City` ORDER BY SUM(`Total`) DESC) AS `Rank`
  FROM walmart_sales
  GROUP BY
    `City`,
    `Product line`
)
SELECT
  `City`,
  `Product line`,
  `Total Revenue`,
  `Rank`
FROM RankedSales
WHERE
  `Rank` = 1
ORDER BY
  `City`
LIMIT 5;

--- Query Result ---
[('Mandalay', 'Sports and travel', 19988.198999999997, 1), ('Naypyitaw', 'Food and beverages', 23766.854999999992, 1), ('Yangon', 'Home and lifestyle', 22417.195499999998, 1)]

--- Final Answer 

In [8]:
q1 = "For each Branch, determine the hour of the day with the highest total sales."
get_sql_answer(question=q1, chain=query_chain, model=llm, database=db)

Original, bad query: SQLQuery: SELECT
  `Branch`,
  EXTRACT(HOUR FROM STR_TO_DATE(`Time`, '%H:%i:%s')) AS `SalesHour`,
  SUM(`Total`) AS `TotalSales`
FROM
  walmart_sales
GROUP BY
  `Branch`,
  `SalesHour`
ORDER BY
  `Branch`,
  `TotalSales` DESC;
Cleaned, good query: SELECT
  `Branch`,
  EXTRACT(HOUR FROM STR_TO_DATE(`Time`, '%H:%i:%s')) AS `SalesHour`,
  SUM(`Total`) AS `TotalSales`
FROM
  walmart_sales
GROUP BY
  `Branch`,
  `SalesHour`
ORDER BY
  `Branch`,
  `TotalSales` DESC;

--- Query Result ---
[('A', 11, 11349.891), ('A', 15, 11273.703), ('A', 10, 11208.414000000002), ('A', 16, 10869.736499999999), ('A', 13, 10443.751500000002), ('A', 19, 10330.257000000001), ('A', 12, 9485.070000000002), ('A', 17, 9043.734), ('A', 14, 8852.4135), ('A', 18, 7447.02), ('A', 20, 5896.38), ('B', 19, 16262.4525), ('B', 14, 11694.564), ('B', 13, 11272.411500000002), ('B', 11, 10481.813999999998), ('B', 15, 10241.164499999999), ('B', 18, 9555.283500000003), ('B', 10, 8865.843), ('B', 12, 8475.410999

In [9]:
q1 = "Calculate the percentage contribution of each gender and customer type combination to the total revenue."
get_sql_answer(question=q1, chain=query_chain, model=llm, database=db)

Original, bad query: SQLQuery: SELECT
  `Gender`,
  `Customer type`,
  SUM(`Total`) AS `Revenue`,
  (SUM(`Total`) / (
    SELECT
      SUM(`Total`)
    FROM walmart_sales
  )) * 100 AS `Percentage Contribution`
FROM walmart_sales
GROUP BY
  `Gender`,
  `Customer type`
ORDER BY
  `Percentage Contribution` DESC
LIMIT 5;
Cleaned, good query: SELECT
  `Gender`,
  `Customer type`,
  SUM(`Total`) AS `Revenue`,
  (SUM(`Total`) / (
    SELECT
      SUM(`Total`)
    FROM walmart_sales
  )) * 100 AS `Percentage Contribution`
FROM walmart_sales
GROUP BY
  `Gender`,
  `Customer type`
ORDER BY
  `Percentage Contribution` DESC
LIMIT 5;

--- Query Result ---
[('Female', 'Member', 88146.94349999996, 27.292885033189577), ('Female', 'Normal', 79735.98149999998, 24.688603934270635), ('Male', 'Normal', 79007.3235, 24.46298999653366), ('Male', 'Member', 76076.50049999995, 23.555521036006073)]

--- Final Answer ---
Here is the percentage contribution of each gender and customer type combination to the total

In [10]:
q1 = "Find the payment method that has the highest average customer rating for each Product line, considering only methods with total sales above the product line’s average."
get_sql_answer(question=q1, chain=query_chain, model=llm, database=db)

Original, bad query: SQLQuery: WITH ProductLineAvgSales AS (
    SELECT
        `Product line`,
        AVG(`Total`) AS `avg_product_line_total_sales`
    FROM
        walmart_sales
    GROUP BY
        `Product line`
),
PaymentMethodStats AS (
    SELECT
        `Product line`,
        `Payment`,
        SUM(`Total`) AS `payment_method_total_sales`,
        AVG(`Rating`) AS `avg_payment_method_rating`
    FROM
        walmart_sales
    GROUP BY
        `Product line`,
        `Payment`
),
RankedPaymentMethods AS (
    SELECT
        pms.`Product line`,
        pms.`Payment`,
        pms.`avg_payment_method_rating`,
        ROW_NUMBER() OVER (PARTITION BY pms.`Product line` ORDER BY pms.`avg_payment_method_rating` DESC) as rn
    FROM
        PaymentMethodStats pms
    JOIN
        ProductLineAvgSales plas ON pms.`Product line` = plas.`Product line`
    WHERE
        pms.`payment_method_total_sales` > plas.`avg_product_line_total_sales`
)
SELECT
    `Product line`,
    `Payment`,
    `

In [11]:
q1 = "Compute monthly total sales and growth rate (%) from the previous month across all branches."
get_sql_answer(question=q1, chain=query_chain, model=llm, database=db)

Original, bad query: SQLQuery: WITH MonthlySales AS (
    SELECT
        DATE_FORMAT(STR_TO_DATE(`Date`, '%Y-%m-%d'), '%Y-%m') AS `sales_month`,
        SUM(`Total`) AS `monthly_total_sales`
    FROM
        walmart_sales
    GROUP BY
        `sales_month`
)
SELECT
    `sales_month`,
    `monthly_total_sales`,
    CASE
        WHEN LAG(`monthly_total_sales`, 1) OVER (ORDER BY `sales_month`) IS NULL THEN NULL
        ELSE ((`monthly_total_sales` - LAG(`monthly_total_sales`, 1) OVER (ORDER BY `sales_month`)) / LAG(`monthly_total_sales`, 1) OVER (ORDER BY `sales_month`)) * 100
    END AS `growth_rate_percentage`
FROM
    MonthlySales
ORDER BY
    `sales_month`;
Cleaned, good query: WITH MonthlySales AS (
    SELECT
        DATE_FORMAT(STR_TO_DATE(`Date`, '%Y-%m-%d'), '%Y-%m') AS `sales_month`,
        SUM(`Total`) AS `monthly_total_sales`
    FROM
        walmart_sales
    GROUP BY
        `sales_month`
)
SELECT
    `sales_month`,
    `monthly_total_sales`,
    CASE
        WHEN LAG(`mont

In [12]:
test_questions = [
    "Find the total revenue generated by each branch.",
    "Identify the top 3 product lines with the highest average rating.",
    "Find the city with the maximum total gross income.",
    "Find which payment method generates the highest average gross income.",
    "Determine the busiest sales hour based on the number of invoices.",
    "Find the average rating given by Members vs Normal customers.",
    "Calculate monthly total sales and the month-over-month growth percentage.",
    "Find the product line with the highest average quantity per transaction.",
    "Identify the top 5 invoices with the highest total amount.",
    "Compute total and average gross income for each product line per gender.",
    "Find which customer type (Member/Normal) spends above the overall average total per transaction.",
    "Compute daily total sales and rank the top 5 sales days.",
    "Find the top 2 product lines in each city by total revenue.",
    "Find which branch has the highest customer satisfaction (average rating).",
    "Calculate the total quantity sold by gender and branch.",
]
   


In [13]:
# Comprehensive few-shot examples covering all difficulty levels
# All values must be strings for vectorization
few_shots = [
    # ============= NORMAL QUERIES =============
    {
        'Question': "Find total revenue generated across all transactions.",
        'SQLQuery': "SELECT ROUND(SUM(Total), 2) AS total_revenue FROM walmart_sales;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The total revenue across all transactions is calculated by summing the Total column."
    },
    {
        'Question': "Count the total number of sales transactions per branch.",
        'SQLQuery': "SELECT Branch, COUNT(`Invoice ID`) AS total_transactions FROM walmart_sales GROUP BY Branch;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The number of transactions per branch is found by counting Invoice IDs grouped by Branch."
    },
    {
        'Question': "Calculate average customer rating per product line.",
        'SQLQuery': "SELECT `Product line`, ROUND(AVG(Rating), 2) AS avg_rating FROM walmart_sales GROUP BY `Product line`;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The average rating per product line is calculated by averaging Rating grouped by Product line."
    },
    {
        'Question': "Find the number of transactions done using each payment method.",
        'SQLQuery': "SELECT Payment, COUNT(*) AS total_payments FROM walmart_sales GROUP BY Payment;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The number of transactions per payment method is found by counting and grouping by Payment."
    },
    {
        'Question': "Retrieve total quantity sold for each product line.",
        'SQLQuery': "SELECT `Product line`, SUM(Quantity) AS total_quantity FROM walmart_sales GROUP BY `Product line`;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The total quantity sold per product line is calculated by summing Quantity grouped by Product line."
    },
    
    # ============= MEDIUM QUERIES =============
    {
        'Question': "Find average revenue per transaction for each branch.",
        'SQLQuery': "SELECT Branch, ROUND(SUM(Total)/COUNT(`Invoice ID`), 2) AS avg_revenue_per_txn FROM walmart_sales GROUP BY Branch;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The average revenue per transaction for each branch is calculated by dividing total sales by transaction count."
    },
    {
        'Question': "Determine which product line has the highest total sales in each city.",
        'SQLQuery': "SELECT City, `Product line`, SUM(Total) AS total_sales FROM walmart_sales GROUP BY City, `Product line` ORDER BY City, total_sales DESC;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Product lines with highest sales per city are found by grouping by City and Product line, then ordering by total sales."
    },
    {
        'Question': "Calculate monthly total sales trend.",
        'SQLQuery': "SELECT DATE_FORMAT(STR_TO_DATE(Date, '%Y-%m-%d'), '%Y-%m') AS month, SUM(Total) AS total_sales FROM walmart_sales GROUP BY month ORDER BY month;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Monthly sales trend is calculated by extracting month from Date and summing Total for each month."
    },
    {
        'Question': "Find top 3 most popular product lines by number of transactions.",
        'SQLQuery': "SELECT `Product line`, COUNT(*) AS transactions FROM walmart_sales GROUP BY `Product line` ORDER BY transactions DESC LIMIT 3;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The top 3 most popular product lines are found by counting transactions and limiting to top 3."
    },
    {
        'Question': "Compare average gross income of male vs. female customers.",
        'SQLQuery': "SELECT Gender, ROUND(AVG(`gross income`), 2) AS avg_gross_income FROM walmart_sales GROUP BY Gender;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Average gross income by gender is calculated by grouping by Gender and averaging gross income."
    },
    
    # ============= COMPLEX QUERIES =============
    {
        'Question': "Find top-performing product line per city using ranking.",
        'SQLQuery': "SELECT City, `Product line`, SUM(Total) AS total_revenue, RANK() OVER (PARTITION BY City ORDER BY SUM(Total) DESC) AS rank_in_city FROM walmart_sales GROUP BY City, `Product line`;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Top-performing product lines per city are identified using window functions with RANK to partition by City."
    },
    {
        'Question': "Identify the hour of the day with the highest total sales for each branch.",
        'SQLQuery': "SELECT Branch, HOUR(STR_TO_DATE(Time, '%H:%i:%s')) AS hour_of_day, SUM(Total) AS total_sales FROM walmart_sales GROUP BY Branch, hour_of_day ORDER BY Branch, total_sales DESC;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The busiest hour per branch is found by extracting hour from Time and summing sales grouped by Branch and hour."
    },
    {
        'Question': "Calculate revenue contribution percentage by gender and customer type.",
        'SQLQuery': "SELECT Gender, `Customer type`, ROUND(SUM(Total) * 100 / (SELECT SUM(Total) FROM walmart_sales), 2) AS revenue_percent FROM walmart_sales GROUP BY Gender, `Customer type` ORDER BY revenue_percent DESC;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Revenue contribution percentage is calculated by dividing group total by overall total and multiplying by 100."
    },
    {
        'Question': "Detect transactions where gross income is an outlier greater than mean plus two standard deviations per product line.",
        'SQLQuery': "WITH stats AS (SELECT `Product line`, AVG(`gross income`) AS avg_income, STDDEV(`gross income`) AS std_income FROM walmart_sales GROUP BY `Product line`) SELECT w.`Product line`, w.`Invoice ID`, w.`gross income`, s.avg_income, s.std_income, ROUND((w.`gross income` - s.avg_income) / s.std_income, 2) AS z_score FROM walmart_sales w JOIN stats s ON w.`Product line` = s.`Product line` WHERE w.`gross income` > s.avg_income + 2 * s.std_income ORDER BY z_score DESC;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Outlier transactions are detected using statistical methods by comparing to mean plus 2 standard deviations with z-score calculation."
    },
    {
        'Question': "Compute monthly sales growth rate percentage over previous month.",
        'SQLQuery': "SELECT DATE_FORMAT(STR_TO_DATE(Date, '%Y-%m-%d'), '%Y-%m') AS month, SUM(Total) AS total_sales, ROUND((SUM(Total) - LAG(SUM(Total)) OVER (ORDER BY DATE_FORMAT(STR_TO_DATE(Date, '%Y-%m-%d'), '%Y-%m'))) / LAG(SUM(Total)) OVER (ORDER BY DATE_FORMAT(STR_TO_DATE(Date, '%Y-%m-%d'), '%Y-%m')) * 100, 2) AS growth_percent FROM walmart_sales GROUP BY month ORDER BY month;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Monthly growth rate is calculated using LAG window function to compare current month sales with previous month."
    },
    
    # ============= EXTREME QUERIES =============
    {
        'Question': "Evaluate branch efficiency by weighted customer satisfaction versus revenue per transaction.",
        'SQLQuery': "WITH branch_perf AS (SELECT Branch, SUM(Total) AS total_sales, AVG(Total) AS avg_txn_value, SUM(Rating * Total) / SUM(Total) AS weighted_rating FROM walmart_sales GROUP BY Branch), global_avg AS (SELECT AVG(Total) AS global_avg_txn FROM walmart_sales) SELECT b.Branch, ROUND(b.total_sales, 2) AS total_sales, ROUND(b.avg_txn_value, 2) AS avg_txn_value, ROUND(b.weighted_rating, 2) AS performance_score, ROUND(b.avg_txn_value / g.global_avg_txn, 2) AS efficiency_ratio, RANK() OVER (ORDER BY b.avg_txn_value / g.global_avg_txn DESC) AS rank_by_efficiency FROM branch_perf b CROSS JOIN global_avg g;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Branch efficiency is evaluated using multiple CTEs to calculate weighted ratings and efficiency ratios compared to global averages."
    },
    {
        'Question': "Detect underperforming product lines in each city below average total revenue.",
        'SQLQuery': "WITH city_avg AS (SELECT City, AVG(SUM_T) AS avg_revenue FROM (SELECT City, `Product line`, SUM(Total) AS SUM_T FROM walmart_sales GROUP BY City, `Product line`) t GROUP BY City) SELECT w.City, w.`Product line`, SUM(w.Total) AS total_revenue, c.avg_revenue FROM walmart_sales w JOIN city_avg c ON w.City = c.City GROUP BY w.City, w.`Product line` HAVING SUM(w.Total) < c.avg_revenue ORDER BY City, total_revenue;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Underperforming product lines are identified by comparing their revenue to city-specific averages using nested CTEs."
    },
    {
        'Question': "Perform RFM recency frequency monetary segmentation of customers using invoice-level data.",
        'SQLQuery': "WITH rfm AS (SELECT `Customer type`, MAX(STR_TO_DATE(Date, '%Y-%m-%d')) AS last_purchase, COUNT(`Invoice ID`) AS frequency, SUM(Total) AS monetary_value FROM walmart_sales GROUP BY `Customer type`) SELECT `Customer type`, DATEDIFF('2019-12-31', last_purchase) AS recency_days, frequency, ROUND(monetary_value, 2) AS total_spent FROM rfm ORDER BY total_spent DESC;",
        'SQLResult': "Result of the SQL query",
        'Answer': "RFM analysis segments customers by calculating recency days since last purchase, frequency transaction count, and monetary value total spent."
    },
    {
        'Question': "Find the combination of payment method and product line that yields the highest profit margin per branch.",
        'SQLQuery': "SELECT Branch, `Product line`, Payment, ROUND(SUM(`gross income`) / SUM(cogs) * 100, 2) AS profit_margin_percent FROM walmart_sales GROUP BY Branch, `Product line`, Payment ORDER BY profit_margin_percent DESC LIMIT 10;",
        'SQLResult': "Result of the SQL query",
        'Answer': "The highest profit margin combinations are found by calculating gross income to COGS ratio for each branch product line and payment method combination."
    },
    {
        'Question': "Identify the day of week with the highest sales variability for staffing optimization.",
        'SQLQuery': "SELECT DAYNAME(STR_TO_DATE(Date, '%Y-%m-%d')) AS day_name, ROUND(AVG(Total), 2) AS avg_sales, ROUND(STDDEV(Total), 2) AS sales_variability FROM walmart_sales GROUP BY day_name ORDER BY sales_variability DESC;",
        'SQLResult': "Result of the SQL query",
        'Answer': "Sales variability by day of week is calculated using standard deviation to identify days with highest fluctuations for staffing optimization."
    }
]

# Recreate vectorization
to_vectorize = [" ".join(str(v) for v in example.values()) for example in few_shots]

print(f" Updated {len(few_shots)} few-shot examples with proper string formatting")

 Updated 20 few-shot examples with proper string formatting


In [14]:
# Import components for semantic similarity-based example selection
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_community.embeddings import HuggingFaceEmbeddings  # Use community package instead


In [15]:
# Initialize the embeddings model using HuggingFace
# Set encode_kwargs to avoid TensorFlow/Keras issues
embeddings = HuggingFaceEmbeddings(
    model_name='sentence-transformers/all-MiniLM-L6-v2',
    model_kwargs={'device': 'cpu'},
    encode_kwargs={'normalize_embeddings': True}
)

print(" Imports completed and embeddings model initialized")
print("   - Using HuggingFace Embeddings (all-MiniLM-L6-v2)")

  embeddings = HuggingFaceEmbeddings(



 Imports completed and embeddings model initialized
   - Using HuggingFace Embeddings (all-MiniLM-L6-v2)


In [16]:
# Create Chroma vector store from the few-shot examples
# This stores embeddings and allows efficient semantic similarity search
vectorstore = Chroma.from_texts(
    texts=to_vectorize,  # The concatenated text strings to embed
    embedding=embeddings,  # The HuggingFace embedding model
    metadatas=few_shots  # Store the original examples as metadata for retrieval
)

print(" Chroma vector store created successfully")
print(f"   - Total vectors stored: {len(to_vectorize)}")
print(f"   - Embedding dimension: 384 (from all-MiniLM-L6-v2)")

# Create semantic similarity-based example selector
# This will dynamically select the most relevant examples based on the input question
example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,  # The Chroma vector store containing embedded examples
    k=2  # Number of most similar examples to retrieve
)

print("\n SemanticSimilarityExampleSelector initialized")
print(f"   - Top k examples to retrieve: 2")
print(f"   - Similarity metric: Cosine similarity")

 Chroma vector store created successfully
   - Total vectors stored: 20
   - Embedding dimension: 384 (from all-MiniLM-L6-v2)

 SemanticSimilarityExampleSelector initialized
   - Top k examples to retrieve: 2
   - Similarity metric: Cosine similarity


In [17]:
# Test the example selector with a single query
example_selector.select_examples({"Question": "Find the top-performing Product line by total revenue in each City with ranking."})

[{'Answer': 'Top-performing product lines per city are identified using window functions with RANK to partition by City.',
  'Question': 'Find top-performing product line per city using ranking.',
  'SQLResult': 'Result of the SQL query',
  'SQLQuery': 'SELECT City, `Product line`, SUM(Total) AS total_revenue, RANK() OVER (PARTITION BY City ORDER BY SUM(Total) DESC) AS rank_in_city FROM walmart_sales GROUP BY City, `Product line`;'},
 {'Answer': 'Product lines with highest sales per city are found by grouping by City and Product line, then ordering by total sales.',
  'Question': 'Determine which product line has the highest total sales in each city.',
  'SQLQuery': 'SELECT City, `Product line`, SUM(Total) AS total_sales FROM walmart_sales GROUP BY City, `Product line` ORDER BY City, total_sales DESC;',
  'SQLResult': 'Result of the SQL query'}]

In [18]:
# MySQL-based instruction prompt template
# This guides the AI to generate syntactically correct MySQL queries
mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.

Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.

Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.

Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run (without any markdown formatting or code blocks)
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only provide the SQL query without any additional text, explanations, or markdown code blocks.
"""

print(" MySQL prompt template created")
print("   - Includes rules for query formatting")
print("   - Limits results with LIMIT clause")
print("   - Uses backticks for column names")
print("   - Handles 'today' with CURDATE()")

 MySQL prompt template created
   - Includes rules for query formatting
   - Limits results with LIMIT clause
   - Uses backticks for column names
   - Handles 'today' with CURDATE()


In [19]:
# Import FewShotPromptTemplate for creating prompts with example demonstrations
from langchain_core.prompts import FewShotPromptTemplate

# Import pre-built MySQL prompt components from LangChain
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt

print(" Imported FewShotPromptTemplate and MySQL prompt components")
print("\n" + "="*80)
print(" PROMPT_SUFFIX (Default ending template):")
print("="*80)
print(PROMPT_SUFFIX)

print("\n" + "="*80)
print(" _mysql_prompt (MySQL-specific instructions):")
print("="*80)
print(_mysql_prompt)

 Imported FewShotPromptTemplate and MySQL prompt components

 PROMPT_SUFFIX (Default ending template):
Only use the following tables:
{table_info}

Question: {input}

 _mysql_prompt (MySQL-specific instructions):
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is i

In [20]:
from langchain_core.prompts import PromptTemplate

# Create a template for formatting individual few-shot examples
# This defines how each example (question + SQL + result + answer) should be displayed
example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
)

print(" Example prompt template created")
print("\n" + "="*80)
print(" Example Format Template:")
print("="*80)
print(example_prompt.template)

print("\n" + "="*80)
print(" Sample Usage (with dummy data):")
print("="*80)
sample_example = {
    "Question": "What is the total revenue by branch?",
    "SQLQuery": "SELECT Branch, SUM(Total) AS total_revenue FROM walmart_sales GROUP BY Branch;",
    "SQLResult": "[('A', 105861.56), ('B', 106200.37), ('C', 110568.71)]",
    "Answer": "Branch A has total revenue of 105,861.56, Branch B has 106,200.37, and Branch C has 110,568.71."
}
print(example_prompt.format(**sample_example))

 Example prompt template created

 Example Format Template:

Question: {Question}
SQLQuery: {SQLQuery}
SQLResult: {SQLResult}
Answer: {Answer}

 Sample Usage (with dummy data):

Question: What is the total revenue by branch?
SQLQuery: SELECT Branch, SUM(Total) AS total_revenue FROM walmart_sales GROUP BY Branch;
SQLResult: [('A', 105861.56), ('B', 106200.37), ('C', 110568.71)]
Answer: Branch A has total revenue of 105,861.56, Branch B has 106,200.37, and Branch C has 110,568.71.


In [22]:
# Create the complete few-shot prompt template by combining all components
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,  # Dynamically selects relevant examples from vector store
    example_prompt=example_prompt,      # Formats each selected example
    prefix=mysql_prompt,                # MySQL expert instructions at the start
    suffix=PROMPT_SUFFIX,               # Output format instructions at the end
    input_variables=["input", "table_info", "top_k"]  # Variables to be filled in
)

print(" Few-shot prompt template assembled successfully!")
print("\n" + "="*80)
print(" Prompt Structure (Sandwich Model):")
print("="*80)
print("1.  PREFIX (Top Bun): MySQL expert instructions")
print("2.  EXAMPLES (Filling): 2 most relevant examples from vector store")
print("3.  SUFFIX (Bottom Bun): Output format rules")
print("\n Input Variables:")
print("   - input: User's question")
print("   - table_info: Database schema")
print("   - top_k: Maximum results to return")

print("\n" + "="*80)
print(" Test: Format a complete prompt")
print("="*80)

# Test the complete prompt with a sample question
test_input = {
    "input": "Find the top 3 product lines by total revenue",
    "table_info": db.table_info,
    "top_k": str(3)
}

formatted_prompt = few_shot_prompt.format(**test_input)
print("\n Generated Prompt (first 500 chars):")
print(formatted_prompt[:500] + "...")

 Few-shot prompt template assembled successfully!

 Prompt Structure (Sandwich Model):
1.  PREFIX (Top Bun): MySQL expert instructions
2.  EXAMPLES (Filling): 2 most relevant examples from vector store
3.  SUFFIX (Bottom Bun): Output format rules

 Input Variables:
   - input: User's question
   - table_info: Database schema
   - top_k: Maximum results to return

 Test: Format a complete prompt

 Generated Prompt (first 500 chars):
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.

Unless the user specifies in the question a specific number of examples to obtain, query for at most 3 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.

Never query for all columns from a table. You must query only the columns that are...


In [23]:
# Create a new chain with the few-shot prompt
chain_with_examples = create_sql_query_chain(llm, db, prompt=few_shot_prompt)

In [24]:
get_sql_answer(question="Identify pairs of product lines that are most frequently purchased together on the same day by the same customer type to uncover potential cross-selling opportunities.", chain=query_chain, model=llm, database=db)

Original, bad query: SQLQuery: SELECT
  s1.`Product line`,
  s2.`Product line`,
  COUNT(*) AS `purchase_count`
FROM walmart_sales AS s1
JOIN walmart_sales AS s2
  ON s1.`Invoice ID` = s2.`Invoice ID` AND s1.`Product line` < s2.`Product line`
GROUP BY
  s1.`Product line`,
  s2.`Product line`
ORDER BY
  `purchase_count` DESC
LIMIT 5;
Cleaned, good query: SELECT
  s1.`Product line`,
  s2.`Product line`,
  COUNT(*) AS `purchase_count`
FROM walmart_sales AS s1
JOIN walmart_sales AS s2
  ON s1.`Invoice ID` = s2.`Invoice ID` AND s1.`Product line` < s2.`Product line`
GROUP BY
  s1.`Product line`,
  s2.`Product line`
ORDER BY
  `purchase_count` DESC
LIMIT 5;

--- Query Result ---


--- Final Answer ---
Based on the SQL query, the analysis identifies the top 5 pairs of product lines that are most frequently purchased together within the same transaction.

For example, if the results showed:
1.  **Home and kitchen** and **Food and beverages**
2.  **Fashion** and **Health and beauty**
3.  **Electr

In [25]:
# First, create the chain WITH few-shot prompt
chain_with_fewshot = create_sql_query_chain(llm, db, prompt=few_shot_prompt)

# Then use it
get_sql_answer(question="Identify pairs of product lines that are most frequently purchased together on the same day by the same customer type to uncover potential cross-selling opportunities.", chain=chain_with_fewshot, model=llm, database=db)

Original, bad query: SELECT s1.`Product line` AS `Product Line 1`, s2.`Product line` AS `Product Line 2`, COUNT(DISTINCT s1.`Invoice ID`) AS `Number of Joint Purchases` FROM walmart_sales s1 JOIN walmart_sales s2 ON s1.`Invoice ID` = s2.`Invoice ID` AND s1.`Product line` < s2.`Product line` GROUP BY s1.`Product line`, s2.`Product line` ORDER BY `Number of Joint Purchases` DESC LIMIT 5;
Cleaned, good query: SELECT s1.`Product line` AS `Product Line 1`, s2.`Product line` AS `Product Line 2`, COUNT(DISTINCT s1.`Invoice ID`) AS `Number of Joint Purchases` FROM walmart_sales s1 JOIN walmart_sales s2 ON s1.`Invoice ID` = s2.`Invoice ID` AND s1.`Product line` < s2.`Product line` GROUP BY s1.`Product line`, s2.`Product line` ORDER BY `Number of Joint Purchases` DESC LIMIT 5;

--- Query Result ---


--- Final Answer ---
The analysis identifies the top 5 pairs of product lines that are most frequently purchased together on the same invoice. The "Number of Joint Purchases" indicates how many uniq