In [1]:
import os
from langchain_google_genai import GoogleGenerativeAI
if 'GOOGLE_API_KEY' in os.environ:
    GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY'] 

In [2]:
llm = GoogleGenerativeAI(model='gemini-pro',google_api_key=GOOGLE_API_KEY,temperature=0.2)
#llm.invoke("Give one line saying about AI")
from langchain.utilities import SQLDatabase
db_user = 'root'
db_password = 'password'
db_host = 'localhost'
db_name = 'neptech_store'
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)

In [3]:
db.table_info

"\nCREATE TABLE laptop_discounts (\n\tdiscount_id INTEGER NOT NULL AUTO_INCREMENT, \n\tlaptop_id INTEGER NOT NULL, \n\tpct_discount DECIMAL(5, 2), \n\tPRIMARY KEY (discount_id), \n\tCONSTRAINT laptop_discounts_ibfk_1 FOREIGN KEY(laptop_id) REFERENCES laptops (laptop_id), \n\tCONSTRAINT laptop_discounts_chk_1 CHECK ((`pct_discount` between 0 and 100))\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from laptop_discounts table:\ndiscount_id\tlaptop_id\tpct_discount\n1\t1\t10.00\n2\t2\t20.00\n3\t3\t15.00\n*/\n\n\nCREATE TABLE laptops (\n\tlaptop_id INTEGER NOT NULL AUTO_INCREMENT, \n\tbrand ENUM('Dell','HP','Lenovo','Apple','Asus') NOT NULL, \n\tprocessor ENUM('i3','i5','i7','Ryzen 5','Ryzen 7') NOT NULL, \n\tram ENUM('4GB','8GB','16GB','32GB') NOT NULL, \n\tstorage ENUM('256GB SSD','512GB SSD','1TB SSD','1TB HDD') NOT NULL, \n\tprice INTEGER, \n\tstock_quantity INTEGER NOT NULL, \n\tPRIMARY KEY (laptop_id), \n\tCONSTRAINT laptops_chk_1 CHECK ((`price` bet

In [4]:
from langchain_experimental.sql import SQLDatabaseChain
db_chain =SQLDatabaseChain.from_llm(llm=llm,db=db,verbose=True)

In [5]:
result = db_chain.invoke("how many Dell brand laptops are there?The SQL query should be outputted plainly, do not surround SQL Query in quotes or anything else.")
print(result)



