In [None]:
import os

In [None]:
# Avoiding tokenizers parallelism error
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# Install the latest versions of required packages
%pip install langchain openai langchain_experimental langchain-openai -Uq

%pip install ipywidgets python-dotenv SQLAlchemy psycopg2-binary -Uq

%pip install sqlalchemy-cockroachdb

In [None]:
# List and check the versions of installed packages
%pip list | grep "langchain\|openai"

In [None]:
from langchain import SQLDatabase
from langchain_openai import OpenAI
from langchain_experimental.sql import SQLDatabaseChain
from langchain_experimental.sql.base import SQLDatabaseSequentialChain
from sqlalchemy.exc import ProgrammingError

In [None]:
%load_ext dotenv
%dotenv

In [None]:
CCK_DB_NAME = os.environ.get("CCK_DB_NAME")
CCK_ENDPOINT = os.environ.get("CCK_ENDPOINT")
CCK_PASSWORD = os.environ.get("CCK_PASSWORD")
CCK_PORT = os.environ.get("CCK_PORT")
CCK_USERNAME = os.environ.get("CCK_USERNAME")
CCK_URI = f"cockroachdb://{CCK_USERNAME}:{CCK_PASSWORD}@{CCK_ENDPOINT}:{CCK_PORT}/{CCK_DB_NAME}?sslmode=verify-full"
OpenAI_API_KEY = os.environ.get("OPENAI_API_KEY")

In [None]:
# Loading some sample questions 
Q1 = "Determine details about the user who ordered most items."
Q2 = "Determine order id, date and name of the customer of the most recently placed order"
Q3 = "List the name of the customers who have placed atleast 5 orders"
Q4 = "What is price of the most profitable sale and the name of the customer who placed that order?"
Q5 = "Fetch the most recent order placed by the customer whose id is 282" 

In [None]:
llm = OpenAI(openai_api_key=OpenAI_API_KEY, model_name="gpt-3.5-turbo-instruct", temperature=0, verbose=True)

In [None]:
db = SQLDatabase.from_uri(CCK_URI)
db_chain = SQLDatabaseSequentialChain.from_llm(
    llm, db, verbose=True, use_query_checker=True
)
try:
    db_chain.run(Q1)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

In [None]:
table_info = {
    "customers": """CREATE TABLE customers (
        customer_id INT8 NOT NULL DEFAULT unique_rowid(),
        customer_name VARCHAR(255) NULL,
        gender VARCHAR(100) NULL,
        age INT8 NULL,
        home_address VARCHAR(255) NULL,
        zip_code VARCHAR(20) NULL,
        city VARCHAR(100) NULL,
        state VARCHAR(100) NULL,
        country VARCHAR(100) NULL,
        CONSTRAINT customers_pkey PRIMARY KEY (customer_id ASC))

/*
3 rows from customers table:
"customer_id" "customer_name" "gender" "age" "home_address" "zip_code" "city" "state" "country"
1 "Leanna Busson" "Female" 30 "8606 Victoria TerraceSuite 560" "5464" "Johnstonhaven" "Northern Territory" "Australia"
2 "Zabrina Harrowsmith" "Genderfluid" 69 "8327 Kirlin SummitApt. 461" "8223" "New Zacharyfort" "South Australia" "Australia"
3 "Shina Dullaghan" "Polygender" 59 "269 Gemma SummitSuite 109" "5661" "Aliburgh" "Australian Capital Territory" "Australia"
*/""",
    "products": """CREATE TABLE products (
        product_id INT8 NOT NULL DEFAULT unique_rowid(),
        product_type VARCHAR(100) NULL,
        product_name VARCHAR(100) NULL,
        size VARCHAR(50) NULL,
        colour VARCHAR(50) NULL,
        price DECIMAL(10,2) NULL,
        quantity INT8 NULL,
        description VARCHAR(255) NULL,
        CONSTRAINT products_pkey PRIMARY KEY (product_id ASC))

/*
3 rows from products table:
"product_ID" "product_type" "product_name" "size" "colour" "price" "quantity" "description"
0 "Shirt" "Oxford Cloth" "XS" "red" 114.00 66 "A red coloured, XS sized, Oxford Cloth Shirt"
1 "Shirt" "Oxford Cloth" "S" "red" 114.00 53 "A red coloured, S sized, Oxford Cloth Shirt"
2 "Shirt" "Oxford Cloth" "M" "red" 114.00 54 "A red coloured, M sized, Oxford Cloth Shirt"
*/""",
    "orders": """CREATE TABLE orders (
        order_id INT8 NOT NULL DEFAULT unique_rowid(),
        customer_id INT8 NOT NULL,
        payment DECIMAL(10,2) NOT NULL,
        order_date DATE NULL DEFAULT current_date(),
        delivery_date DATE NULL,
        CONSTRAINT orders_pkey PRIMARY KEY (order_id ASC),
        CONSTRAINT fk_customer_id FOREIGN KEY (customer_id) REFERENCES public.customers(customer_id))

/*
3 rows from orders table:
"order_id" "customer_id" "payment" "order_date" "delivery_date"
1 64 30811.00 "2021-08-30" "2021-09-24"
2 473 50490.00 "2021-02-03" "2021-02-13"
3 774 46763.00 "2021-10-08" "2021-11-03"
*/""",
    "sales": """CREATE TABLE sales (
        sales_id INT8 NOT NULL DEFAULT unique_rowid(),
        order_id INT8 NULL,
        product_id INT8 NULL,
        price_per_unit DECIMAL(10,2) NULL,
        quantity INT8 NULL,
        total_price DECIMAL(10,2) NULL,
        CONSTRAINT sales_pkey PRIMARY KEY (sales_id ASC),
        CONSTRAINT fk_product_id FOREIGN KEY (product_id) REFERENCES public.products(product_id))

/*
3 rows from sales table:
"sales_id" "order_id" "product_id" "price_per_unit" "quantity" "total_price"
0 1 218 106.00 2 212.00
1 1 481 118.00 1 118.00
2 1 2 96.00 3 288.00
*/""",
}

In [None]:
db = SQLDatabase.from_uri(
    CCK_URI,
    include_tables=["customers", "orders", "products", "sales"],
    sample_rows_in_table_info=3,
    custom_table_info=table_info,
)

db_chain = SQLDatabaseSequentialChain.from_llm(
    llm, db, verbose=True, use_query_checker=True, top_k=3
)

try:
    db_chain.run(Q2)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

In [None]:
from langchain.prompts.prompt import PromptTemplate

EXAMPLE_TEMPLATE = """Given an input question, please create a syntactically correct {dialect} SQL query to retrieve the required information. Then, execute the query, observe the results, and provide a concise answer.

Follow this format strictly:

Question: "Original question here"
SQLQuery: "SQL query to execute"
SQLResult: "Results of the SQLQuery"
Answer: "Final answer based on the SQLResult"

Only reference the following tables:

{table_info}

Important Notes:
- If joins are necessary to answer the question, include them in the query.
- Ensure that conditions or filters in the question are correctly applied in the SQL query.

Question: {input}"""

PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "dialect"], template=EXAMPLE_TEMPLATE
)


In [None]:
db = SQLDatabase.from_uri(CCK_URI, include_tables=["customers", "orders", "products", "sales"],)

db_chain = SQLDatabaseChain.from_llm(
    llm,
    db,
    prompt=PROMPT,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,
)

try:
    result = db_chain(Q2)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

result["intermediate_steps"]

In [None]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)

# Example of describing a table
try:
    agent_executor.run("Describe the products table.")
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

In [None]:
try:
    agent_executor.run(Q4)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")