In [None]:
!pip install openai

In [None]:
!pip install langchain langchain_experimental

In [None]:
!pip install sentence-transformers chromadb

Reference: https://medium.com/@yernenip/few-shot-prompting-with-codellama-langchain-and-mysql-94020ee16a08

In [33]:
import os
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"]="ls__922342c2ddcc4de0b72076736e6178c7"

In [34]:
from langchain.llms import OpenAI                         # Open AI (LLM)
# Set the OpenAI API key as an environment variable
os.environ["OPENAI_API_KEY"] = "USE-YOUR-OPEN-AI-KEY"

In [35]:
import os
import langchain
import openai
from langchain import OpenAI, SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
import sqlite3
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from langchain import PromptTemplate, OpenAI
langchain.verbose = True

In [37]:
llm = OpenAI(temperature=0, verbose=True)

In [38]:
#Read database
def read_sql_query(sql, db):
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(sql)
    rows = cur.fetchall()
    for row in rows:
        print(row)
    conn.close()
from langchain.utilities import SQLDatabase               # SQL Database
product_db = SQLDatabase.from_uri('sqlite:///product_db.sqlite')  # Create an instance of SQLDatabase engine

In [None]:
db_chain = SQLDatabaseChain.from_llm(llm, product_db, verbose=True, return_sql=False, use_query_checker=True)
db_chain.run("Find the average age of customers for each product preference category, and then identify customers who are above the average age within their respective product preference category?")

In [99]:
#Creating sample promts for Queries (not for agents):
examples = [
    {
        "input": "How many product categories?",
        "sql_cmd": "SELECT COUNT(DISTINCT product_category) FROM table_product",
        "result":"[(6,)]",
        "answer": "There are 6 product categories"
    },
    {
        "input": "How many customers between age 20-40?",
        "sql_cmd": "SELECT COUNT(*) FROM table_product WHERE age BETWEEN 20 AND 40",
        "result": "[(96,)]",
        "answer": "There are 96 customers with age between 20-40"
    },
    {
        "input": "How many customers are from New York and are male?",
        "sql_cmd": "SELECT COUNT(*) FROM table_product WHERE city = 'New York' AND gender = 'Male'",
        "result": "[(1,)]",
        "answer": "There is 1 customer from New York who is a male"
    },
    {
        "input": "How many unique customers are there from Canada in Product table?",
        "sql_cmd": "SELECT COUNT(DISTINCT customer_id) FROM table_product WHERE country = 'Canada'",
        "result": "[(26,)]",
        "answer": "There are 26 unique customers from Canada"
    },
    {
        "input": "Extract the username name of all emails from the database and find out how many are concatenation of lowecase first name and lowercase last name of users and they are from NY and is a female",
        "sql_cmd": "SELECT COUNT(*) FROM table_product WHERE email LIKE LOWER(first_name || last_name || '%') AND state = 'NY' AND gender = 'Female'",
        "result": "[(1,)]",
        "answer": "There is 1 such email addresses"
    }
]

In [105]:
print(product_db.table_info)


CREATE TABLE table_product (
	customer_id INTEGER, 
	first_name TEXT, 
	last_name TEXT, 
	email TEXT, 
	age INTEGER, 
	gender TEXT, 
	city TEXT, 
	state TEXT, 
	country TEXT, 
	product_category TEXT, 
	product_preference TEXT
)

/*
3 rows from table_product table:
customer_id	first_name	last_name	email	age	gender	city	state	country	product_category	product_preference
1	John	Doe	johndoe@example.com	35	Male	New York	NY	United States	Electronics	Smartphones
2	Jane	Smith	janesmith@example.com	28	Female	Los Angeles	CA	United States	Clothing	Dresses
3	Michael	Johnson	michaeljohnson@example.com	42	Male	Chicago	IL	United States	Home & Kitchen	Appliances
*/


In [106]:
#FEW-shot prompting without vectordb
from langchain import PromptTemplate, OpenAI
from langchain.prompts.few_shot import FewShotPromptTemplate


example_prompt = PromptTemplate(
    input_variables=["input", "sql_cmd", "result", "answer",],
    template="\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {result}\nAnswer: {answer}",
)

few_shot_prompt = FewShotPromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    input_variables=["input", "table_info", "top_k"],
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX,
)

**SQL Query without few shot prompting**

In [121]:
db_chain.run("Find the number of people whose age is less than or equal to the average age with a 2 year variation in age allowed and they are from NY or CA or IL whose prefer to buy smartphones. Also tell me the average age")



[1m> Entering new SQLDatabaseChain chain...[0m
Find the number of people whose age is less than or equal to the average age with a 2 year variation in age allowed and they are from NY or CA or IL whose prefer to buy smartphones. Also tell me the average age
SQLQuery:

[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mYou are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the c

OperationalError: ignored

**SQL Query with few shot prompting**

In [120]:
local_chain = SQLDatabaseChain.from_llm(llm, product_db, prompt=few_shot_prompt, use_query_checker=True,
                                        verbose=True, return_sql=False,)
local_chain.run("Find the number of people whose age is less than or equal to the average age with a 2 year variation in age allowed and they are from NY or CA or IL whose prefer to buy smartphones. Also tell me the average age")



[1m> Entering new SQLDatabaseChain chain...[0m
Find the number of people whose age is less than or equal to the average age with a 2 year variation in age allowed and they are from NY or CA or IL whose prefer to buy smartphones. Also tell me the average age
SQLQuery:

[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mYou are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column n

'The average age is 35.0 and there are 1 people whose age is within 2 years of the average age and they are from NY, CA, or IL and prefer to buy smartphones.'