[1m> Entering new SQLDatabaseChain chain...[0m
how many Dell brand laptops are there?The SQL query should be outputted plainly, do not surround SQL Query in quotes or anything else.
SQLQuery:[32;1m[1;3mSELECT COUNT(*) AS `Dell_brand_laptops` 
FROM laptops 
WHERE brand = 'Dell';[0m
SQLResult: [33;1m[1;3m[(14,)][0m
Answer:[32;1m[1;3m14[0m
[1m> Finished chain.[0m
{'query': 'how many Dell brand laptops are there?The SQL query should be outputted plainly, do not surround SQL Query in quotes or anything else.', 'result': '14'}


In [23]:
result2 = db_chain.invoke("How many unique brands are there? list them.do not surround SQL Query in quotes or anything else.")
print(result2)



[1m> Entering new SQLDatabaseChain chain...[0m
How many unique brands are there? list them.do not surround SQL Query in quotes or anything else.
SQLQuery:[32;1m[1;3mSELECT COUNT(DISTINCT `brand`) AS `num_unique_brands` FROM `laptops`;[0m
SQLResult: [33;1m[1;3m[(5,)][0m
Answer:[32;1m[1;3m5
Dell
HP
Lenovo
Apple
Asus[0m
[1m> Finished chain.[0m
{'query': 'How many unique brands are there? list them.do not surround SQL Query in quotes or anything else.', 'result': '5\nDell\nHP\nLenovo\nApple\nAsus'}


In [24]:
result3 = db_chain.invoke("How much is the price of inventory for all Dell Laptops? do not surround SQL Query in quotes or anything else.")
print(result3)



[1m> Entering new SQLDatabaseChain chain...[0m
How much is the price of inventory for all Dell Laptops? do not surround SQL Query in quotes or anything else.
SQLQuery:[32;1m[1;3mSELECT SUM(`price` * `stock_quantity`) AS `total_inventory_price`
FROM `laptops`
WHERE `brand` = 'Dell';[0m
SQLResult: [33;1m[1;3m[(Decimal('788198'),)][0m
Answer:[32;1m[1;3m788198[0m
[1m> Finished chain.[0m
{'query': 'How much is the price of inventory for all Dell Laptops? do not surround SQL Query in quotes or anything else.', 'result': '788198'}


In [25]:
result4 = db_chain.invoke("If I sell all my Apple laptops today with discounts applied, how much revenue our store will generate (post discounts)?do not surround SQL Query in quotes or anything else.")
print(result4) #FAILURE



[1m> Entering new SQLDatabaseChain chain...[0m
If I sell all my Apple laptops today with discounts applied, how much revenue our store will generate (post discounts)?do not surround SQL Query in quotes or anything else.
SQLQuery:[32;1m[1;3mSELECT SUM(`laptops`.`price` * (1 - `laptop_discounts`.`pct_discount`)) AS `total_revenue`
FROM `laptops`
JOIN `laptop_discounts` ON `laptops`.`laptop_id` = `laptop_discounts`.`laptop_id`
WHERE `laptops`.`brand` = 'Apple';[0m
SQLResult: [33;1m[1;3m[(Decimal('-39237.00'),)][0m
Answer:[32;1m[1;3m-39237.00[0m
[1m> Finished chain.[0m
{'query': 'If I sell all my Apple laptops today with discounts applied, how much revenue our store will generate (post discounts)?do not surround SQL Query in quotes or anything else.', 'result': '-39237.00'}


# FEW SHOT LEARNING

In [19]:
from few_shots import few_shots_list
few_shots_list[0]

{'Question': 'How many Asus laptops with 8GB RAM and 512GB SSD are left in stock?',
 'SQLQuery': "SELECT sum(stock_quantity) FROM laptops WHERE brand = 'Asus' AND ram = '8GB' AND storage = '512GB SSD'",
 'SQLResult': 'Result of the SQL query',
 'Answer': '95'}

In [20]:
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')



In [21]:
to_vectorize = [" ".join(i.values()) for i in few_shots_list]
to_vectorize[0]

"How many Asus laptops with 8GB RAM and 512GB SSD are left in stock? SELECT sum(stock_quantity) FROM laptops WHERE brand = 'Asus' AND ram = '8GB' AND storage = '512GB SSD' Result of the SQL query 95"

In [22]:
from langchain.vectorstores import Chroma

vectorstore = Chroma.from_texts(to_vectorize,embedding=embeddings,metadatas=few_shots_list)

In [23]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore = vectorstore, k=2
)
example_selector.select_examples({"Question":"How many Apple laptops I have in my store?"})
# gives similar looking queries

[{'Answer': '556261',
  'Question': 'How much revenue will the store generate if I sell all Apple laptops today without any discounts?',
  'SQLQuery': "SELECT SUM(price * stock_quantity) FROM laptops WHERE brand = 'Apple'",
  'SQLResult': 'Result of the SQL query'},
 {'Answer': '556261',
  'Question': 'How much revenue will the store generate if I sell all Apple laptops today without any discounts?',
  'SQLQuery': "SELECT SUM(price * stock_quantity) FROM laptops WHERE brand = 'Apple'",
  'SQLResult': 'Result of the SQL query'}]

In [24]:
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX,_mysql_prompt
print(_mysql_prompt)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

In [25]:
print(PROMPT_SUFFIX)

Only use the following tables:
{table_info}

Question: {input}


In [26]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question","SQLQuery","SQLResult","Answer"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
)

In [27]:
from langchain.prompts import FewShotPromptTemplate

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input","table_info","top_k"]
)

In [28]:
new_chain = SQLDatabaseChain.from_llm(llm=llm,db=db,verbose=True, prompt=few_shot_prompt)

In [30]:
res = new_chain.invoke("How much is the price of inventory for all Asus laptops?")
print(res)



[1m> Entering new SQLDatabaseChain chain...[0m
How much is the price of inventory for all Asus laptops?
SQLQuery:[32;1m[1;3mSELECT SUM(price * stock_quantity) AS total_inventory_price FROM laptops WHERE brand = 'Asus';[0m
SQLResult: [33;1m[1;3m[(Decimal('1062553'),)][0m
Answer:[32;1m[1;3m1062553[0m
[1m> Finished chain.[0m
{'query': 'How much is the price of inventory for all Asus laptops?', 'result': '1062553'}


In [31]:
res2 = new_chain.invoke("If I sell all my Apple laptops today with discounts applied, how much revenue our store will generate (post discounts)?")



[1m> Entering new SQLDatabaseChain chain...[0m
If I sell all my Apple laptops today with discounts applied, how much revenue our store will generate (post discounts)?
SQLQuery:[32;1m[1;3mSELECT sum(a.total_amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) as total_revenue 
                     FROM (SELECT SUM(price * stock_quantity) as total_amount, laptop_id 
                           FROM laptops WHERE brand = 'Apple' 
                           GROUP BY laptop_id) a 
                     LEFT JOIN discounts ON a.laptop_id = discounts.laptop_id[0m

ProgrammingError: (pymysql.err.ProgrammingError) (1146, "Table 'neptech_store.discounts' doesn't exist")
[SQL: SELECT sum(a.total_amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) as total_revenue 
                     FROM (SELECT SUM(price * stock_quantity) as total_amount, laptop_id 
                           FROM laptops WHERE brand = 'Apple' 
                           GROUP BY laptop_id) a 
                     LEFT JOIN discounts ON a.laptop_id = discounts.laptop_id]
(Background on this error at: https://sqlalche.me/e/20/f405)

In [32]:
print(db.table_info)


CREATE TABLE laptop_discounts (
	discount_id INTEGER NOT NULL AUTO_INCREMENT, 
	laptop_id INTEGER NOT NULL, 
	pct_discount DECIMAL(5, 2), 
	PRIMARY KEY (discount_id), 
	CONSTRAINT laptop_discounts_ibfk_1 FOREIGN KEY(laptop_id) REFERENCES laptops (laptop_id), 
	CONSTRAINT laptop_discounts_chk_1 CHECK ((`pct_discount` between 0 and 100))
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from laptop_discounts table:
discount_id	laptop_id	pct_discount
1	1	10.00
2	2	20.00
3	3	15.00
*/


CREATE TABLE laptops (
	laptop_id INTEGER NOT NULL AUTO_INCREMENT, 
	brand ENUM('Dell','HP','Lenovo','Apple','Asus') NOT NULL, 
	processor ENUM('i3','i5','i7','Ryzen 5','Ryzen 7') NOT NULL, 
	ram ENUM('4GB','8GB','16GB','32GB') NOT NULL, 
	storage ENUM('256GB SSD','512GB SSD','1TB SSD','1TB HDD') NOT NULL, 
	price INTEGER, 
	stock_quantity INTEGER NOT NULL, 
	PRIMARY KEY (laptop_id), 
	CONSTRAINT laptops_chk_1 CHECK ((`price` between 300 and 2000))
)DEFAULT CHARSET=utf8mb4 COLLATE