## Chat with the SuperStore Dataset

In [1]:
#importing the necessary dependencies
import os
from langchain.llms import GooglePalm
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
#loading g the llm
llm=GooglePalm(google_api_key=os.environ['GOOGLE_API_KEY'],temperature=0.6)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#connecting to SQL database:Superstore database
from langchain.utilities import SQLDatabase

#db_password: Do Not Use Special Characters
db_user = "root"
db_password = "root1234"
db_host = "localhost"
db_name = "sample_superstore"

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)

print(db.table_info)


CREATE TABLE orders (
	`Row_ID` INTEGER, 
	`Order_ID` TEXT, 
	`Order_Date` TEXT, 
	`Ship_Date` TEXT, 
	`Ship_Mode` TEXT, 
	`Customer_ID` TEXT, 
	`Customer_Name` TEXT, 
	`Segment` TEXT, 
	`Country` TEXT, 
	`City` TEXT, 
	`State` TEXT, 
	`Postal_Code` TEXT, 
	`Region` TEXT, 
	`Product_ID` TEXT, 
	`Category` TEXT, 
	`Sub_Category` TEXT, 
	`Product_Name` TEXT, 
	`Sales` DOUBLE, 
	`Quantity` INTEGER, 
	`Discount` DOUBLE, 
	`Profit` DOUBLE
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from orders table:
Row_ID	Order_ID	Order_Date	Ship_Date	Ship_Mode	Customer_ID	Customer_Name	Segment	Country	City	State	Postal_Code	Region	Product_ID	Category	Sub_Category	Product_Name	Sales	Quantity	Discount	Profit
1	CA-2020-152156	11/8/2020	11/11/2020	Second Class	CG-12520	Claire Gute	Consumer	United States	Henderson	Kentucky	42420	South	FUR-BO-10001798	Furniture	Bookcases	Bush Somerset Collection Bookcase	261.9600000000	2	0E-10	41.9136000000
2	CA-2020-152156	11/8/2020	11/11/202

In [4]:
#using SQL chain
from langchain_experimental.sql import SQLDatabaseChain

#Put verbose=true to view the generated SQL statements
db_chain = SQLDatabaseChain.from_llm(llm,db,verbose=True,return_direct=True)

### Exploring the Answers without adding Few Shot Learning

In [5]:
qns1 = db_chain("what is the total number of returns made by the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total number of returns made by the store
SQLQuery:[32;1m[1;3mSELECT count(*) as Total_Returns FROM returns[0m
SQLResult: [33;1m[1;3m[(800,)][0m
[1m> Finished chain.[0m


In [6]:
qns1

{'query': 'what is the total number of returns made by the store',
 'result': '[(800,)]'}

Need to be more specific to differentiate between the number of products returned and the number of orders returned. Let's check if it can tell the difference.

In [7]:
qns1 = db_chain.run("what is the total number of orders returned to the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total number of orders returned to the store
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM returns;[0m
SQLResult: [33;1m[1;3m[(800,)][0m
[1m> Finished chain.[0m


In [8]:
qns1 = db_chain.run("what is the total number of products returned to the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total number of products returned to the store
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM returns[0m
SQLResult: [33;1m[1;3m[(800,)][0m
[1m> Finished chain.[0m


Seems it can't tell the difference. This is an example of a query that should be used in few shot learning

In [9]:
qns1 = db_chain.run("what is the total sales made by the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total sales made by the store
SQLQuery:[32;1m[1;3mSELECT SUM(Sales) FROM orders[0m
SQLResult: [33;1m[1;3m[(2272449.8562999545,)][0m
[1m> Finished chain.[0m


In [10]:
qns1 = db_chain.run("which state made the highest sales")



[1m> Entering new SQLDatabaseChain chain...[0m
which state made the highest sales
SQLQuery:[32;1m[1;3mSELECT State, SUM(Sales) AS Total_Sales FROM orders GROUP BY State ORDER BY Total_Sales DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('California', 450567.5915000007)][0m
[1m> Finished chain.[0m


In [11]:
qns1 = db_chain.run("Which state made the highest sales and by how much")



[1m> Entering new SQLDatabaseChain chain...[0m
Which state made the highest sales and by how much
SQLQuery:[32;1m[1;3mSELECT State, SUM(Sales) AS Sales FROM orders GROUP BY State ORDER BY Sales DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('California', 450567.5915000007)][0m
[1m> Finished chain.[0m


In [12]:
qns1 = db_chain.run("Which region the made the highest sales and by how much")



[1m> Entering new SQLDatabaseChain chain...[0m
Which region the made the highest sales and by how much
SQLQuery:[32;1m[1;3mSELECT `Region`, SUM(`Sales`) AS `Total_Sales` FROM `orders` GROUP BY `Region` ORDER BY `Total_Sales` DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('West', 713471.3445000004)][0m
[1m> Finished chain.[0m


In [13]:
qns1 = db_chain("How many orders were placed in each city?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many orders were placed in each city?
SQLQuery:[32;1m[1;3mSELECT City, COUNT(*) AS num_orders FROM orders GROUP BY City[0m
SQLResult: [33;1m[1;3m[('Henderson', 51), ('Los Angeles', 731), ('Fort Lauderdale', 15), ('Concord', 31), ('Seattle', 412), ('Fort Worth', 27), ('Madison', 10), ('West Jordan', 5), ('San Francisco', 494), ('Fremont', 8), ('Philadelphia', 507), ('Orem', 9), ('Houston', 370), ('Richardson', 2), ('Naperville', 8), ('Melbourne', 1), ('Eagan', 8), ('Westland', 12), ('Dover', 18), ('New Albany', 4), ('New York City', 893), ('Troy', 28), ('Chicago', 301), ('Gilbert', 14), ('Springfield', 158), ('Jackson', 79), ('Memphis', 27), ('Decatur', 34), ('Durham', 9), ('Columbia', 78), ('Rochester', 51), ('Minneapolis', 23), ('Portland', 24), ('Saint Paul', 2), ('Aurora', 67), ('Charlotte', 52), ('Orland Park', 1), ('Urbandale', 3), ('Columbus', 214), ('Bristol', 12), ('Wilmington', 35), ('Bloomington', 15), ('Phoenix', 61

In [14]:
qns1

{'query': 'How many orders were placed in each city?',
 'result': "[('Henderson', 51), ('Los Angeles', 731), ('Fort Lauderdale', 15), ('Concord', 31), ('Seattle', 412), ('Fort Worth', 27), ('Madison', 10), ('West Jordan', 5), ('San Francisco', 494), ('Fremont', 8), ('Philadelphia', 507), ('Orem', 9), ('Houston', 370), ('Richardson', 2), ('Naperville', 8), ('Melbourne', 1), ('Eagan', 8), ('Westland', 12), ('Dover', 18), ('New Albany', 4), ('New York City', 893), ('Troy', 28), ('Chicago', 301), ('Gilbert', 14), ('Springfield', 158), ('Jackson', 79), ('Memphis', 27), ('Decatur', 34), ('Durham', 9), ('Columbia', 78), ('Rochester', 51), ('Minneapolis', 23), ('Portland', 24), ('Saint Paul', 2), ('Aurora', 67), ('Charlotte', 52), ('Orland Park', 1), ('Urbandale', 3), ('Columbus', 214), ('Bristol', 12), ('Wilmington', 35), ('Bloomington', 15), ('Phoenix', 61), ('Roseville', 25), ('Independence', 2), ('Pasadena', 41), ('Newark', 93), ('Franklin', 34), ('Scottsdale', 12), ('San Jose', 41), ('Edm

In [15]:
#Getting table of pandas dataframe from the result
import pandas as pd
result_list = eval(qns1['result'])
df = pd.DataFrame(result_list)
print(df)


                   0    1
0          Henderson   51
1        Los Angeles  731
2    Fort Lauderdale   15
3            Concord   31
4            Seattle  412
..               ...  ...
524     San Clemente    2
525  San Luis Obispo    1
526       Springdale    1
527             Lodi    1
528            Mason    3

[529 rows x 2 columns]


### Creating Semantic Similarity Based example selector
- create embedding on the few_shots: Using Google Gen Embeddings
- Store the embeddings in a vector store: FAISS index
- Retrieve the the top most Semantically close example from the vector store

In [16]:
#Using Google 
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings=GoogleGenerativeAIEmbeddings(model='models/embedding-001')

In [17]:
#Need to combine all values in the few shot examples into one sentence
from few_shots import few_shots
print(few_shots)
# creating a blob of all the sentences
to_vectorize = [" ".join(example.values()) for example in few_shots]

[{'Question': 'What is the total sales amount for all orders?', 'SQLQuery': 'SELECT SUM(Sales) FROM orders', 'SQLResult': 'Result of the SQL query', 'Answer': 'Answer'}, {'Question': 'How many orders were returned in total?', 'SQLQuery': 'SELECT COUNT(DISTINCT Order_ID) FROM returns', 'SQLResult': 'Result of the SQL query', 'Answer': 'Answer'}, {'Question': 'Who is the regional manager for the region with the highest sales?', 'SQLQuery': 'SELECT p.Regional_Manager FROM people p JOIN (SELECT Region, SUM(Sales) AS TotalSales FROM orders GROUP BY Region ORDER BY TotalSales DESC LIMIT 1) AS max_sales ON p.Region = max_sales.Region', 'SQLResult': 'Result of the SQL query', 'Answer': 'Sadie Pawthorne'}, {'Question': 'What is the total profit earned?', 'SQLQuery': 'SELECT SUM(Profit) FROM orders', 'SQLResult': 'Result of the SQL query', 'Answer': 'Answer'}, {'Question': 'How many orders were placed in total?', 'SQLQuery': 'SELECT COUNT(DISTINCT Order_ID) FROM orders', 'SQLResult': 'Result of 

In [18]:
#Using the FAISS vector database
from langchain.vectorstores import FAISS #vector embeddngs

#generating a vector store: 
vector_store=FAISS.from_texts(to_vectorize, embeddings, metadatas=few_shots)

In [19]:
#Checking sematic similarity: # Helping to pull similar looking queries
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore = vector_store,
    k=2, #number of examples
)

example_selector.select_examples({"Question": "How many orders were placed from the store?"})

[{'Question': 'How many orders were placed in total?',
  'SQLQuery': 'SELECT COUNT(DISTINCT Order_ID) FROM orders',
  'SQLResult': 'Result of the SQL query',
  'Answer': 'Answer'},
 {'Question': 'How many unique customers placed orders?',
  'SQLQuery': 'SELECT COUNT(DISTINCT Customer_ID) FROM orders',
  'SQLResult': 'Result of the SQL query',
  'Answer': 'Answer'}]

### Setting up Prompts

Adding custom prompts to provide better context to the LLM

In [20]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from langchain.prompts.prompt import PromptTemplate
#adding mysql prompt
print(_mysql_prompt)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

In [21]:
print(PROMPT_SUFFIX)

Only use the following tables:
{table_info}

Question: {input}


In [96]:
mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery

"""

In [97]:
#Example prompt format
example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}",
)

In [98]:
# Entire Prompt
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=mysql_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [99]:
#Chain_
new_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt,
                                      return_direct=True,return_intermediate_steps=True)

In [103]:
#Checking how the new chain performs
response = new_chain("Which product category has the highest average quantity sold per order?")



[1m> Entering new SQLDatabaseChain chain...[0m
Which product category has the highest average quantity sold per order?
SQLQuery:[32;1m[1;3mSELECT Category, AVG(Quantity) AS AvgQuantity FROM orders GROUP BY Category ORDER BY AvgQuantity DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Office Supplies', Decimal('3.8038'))][0m
[1m> Finished chain.[0m


In [102]:
response

{'query': 'What is the average discount rate for each product category in the western region?',
 'result': "[('Office Supplies', 0.09355016538037439), ('Furniture', 0.13309143686502267), ('Technology', 0.13355704697986673)]",
 'intermediate_steps': [{'input': 'What is the average discount rate for each product category in the western region?\nSQLQuery:',
   'top_k': '5',
   'dialect': 'mysql',
   'table_info': '\nCREATE TABLE orders (\n\t`Row_ID` INTEGER, \n\t`Order_ID` TEXT, \n\t`Order_Date` TEXT, \n\t`Ship_Date` TEXT, \n\t`Ship_Mode` TEXT, \n\t`Customer_ID` TEXT, \n\t`Customer_Name` TEXT, \n\t`Segment` TEXT, \n\t`Country` TEXT, \n\t`City` TEXT, \n\t`State` TEXT, \n\t`Postal_Code` TEXT, \n\t`Region` TEXT, \n\t`Product_ID` TEXT, \n\t`Category` TEXT, \n\t`Sub_Category` TEXT, \n\t`Product_Name` TEXT, \n\t`Sales` DOUBLE, \n\t`Quantity` INTEGER, \n\t`Discount` DOUBLE, \n\t`Profit` DOUBLE\n)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB\n\n/*\n3 rows from orders table:\nRo

In [104]:
response['result']

"[('Office Supplies', Decimal('3.8038'))]"

In [110]:
import ast
from decimal import Decimal
response['result']

"[('Office Supplies', Decimal('3.8038'))]"

In [54]:
#prints the SQL code used
response['intermediate_steps'][2]['sql_cmd']

'SELECT City, SUM(Sales) AS TotalSales FROM orders GROUP BY City ORDER BY TotalSales DESC LIMIT 1'

In [55]:
sql_cmd = response['intermediate_steps'][2]['sql_cmd']

In [56]:
sql_cmd

'SELECT City, SUM(Sales) AS TotalSales FROM orders GROUP BY City ORDER BY TotalSales DESC LIMIT 1'

### Extracting Column names from the generated sql code

In [74]:
# Using the Google Palm llm
column_names = llm.invoke("Could you generate the column names based on this SQL Query {} and store in a list called column_names".format(sql_cmd))

In [76]:
column_names

"```python\ncolumn_names = ['City', 'TotalSales']\n```"

In [78]:
column_names.split('=')[1]

" ['City', 'TotalSales']\n```"

In [113]:
import re
import ast
from decimal import Decimal
#defining a function to get the column names ans takes in the response as the input
def get_column_names(response):
    #Getting the sql cmd from the chain reponse
    sql_cmd = response['intermediate_steps'][2]['sql_cmd']

    #Using the llm to extract the name of the columns from the sql query
    column_names = llm.invoke("Could you generate the column names based on this SQL Query {} and store in a list called column_names".format(sql_cmd))

    # Define the regex pattern to match the list
    pattern = r"\[.*?\]"

    # Use re.search to find the first occurrence of the pattern
    match = re.search(pattern, column_names)

    if match:
        # Extract the matched substring
        extracted_list_str = match.group()
        # Parse the string representation of the list into a Python list
        extracted_list = ast.literal_eval(extracted_list_str)
        print(extracted_list)
    else:
        return None

    

In [115]:
get_column_names(response)

['Category', 'AvgQuantity']


### Creating a Dataframe from the SQL Result

In [111]:
#Function to return dataframe
def return_df(response):
    result_list = eval(response['result'])
    df = pd.DataFrame(result_list)
    return df

In [114]:
response = new_chain("Which product category has the highest average quantity sold per order?")
df = return_df(response)
df.head()



[1m> Entering new SQLDatabaseChain chain...[0m
Which product category has the highest average quantity sold per order?
SQLQuery:[32;1m[1;3mSELECT Category, AVG(Quantity) AS AvgQuantity FROM orders GROUP BY Category ORDER BY AvgQuantity DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Office Supplies', Decimal('3.8038'))][0m
[1m> Finished chain.[0m


Unnamed: 0,0,1
0,Office Supplies,3.8038


In [34]:
response = new_chain("How many unique product names are there in the orders table? in the SQLResult return the possible column names as the first tuple")
df = return_df(response)
df.head()



[1m> Entering new SQLDatabaseChain chain...[0m
How many unique product names are there in the orders table? in the SQLResult return the possible column names as the first tuple
SQLQuery:[32;1m[1;3mSELECT COUNT(DISTINCT Product_Name) FROM orders[0m
SQLResult: [33;1m[1;3m[(1797,)][0m
[1m> Finished chain.[0m


Unnamed: 0,0
0,1797


In [33]:
response

{'query': 'How many unique product names are there in the orders table?',
 'result': '[(1797,)]'}

In [30]:
new_chain("How many orders were returned to the store")



[1m> Entering new SQLDatabaseChain chain...[0m
How many orders were returned to the store
SQLQuery:[32;1m[1;3mSELECT COUNT(DISTINCT o.Order_ID) AS TotalOrders FROM orders o JOIN returns r ON o.Order_ID = r.Order_ID[0m
SQLResult: [33;1m[1;3m[(292,)][0m
[1m> Finished chain.[0m


{'query': 'How many orders were returned to the store', 'result': '[(292,)]'}

In [31]:
try:
    response = new_chain("How many returns were made for each sub-category of Stattionery?")
    df = return_df(response)
    df.head()
except Exception as e:
    print("Sorry!Failed to generate the SQL Query, Please try rewording your question")



[1m> Entering new SQLDatabaseChain chain...[0m
How many returns were made for each sub-category of Stattionery?
SQLQuery:[32;1m[1;3mSELECT o.Sub_Category, COUNT(DISTINCT r.Order_ID) AS TotalReturns FROM returns r JOIN orders o ON r.Order_ID = o.Order_ID WHERE o.Category = 'Stationery' GROUP BY o.Sub_Category[0m
SQLResult: [33;1m[1;3m[0m
[1m> Finished chain.[0m
Sorry!Failed to generate the SQL Query, Please try rewording your question
