In [2]:
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
from langchain.prompts.prompt import PromptTemplate
from langchain.tools.sql_database.tool import (
    QueryCheckerTool,
    ListSQLDatabaseTool,
    InfoSQLDatabaseTool,
    QuerySQLDataBaseTool,
)

In [9]:
import traceback
from langchain.chat_models import ChatOpenAI
from langchain_visualizer import visualize
from langchain.callbacks import get_openai_callback
import os
import langchain

langchain.debug = True


db_path = os.path.abspath("./orders.db")

db = SQLDatabase.from_uri(
    f"sqlite:////{db_path}",
    include_tables=["return_policy", "category", "product", "order"],
    sample_rows_in_table_info=3,
)
print(db.table_info)
llm = ChatOpenAI(
    temperature=0,
    model_name="gpt-3.5-turbo",
)

_DEFAULT_PROMPT = """You are an agent designed to interact with a SQL database.\nGiven an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
\nUnless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
\nYou can order the results by a relevant column to return the most interesting examples in the database.
\nNever query for all the columns from a specific table, only ask for the relevant columns given the question.
\nYou MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
\n\nDO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n\nIf the question does not seem related to the database, just return \"I don't know\" as the answer.
\n If the query is not correct, an error message will be returned.
\n If an error is returned, rewrite the query, check the query, and try again.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}

Question: {input}"""

PROMPT = PromptTemplate(
    template=_DEFAULT_PROMPT, input_variables=["input", "table_info"]
)

chain = SQLDatabaseChain.from_llm(
    llm=llm,
    db=db,
    verbose=True,
    prompt=PROMPT,
    use_query_checker=True,
)


def chain_executor(query):
    with get_openai_callback() as cb:
        response = chain.run(query)
        print(f"Total Tokens: {cb.total_tokens}")
        print(f"Prompt Tokens: {cb.prompt_tokens}")
        print(f"Completion Tokens: {cb.completion_tokens}")
        print(f"Total Cost (USD): ${cb.total_cost}")
        return response


CREATE TABLE category (
	id INTEGER, 
	name VARCHAR, 
	PRIMARY KEY (id), 
	UNIQUE (name)
)

/*
3 rows from category table:
id	name
1	Electronics
2	Fashion
3	Home & Kitchen
*/


CREATE TABLE product (
	id INTEGER, 
	name VARCHAR, 
	category_id INTEGER, 
	PRIMARY KEY (id), 
	FOREIGN KEY(category_id) REFERENCES category (id)
)

/*
3 rows from product table:
id	name	category_id
1	Samsung Galaxy S21	1
2	Apple MacBook Pro	1
3	Sony 65" 4K Smart TV	1
*/


CREATE TABLE return_policy (
	id INTEGER, 
	category_id INTEGER, 
	return_policy VARCHAR, 
	return_window INTEGER, 
	policy_name VARCHAR, 
	PRIMARY KEY (id), 
	FOREIGN KEY(category_id) REFERENCES category (id)
)

/*
3 rows from return_policy table:
id	category_id	return_policy	return_window	policy_name
1	1	Items can be returned within 30 days of purchase with the original receipt.	30	Electronics Return Policy
2	2	Clothing and accessories can be returned within 14 days of purchase in unused condition with tags at	14	Fashion Return Policy
3	3	

In [10]:
res = chain.run("how many orders are in order table")

[32;1m[1;3m[chain/start][0m [1m[1:chain:SQLDatabaseChain] Entering Chain run with input:
[0m{
  "query": "how many orders are in order table"
}
[32;1m[1;3m[chain/start][0m [1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:
[0m{
  "input": "how many orders are in order table\nSQLQuery:",
  "top_k": "5",
  "dialect": "sqlite",
  "table_info": "\nCREATE TABLE category (\n\tid INTEGER, \n\tname VARCHAR, \n\tPRIMARY KEY (id), \n\tUNIQUE (name)\n)\n\n/*\n3 rows from category table:\nid\tname\n1\tElectronics\n2\tFashion\n3\tHome & Kitchen\n*/\n\n\nCREATE TABLE product (\n\tid INTEGER, \n\tname VARCHAR, \n\tcategory_id INTEGER, \n\tPRIMARY KEY (id), \n\tFOREIGN KEY(category_id) REFERENCES category (id)\n)\n\n/*\n3 rows from product table:\nid\tname\tcategory_id\n1\tSamsung Galaxy S21\t1\n2\tApple MacBook Pro\t1\n3\tSony 65\" 4K Smart TV\t1\n*/\n\n\nCREATE TABLE return_policy (\n\tid INTEGER, \n\tcategory_id INTEGER, \n\treturn_policy VARCHAR, \n\treturn_w

OperationalError: (sqlite3.OperationalError) near "The": syntax error
[SQL: The original query seems to be correct. However, it is important to note that "order" is a reserved keyword in SQLite and should be enclosed in double quotes to avoid syntax errors. Therefore, the corrected query is:

SELECT COUNT(*) FROM "order"]
(Background on this error at: https://sqlalche.me/e/14/e3q8)

In [8]:
chain_executor("list all the products")

[32;1m[1;3m[chain/start][0m [1m[1:chain:SQLDatabaseChain] Entering Chain run with input:
[0m{
  "query": "list all the products"
}
[32;1m[1;3m[chain/start][0m [1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:
[0m{
  "input": "list all the products\nSQLQuery:",
  "top_k": "5",
  "dialect": "sqlite",
  "table_info": "\nCREATE TABLE category (\n\tid INTEGER, \n\tname VARCHAR, \n\tPRIMARY KEY (id), \n\tUNIQUE (name)\n)\n\n/*\n3 rows from category table:\nid\tname\n1\tElectronics\n2\tFashion\n3\tHome & Kitchen\n*/\n\n\nCREATE TABLE product (\n\tid INTEGER, \n\tname VARCHAR, \n\tcategory_id INTEGER, \n\tPRIMARY KEY (id), \n\tFOREIGN KEY(category_id) REFERENCES category (id)\n)\n\n/*\n3 rows from product table:\nid\tname\tcategory_id\n1\tSamsung Galaxy S21\t1\n2\tApple MacBook Pro\t1\n3\tSony 65\" 4K Smart TV\t1\n*/\n\n\nCREATE TABLE return_policy (\n\tid INTEGER, \n\tcategory_id INTEGER, \n\treturn_policy VARCHAR, \n\treturn_window INTEGER, \n\tpolicy_

OperationalError: (sqlite3.OperationalError) near "The": syntax error
[SQL: The original query is correct:

SELECT id, name FROM product LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/14/e3q8)