## Chat with the SuperStore Dataset

In [27]:
#importing the necessary dependencies
import os
from langchain.llms import GooglePalm
from dotenv import load_dotenv
load_dotenv()

True

In [28]:
#loading g the llm
llm=GooglePalm(google_api_key=os.environ['GOOGLE_API_KEY'],temperature=0.6)

In [29]:
#connecting to SQL database:Superstore database
from langchain.utilities import SQLDatabase

#db_password: Do Not Use Special Characters
db_user = "root"
db_password = "root1234"
db_host = "localhost"
db_name = "sample_superstore"

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)

print(db.table_info)


CREATE TABLE orders (
	`Row_ID` INTEGER, 
	`Order_ID` TEXT, 
	`Order_Date` TEXT, 
	`Ship_Date` TEXT, 
	`Ship_Mode` TEXT, 
	`Customer_ID` TEXT, 
	`Customer_Name` TEXT, 
	`Segment` TEXT, 
	`Country` TEXT, 
	`City` TEXT, 
	`State` TEXT, 
	`Postal_Code` TEXT, 
	`Region` TEXT, 
	`Product_ID` TEXT, 
	`Category` TEXT, 
	`Sub_Category` TEXT, 
	`Product_Name` TEXT, 
	`Sales` DOUBLE, 
	`Quantity` INTEGER, 
	`Discount` DOUBLE, 
	`Profit` DOUBLE
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4

/*
3 rows from orders table:
Row_ID	Order_ID	Order_Date	Ship_Date	Ship_Mode	Customer_ID	Customer_Name	Segment	Country	City	State	Postal_Code	Region	Product_ID	Category	Sub_Category	Product_Name	Sales	Quantity	Discount	Profit
1	CA-2020-152156	11/8/2020	11/11/2020	Second Class	CG-12520	Claire Gute	Consumer	United States	Henderson	Kentucky	42420	South	FUR-BO-10001798	Furniture	Bookcases	Bush Somerset Collection Bookcase	261.9600000000	2	0E-10	41.9136000000
2	CA-2020-152156	11/8/2020	11/11/202

In [51]:
#using SQL chain
from langchain_experimental.sql import SQLDatabaseChain

#Put verbose=true to view the generated SQL statements
db_chain = SQLDatabaseChain.from_llm(llm,db,verbose=True,return_direct=True)

### Exploring the Answers without adding Few Shot Learning

In [58]:
qns1 = db_chain("what is the total number of returns made by the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total number of returns made by the store
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM returns[0m
SQLResult: [33;1m[1;3m[(800,)][0m
[1m> Finished chain.[0m


In [59]:
qns1

{'query': 'what is the total number of returns made by the store',
 'result': '[(800,)]'}

Need to be more specific to differentiate between the number of products returned and the number of orders returned. Let's check if it can tell the difference.

In [24]:
qns1 = db_chain.run("what is the total number of orders returned to the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total number of orders returned to the store
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM returns[0m
SQLResult: [33;1m[1;3m[(800,)][0m
Answer:[32;1m[1;3m800[0m
[1m> Finished chain.[0m


In [25]:
qns1 = db_chain.run("what is the total number of products returned to the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total number of products returned to the store
SQLQuery:[32;1m[1;3mSELECT COUNT(*) AS Total_Products_Returned FROM returns[0m
SQLResult: [33;1m[1;3m[(800,)][0m
Answer:[32;1m[1;3m800[0m
[1m> Finished chain.[0m


Seems it can't tell the difference. This is an example of a query that should be used in few shot learning

In [26]:
qns1 = db_chain.run("what is the total sales made by the store")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the total sales made by the store
SQLQuery:[32;1m[1;3mSELECT SUM(Sales) FROM orders[0m
SQLResult: [33;1m[1;3m[(2272449.8562999545,)][0m
Answer:[32;1m[1;3m2272449.8562999545[0m
[1m> Finished chain.[0m


In [32]:
qns1 = db_chain.run("which state made the highest sales")



[1m> Entering new SQLDatabaseChain chain...[0m
which state made the highest sales
SQLQuery:[32;1m[1;3mSELECT State, SUM(Sales) FROM orders GROUP BY State ORDER BY SUM(Sales) DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('California', 450567.5915000007)][0m
Answer:[32;1m[1;3mCalifornia[0m
[1m> Finished chain.[0m


In [34]:
qns1 = db_chain.run("Which state made the highest sales and by how much")



[1m> Entering new SQLDatabaseChain chain...[0m
Which state made the highest sales and by how much
SQLQuery:[32;1m[1;3mSELECT State, SUM(Sales) AS Sales FROM orders GROUP BY State ORDER BY Sales DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('California', 450567.5915000007)][0m
Answer:[32;1m[1;3mCalifornia: 450567.5915000007[0m
[1m> Finished chain.[0m


In [35]:
qns1 = db_chain.run("Which region the made the highest sales and by how much")



[1m> Entering new SQLDatabaseChain chain...[0m
Which region the made the highest sales and by how much
SQLQuery:[32;1m[1;3mSELECT 
  Region,
  SUM(Sales) AS TotalSales
FROM orders
GROUP BY Region
ORDER BY TotalSales DESC
LIMIT 1[0m
SQLResult: [33;1m[1;3m[('West', 713471.3445000004)][0m
Answer:[32;1m[1;3mWest with 713471.3445000004[0m
[1m> Finished chain.[0m


In [60]:
qns1 = db_chain("How many orders were placed in each city?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many orders were placed in each city?
SQLQuery:[32;1m[1;3mSELECT City, COUNT(*) AS Num_Orders FROM orders GROUP BY City ORDER BY Num_Orders DESC[0m
SQLResult: [33;1m[1;3m[('New York City', 893), ('Los Angeles', 731), ('Philadelphia', 507), ('San Francisco', 494), ('Seattle', 412), ('Houston', 370), ('Chicago', 301), ('Columbus', 214), ('San Diego', 166), ('Springfield', 158), ('Dallas', 149), ('Jacksonville', 121), ('Detroit', 111), ('Newark', 93), ('Richmond', 88), ('Jackson', 79), ('Columbia', 78), ('Aurora', 67), ('Phoenix', 61), ('San Antonio', 59), ('Arlington', 59), ('Long Beach', 58), ('Louisville', 57), ('Miami', 55), ('Charlotte', 52), ('Henderson', 51), ('Rochester', 51), ('Lakewood', 48), ('Milwaukee', 45), ('Lawrence', 44), ('Fairfield', 43), ('Denver', 43), ('Lancaster', 43), ('Cleveland', 42), ('Baltimore', 42), ('Pasadena', 41), ('San Jose', 41), ('Fayetteville', 40), ('Salem', 40), ('Atlanta', 39), ('Austin', 3

In [61]:
qns1

{'query': 'How many orders were placed in each city?',
 'result': "[('New York City', 893), ('Los Angeles', 731), ('Philadelphia', 507), ('San Francisco', 494), ('Seattle', 412), ('Houston', 370), ('Chicago', 301), ('Columbus', 214), ('San Diego', 166), ('Springfield', 158), ('Dallas', 149), ('Jacksonville', 121), ('Detroit', 111), ('Newark', 93), ('Richmond', 88), ('Jackson', 79), ('Columbia', 78), ('Aurora', 67), ('Phoenix', 61), ('San Antonio', 59), ('Arlington', 59), ('Long Beach', 58), ('Louisville', 57), ('Miami', 55), ('Charlotte', 52), ('Henderson', 51), ('Rochester', 51), ('Lakewood', 48), ('Milwaukee', 45), ('Lawrence', 44), ('Fairfield', 43), ('Denver', 43), ('Lancaster', 43), ('Cleveland', 42), ('Baltimore', 42), ('Pasadena', 41), ('San Jose', 41), ('Fayetteville', 40), ('Salem', 40), ('Atlanta', 39), ('Austin', 38), ('Wilmington', 35), ('Tampa', 35), ('Huntsville', 35), ('Decatur', 34), ('Franklin', 34), ('Concord', 31), ('Toledo', 31), ('Tucson', 30), ('Oceanside', 30), (

In [62]:
#Getting table of pandas dataframe from the result
import pandas as pd
result_list = eval(qns1['result'])
df = pd.DataFrame(result_list)
print(df)


                     0    1
0        New York City  893
1          Los Angeles  731
2         Philadelphia  507
3        San Francisco  494
4              Seattle  412
..                 ...  ...
524         Hagerstown    1
525  Arlington Heights    1
526    San Luis Obispo    1
527         Springdale    1
528               Lodi    1

[529 rows x 2 columns]


### Creating Semantic Similarity Based example selector
- create embedding on the few_shots: Using Google Gen Embeddings
- Store the embeddings in a vector store: FAISS index
- Retrieve the the top most Semantically close example from the vector store

In [63]:
#Using Google 
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings=GoogleGenerativeAIEmbeddings(model='models/embedding-001')

In [64]:
#Need to combine all values in the few shot examples into one sentence
from few_shots import few_shots
# creating a blob of all the sentences
to_vectorize = [" ".join(example.values()) for example in few_shots]

In [65]:
#Using the FAISS vector database
from langchain.vectorstores import FAISS #vector embeddngs

#generating a vector store: 
vector_store=FAISS.from_texts(to_vectorize, embeddings, metadatas=few_shots)

In [66]:
#Checking sematic similarity: # Helping to pull similar looking queries
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore = vector_store,
    k=2, #number of examples
)

example_selector.select_examples({"Question": "How many orders were placed from the store?"})

[{'Question': 'How many orders were placed in total?',
  'SQLQuery': 'SELECT COUNT(*) FROM orders',
  'SQLResult': 'Result of the SQL query',
  'Answer': 'Answer'},
 {'Question': 'How many unique customers placed orders?',
  'SQLQuery': 'SELECT COUNT(DISTINCT Customer_ID) FROM orders',
  'SQLResult': 'Result of the SQL query',
  'Answer': 'Answer'}]

### Setting up Prompts

Adding custom prompts to provide better context to the LLM

In [67]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from langchain.prompts.prompt import PromptTemplate
#adding mysql prompt
print(_mysql_prompt)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

In [68]:
print(PROMPT_SUFFIX)

Only use the following tables:
{table_info}

Question: {input}


In [69]:
mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
For results with decimals, round off the numbers to 2 decimal places unless specified otherwise by the user.
Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery

"""

In [70]:
#Example prompt format
example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}",
)

In [71]:
# Entire Prompt
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=mysql_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [74]:
#Chain_
new_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt,return_direct=True)

In [75]:
#Checking how the new chain performs
new_chain("Which city has the highest total sales?")



[1m> Entering new SQLDatabaseChain chain...[0m
Which city has the highest total sales?
SQLQuery:[32;1m[1;3mSELECT City, SUM(Sales) AS TotalSales FROM orders GROUP BY City ORDER BY TotalSales DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('New York City', 255248.969)][0m
[1m> Finished chain.[0m


{'query': 'Which city has the highest total sales?',
 'result': "[('New York City', 255248.969)]"}

In [76]:
#Function to return dataframe
def return_df(response):
    result_list = eval(response['result'])
    df = pd.DataFrame(result_list)
    return df

In [77]:
response = new_chain("Which city has the highest total sales?")
df = return_df(response)
df.head()



[1m> Entering new SQLDatabaseChain chain...[0m
Which city has the highest total sales?
SQLQuery:[32;1m[1;3mSELECT City, SUM(Sales) AS TotalSales FROM orders GROUP BY City ORDER BY TotalSales DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('New York City', 255248.969)][0m
[1m> Finished chain.[0m


Unnamed: 0,0,1
0,New York City,255248.969


In [78]:
response = new_chain("How many unique product names are there in the orders table?")
df = return_df(response)
df.head()



[1m> Entering new SQLDatabaseChain chain...[0m
How many unique product names are there in the orders table?
SQLQuery:[32;1m[1;3mSELECT COUNT(DISTINCT Product_Name) FROM orders[0m
SQLResult: [33;1m[1;3m[(1797,)][0m
[1m> Finished chain.[0m


Unnamed: 0,0
0,1797


In [79]:
response = new_chain("How many returns were made for each sub-category of office supplies?")
df = return_df(response)
df.head()



[1m> Entering new SQLDatabaseChain chain...[0m
How many returns were made for each sub-category of office supplies?
SQLQuery:[32;1m[1;3mSELECT Sub_Category, COUNT(*) FROM orders WHERE Category = 'Office Supplies' AND Returned = 'Yes' GROUP BY Sub_Category[0m

OperationalError: (pymysql.err.OperationalError) (1054, "Unknown column 'Returned' in 'where clause'")
[SQL: SELECT Sub_Category, COUNT(*) FROM orders WHERE Category = 'Office Supplies' AND Returned = 'Yes' GROUP BY Sub_Category]
(Background on this error at: https://sqlalche.me/e/20/e3q8)