In [1]:
from langchain.llms import GooglePalm
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts import FewShotPromptTemplate

In [2]:
api_key = 'AIzaSyDE9m47jUiy6X1YM_1wvR86v_F8iTjrN5c'

llm = GooglePalm(google_api_key=api_key, temperature=0.2)

In [3]:
# Connect to DB
db_user = "root"
db_password = "root"
db_host = "localhost"
db_name = "my_bookstore"

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)

In [4]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

> Detecting the errors that LLM makes


In [5]:
qns1 = db_chain("What is the revenue i get if i sell all the books named 'book0' with the discount?")



[1m> Entering new SQLDatabaseChain chain...[0m
What is the revenue i get if i sell all the books named 'book0' with the discount?
SQLQuery:[32;1m[1;3mSELECT (25.00 - (25.00 * 10.00 / 100)) * 60 AS revenue[0m
SQLResult: [33;1m[1;3m[(Decimal('1350.00000000'),)][0m
Answer:[32;1m[1;3m1350.00[0m
[1m> Finished chain.[0m


In [6]:
qns1 = db_chain.run("SELECT price*(1 - pct_discount/100)*stock_quantity AS revenue FROM books JOIN discounts ON books.book_id = discounts.book_id WHERE title = 'Book0'")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT price*(1 - pct_discount/100)*stock_quantity AS revenue FROM books JOIN discounts ON books.book_id = discounts.book_id WHERE title = 'Book0'
SQLQuery:[32;1m[1;3mSELECT price*(1 - pct_discount/100)*stock_quantity AS revenue FROM books JOIN discounts ON books.book_id = discounts.book_id WHERE title = 'Book0'[0m
SQLResult: [33;1m[1;3m[(Decimal('1350.00000000'),)][0m
Answer:[32;1m[1;3m1350.00[0m
[1m> Finished chain.[0m


In [7]:
qns2 = db_chain("What is the total number of fiction books that are in the store?")



[1m> Entering new SQLDatabaseChain chain...[0m
What is the total number of fiction books that are in the store?
SQLQuery:[32;1m[1;3mSELECT count(*) FROM books WHERE genre = 'Fiction'[0m
SQLResult: [33;1m[1;3m[(23,)][0m
Answer:[32;1m[1;3m23[0m
[1m> Finished chain.[0m


In [8]:
qns2 = db_chain.run("SELECT SUM(stock_quantity) FROM books WHERE genre = 'Fiction'")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT SUM(stock_quantity) FROM books WHERE genre = 'Fiction'
SQLQuery:[32;1m[1;3mSELECT SUM(stock_quantity) FROM books WHERE genre = 'Fiction'[0m
SQLResult: [33;1m[1;3m[(Decimal('1349'),)][0m
Answer:[32;1m[1;3m1349[0m
[1m> Finished chain.[0m


In [9]:
qns4 = db_chain("if i sold all the quantities in the store with discount, what book will make me the most money?")



[1m> Entering new SQLDatabaseChain chain...[0m
if i sold all the quantities in the store with discount, what book will make me the most money?
SQLQuery:[32;1m[1;3mSELECT title, SUM(price * stock_quantity * (1 - pct_discount)) AS total_revenue FROM books JOIN discounts ON books.book_id = discounts.book_id GROUP BY title ORDER BY total_revenue DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Book5', Decimal('-2070.0000'))][0m
Answer:[32;1m[1;3mBook5[0m
[1m> Finished chain.[0m


In [10]:
qns4 = db_chain.run("SELECT title, SUM(price * stock_quantity * (1 - pct_discount/100)) AS total_revenue FROM books JOIN discounts ON books.book_id = discounts.book_id GROUP BY title ORDER BY total_revenue DESC LIMIT 1")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT title, SUM(price * stock_quantity * (1 - pct_discount/100)) AS total_revenue FROM books JOIN discounts ON books.book_id = discounts.book_id GROUP BY title ORDER BY total_revenue DESC LIMIT 1
SQLQuery:[32;1m[1;3mSELECT title, SUM(price * stock_quantity * (1 - pct_discount/100)) AS total_revenue FROM books JOIN discounts ON books.book_id = discounts.book_id GROUP BY title ORDER BY total_revenue DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Book3', Decimal('3344.00000000'))][0m
Answer:[32;1m[1;3mBook3[0m
[1m> Finished chain.[0m


In [11]:
qns5 = db_chain("whats the total quantity of books that are written by authors who were born in 1967")



[1m> Entering new SQLDatabaseChain chain...[0m
whats the total quantity of books that are written by authors who were born in 1967
SQLQuery:[32;1m[1;3mSELECT count(*) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year = 1967)[0m
SQLResult: [33;1m[1;3m[(13,)][0m
Answer:[32;1m[1;3m13[0m
[1m> Finished chain.[0m


In [12]:
qns5 = db_chain.run("SELECT SUM(stock_quantity) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year = 1967)")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT SUM(stock_quantity) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year = 1967)
SQLQuery:[32;1m[1;3mSELECT SUM(stock_quantity) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year = 1967)[0m
SQLResult: [33;1m[1;3m[(Decimal('696'),)][0m
Answer:[32;1m[1;3m696[0m
[1m> Finished chain.[0m


In [13]:
qns6 = db_chain("how many books were written by the authors name is 'author9'")



[1m> Entering new SQLDatabaseChain chain...[0m
how many books were written by the authors name is 'author9'
SQLQuery:[32;1m[1;3mSELECT count(*) FROM books WHERE author_id = 9[0m
SQLResult: [33;1m[1;3m[(8,)][0m
Answer:[32;1m[1;3m8[0m
[1m> Finished chain.[0m


In [14]:
qns6 = db_chain.run("SELECT count(*) FROM books JOIN authors ON books.author_id = authors.author_id WHERE author_name = 'Author9';")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT count(*) FROM books JOIN authors ON books.author_id = authors.author_id WHERE author_name = 'Author9';
SQLQuery:[32;1m[1;3mSELECT count(*) FROM books JOIN authors ON books.author_id = authors.author_id WHERE author_name = 'Author9';[0m
SQLResult: [33;1m[1;3m[(10,)][0m
Answer:[32;1m[1;3m10[0m
[1m> Finished chain.[0m


In [None]:
# LLM might fail here
qns7 = db_chain("whos books we have the most by quantiity wise")

In [16]:
qns7 = db_chain.run("SELECT author_name, SUM(stock_quantity) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT author_name, SUM(stock_quantity) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;
SQLQuery:[32;1m[1;3mSELECT author_name, SUM(stock_quantity) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Author3', Decimal('782'))][0m
Answer:[32;1m[1;3mAuthor3[0m
[1m> Finished chain.[0m


In [17]:
qns8 = db_chain("which author wrote the highest number of books?")



[1m> Entering new SQLDatabaseChain chain...[0m
which author wrote the highest number of books?
SQLQuery:[32;1m[1;3mSELECT author_name FROM authors GROUP BY author_id ORDER BY COUNT(*) DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Author0',)][0m
Answer:[32;1m[1;3mAuthor0[0m
[1m> Finished chain.[0m


In [18]:
qns8 = db_chain.run("SELECT author_name, COUNT(*) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;")



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT author_name, COUNT(*) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;
SQLQuery:[32;1m[1;3mSELECT author_name, COUNT(*) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Author3', 14)][0m
Answer:[32;1m[1;3mAuthor3[0m
[1m> Finished chain.[0m


> Setting up for Few-Shot learning

In [19]:
few_shots = [
    {
        'Question': "What is the revenue i get if i sell all the books named 'book0' with the discount?",
        'SQLQuery': "SELECT price*(1 - pct_discount/100)*stock_quantity AS revenue FROM books JOIN discounts ON books.book_id = discounts.book_id WHERE title = 'Book0'",
        'SQLResult': "Result of the SQL query",
        'Answer': qns1
    },
    {
        'Question': "What is the total number of fiction books that are in the store?",
        'SQLQuery': "SELECT SUM(stock_quantity) FROM books WHERE genre = 'Fiction'",
        'SQLResult': "Result of the SQL query",
        'Answer': qns2
    },
    {
        'Question': "if i sold all the quantities in the store with discount, what book will make me the most money?",
        'SQLQuery': "SELECT title, SUM(price * stock_quantity * (1 - pct_discount/100)) AS total_revenue FROM books JOIN discounts ON books.book_id = discounts.book_id GROUP BY title ORDER BY total_revenue DESC LIMIT 1",
        'SQLResult': "Result of the SQL query",
        'Answer': qns4
    },
    {
        'Question': "whats the total quantity of books that are written by authors who were born in 1967",
        'SQLQuery': "SELECT SUM(stock_quantity) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year = 1967)",
        'SQLResult': "Result of the SQL query",
        'Answer': qns5
    },
    {
        'Question': "how many books were written by the authors name is 'author9'",
        'SQLQuery': "SELECT count(*) FROM books JOIN authors ON books.author_id = authors.author_id WHERE author_name = 'Author9';",
        'SQLResult': "Result of the SQL query",
        'Answer': qns6
    },
    {
        'Question': "whos books we have the most by quantiity wise",
        'SQLQuery': "SELECT author_name, SUM(stock_quantity) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;",
        'SQLResult': "Result of the SQL query",
        'Answer': qns7
    },
    {
        'Question': "which author wrote the highest number of books?",
        'SQLQuery':"SELECT author_name, COUNT(*) AS total_stock FROM authors JOIN books USING (author_id) GROUP BY author_name ORDER BY total_stock DESC LIMIT 1;",
        'SQLResult': "Result of the SQL query",
        'Answer': qns8
    } 
]

In [20]:
to_vectorize = [" ".join(example.values()) for example in few_shots]

In [21]:
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

In [22]:
vectorestore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=few_shots)

In [23]:
example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorestore,
    k=2,        # 2 select examples
)

# Test the example selector
example_selector.select_examples({"Question": "how many books were written by the authors name is 'author7'"})

[{'Answer': '10',
  'Question': "how many books were written by the authors name is 'author9'",
  'SQLQuery': "SELECT count(*) FROM books JOIN authors ON books.author_id = authors.author_id WHERE author_name = 'Author9';",
  'SQLResult': 'Result of the SQL query'},
 {'Answer': '696',
  'Question': 'whats the total quantity of books that are written by authors who were born in 1967',
  'SQLQuery': 'SELECT SUM(stock_quantity) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year = 1967)',
  'SQLResult': 'Result of the SQL query'}]

In [24]:
print(_mysql_prompt)    # Prompt

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

In [25]:
example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
)

In [26]:
few_shots_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"]
)

In [27]:
# Few-Shot LLM
tuned_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shots_prompt)

> Testing 

In [28]:
tuned_chain("if i sold all the quantities in the store with discount, what book will make me the least money?")



[1m> Entering new SQLDatabaseChain chain...[0m
if i sold all the quantities in the store with discount, what book will make me the least money?
SQLQuery:[32;1m[1;3mSELECT title, SUM(price * stock_quantity * (1 - pct_discount/100)) AS total_revenue FROM books JOIN discounts ON books.book_id = discounts.book_id GROUP BY title ORDER BY total_revenue LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Book5', Decimal('207.00000000'))][0m
Answer:[32;1m[1;3mBook5[0m
[1m> Finished chain.[0m


{'query': 'if i sold all the quantities in the store with discount, what book will make me the least money?',
 'result': 'Book5'}

In [29]:
tuned_chain("whats the number of books that are written by authors who were born before 1970")



[1m> Entering new SQLDatabaseChain chain...[0m
whats the number of books that are written by authors who were born before 1970
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year < 1970)[0m
SQLResult: [33;1m[1;3m[(62,)][0m
Answer:[32;1m[1;3m62[0m
[1m> Finished chain.[0m


{'query': 'whats the number of books that are written by authors who were born before 1970',
 'result': '62'}

In [30]:
tuned_chain("whats the number of books that are written by authors who were born in and after 1970")



[1m> Entering new SQLDatabaseChain chain...[0m
whats the number of books that are written by authors who were born in and after 1970
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM books WHERE author_id IN (SELECT author_id FROM authors WHERE birth_year >= 1970)[0m
SQLResult: [33;1m[1;3m[(38,)][0m
Answer:[32;1m[1;3m38[0m
[1m> Finished chain.[0m


{'query': 'whats the number of books that are written by authors who were born in and after 1970',
 'result': '38'}

In [34]:
# Answers list to be used in main.py
qns_answers = [qns1, qns2, qns4, qns5, qns6, qns7, qns8]
qns_answers

['1350.00', '1349', 'Book3', '696', '10', 'Author3', 'Author3']