In [1]:
import openai
import pandas as pd
from sqlalchemy import create_engine, text

import os
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
openai.api_key = os.getenv("OPENAI_API_KEY")
# print(openai.api_key)

In [3]:
file_path = "customer_orders.csv"
df = pd.read_csv(file_path)
df.head()

Unnamed: 0,ORDERNUMBER,QUANTITYORDERED,PRICEEACH,ORDERLINENUMBER,SALES,ORDERDATE,STATUS,QTR_ID,MONTH_ID,YEAR_ID,...,ADDRESSLINE1,ADDRESSLINE2,CITY,STATE,POSTALCODE,COUNTRY,TERRITORY,CONTACTLASTNAME,CONTACTFIRSTNAME,DEALSIZE
0,10107,30,95.7,2,2871.0,2/24/2003 0:00,Shipped,1,2,2003,...,897 Long Airport Avenue,,NYC,NY,10022.0,USA,,Yu,Kwai,Small
1,10121,34,81.35,5,2765.9,5/7/2003 0:00,Shipped,2,5,2003,...,59 rue de l'Abbaye,,Reims,,51100.0,France,EMEA,Henriot,Paul,Small
2,10134,41,94.74,2,3884.34,7/1/2003 0:00,Shipped,3,7,2003,...,27 rue du Colonel Pierre Avia,,Paris,,75508.0,France,EMEA,Da Cunha,Daniel,Medium
3,10145,45,83.26,6,3746.7,8/25/2003 0:00,Shipped,3,8,2003,...,78934 Hillside Dr.,,Pasadena,CA,90003.0,USA,,Young,Julie,Medium
4,10159,49,100.0,14,5205.27,10/10/2003 0:00,Shipped,4,10,2003,...,7734 Strong St.,,San Francisco,CA,,USA,,Brown,Julie,Medium


In [4]:
engine = create_engine("sqlite:///:memory:",echo=True)
df.to_sql("SalesTable",engine)

2024-08-14 10:08:14,664 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-08-14 10:08:14,678 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("SalesTable")
2024-08-14 10:08:14,679 INFO sqlalchemy.engine.Engine [raw sql] ()
2024-08-14 10:08:14,679 INFO sqlalchemy.engine.Engine PRAGMA temp.table_info("SalesTable")
2024-08-14 10:08:14,679 INFO sqlalchemy.engine.Engine [raw sql] ()
2024-08-14 10:08:14,679 INFO sqlalchemy.engine.Engine 
CREATE TABLE "SalesTable" (
	"index" BIGINT, 
	"ORDERNUMBER" BIGINT, 
	"QUANTITYORDERED" BIGINT, 
	"PRICEEACH" FLOAT, 
	"ORDERLINENUMBER" BIGINT, 
	"SALES" FLOAT, 
	"ORDERDATE" TEXT, 
	"STATUS" TEXT, 
	"QTR_ID" BIGINT, 
	"MONTH_ID" BIGINT, 
	"YEAR_ID" BIGINT, 
	"PRODUCTLINE" TEXT, 
	"MSRP" BIGINT, 
	"PRODUCTCODE" TEXT, 
	"CUSTOMERNAME" TEXT, 
	"PHONE" TEXT, 
	"ADDRESSLINE1" TEXT, 
	"ADDRESSLINE2" TEXT, 
	"CITY" TEXT, 
	"STATE" TEXT, 
	"POSTALCODE" TEXT, 
	"COUNTRY" TEXT, 
	"TERRITORY" TEXT, 
	"CONTACTLASTNAME" TEXT, 
	"CONTACTFIRSTNAME" TEXT, 
	"DEALSI

2823

In [5]:
def execute_sql_query(connection,query):
    result = connection.execute(text(query))
    return result.fetchall()

In [6]:
with engine.connect() as conn:
    sales_sum_result = execute_sql_query(conn,"select sum(sales) from SalesTable")

2024-08-14 10:08:14,739 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-08-14 10:08:14,740 INFO sqlalchemy.engine.Engine select sum(sales) from SalesTable
2024-08-14 10:08:14,740 INFO sqlalchemy.engine.Engine [generated in 0.00139s] ()
2024-08-14 10:08:14,742 INFO sqlalchemy.engine.Engine ROLLBACK


In [7]:
sales_sum_result

[(10032628.85,)]

In [8]:
def table_schema_prompt(df):
    columns = ", ".join(df.columns)
    return f'Knowing that the table Columns are: SalesTable({columns}). '

# sch_prompt = table_schema_prompt(df)
# print(table_schema_prompt(df))

In [9]:
def get_user_input():
    return input('Enter your query: ')
# get_user_input()

In [10]:
# def full_prompt(df,user_prompt):
#     schema_prompt = table_schema_prompt(df)
#     return f"{schema_prompt}### Query to answer: {user_prompt}. \r\nSELECT"

# complete_prompt = full_prompt(df,get_user_input())
# print(complete_prompt)

In [11]:
user_question = f"Query to answer: {get_user_input()}. Answer directly with 'SELECT', omit the ``` and the 'sql'."
table_columns = table_schema_prompt(df)
complete_prompt = table_columns+user_question

In [12]:
complete_prompt

"Knowing that the table Columns are: SalesTable(ORDERNUMBER, QUANTITYORDERED, PRICEEACH, ORDERLINENUMBER, SALES, ORDERDATE, STATUS, QTR_ID, MONTH_ID, YEAR_ID, PRODUCTLINE, MSRP, PRODUCTCODE, CUSTOMERNAME, PHONE, ADDRESSLINE1, ADDRESSLINE2, CITY, STATE, POSTALCODE, COUNTRY, TERRITORY, CONTACTLASTNAME, CONTACTFIRSTNAME, DEALSIZE). Query to answer: give me the top 5 most common customer names of each city. Answer directly with 'SELECT', omit the ``` and the 'sql'."

In [13]:
# response = openai.chat.completions.create(
#     model = "gpt-3.5",
#     prompt = complete_prompt,
#     temperature = 0,
#     max_tokens = 150
# )
response = openai.chat.completions.create(
    model = "gpt-4o",
    messages=[
        {"role": "system", "content": "You're an SQL query generating assistant."},
        {"role": "user", "content": complete_prompt}],
        temperature=0
)

In [14]:
def process_response(res):
    raw_query = res.choices[0].message.content.strip()
    if not raw_query.upper().startswith("SELECT"):
        # raw_query = f"SELECT {raw_query}"
        print(f"Unsupported query: {raw_query}. \nQuery must start with 'SELECT'")
        return ""
    elif raw_query.upper().startswith("SELECT"):
        return raw_query

In [15]:
print(process_response(response))

SELECT CITY, CUSTOMERNAME, COUNT(*) AS frequency
FROM SalesTable
GROUP BY CITY, CUSTOMERNAME
ORDER BY CITY, frequency DESC
LIMIT 5;


In [16]:
with engine.connect() as conn:
    final_query = process_response(response)
    resulting_table = execute_sql_query(conn, final_query)

2024-08-14 10:08:48,742 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-08-14 10:08:48,743 INFO sqlalchemy.engine.Engine SELECT CITY, CUSTOMERNAME, COUNT(*) AS frequency
FROM SalesTable
GROUP BY CITY, CUSTOMERNAME
ORDER BY CITY, frequency DESC
LIMIT 5;
2024-08-14 10:08:48,744 INFO sqlalchemy.engine.Engine [generated in 0.00201s] ()
2024-08-14 10:08:48,746 INFO sqlalchemy.engine.Engine ROLLBACK


In [17]:
type(resulting_table)

list

In [18]:
display(resulting_table)

[('Aaarhus', 'Heintze Collectables', 27),
 ('Allentown', 'Diecast Classics Inc.', 31),
 ('Barcelona', 'Enaco Distributors', 23),
 ('Bergamo', 'Rovelli Gifts', 48),
 ('Bergen', 'Herkku Gifts', 29)]

In [19]:
display(pd.DataFrame(resulting_table))

Unnamed: 0,CITY,CUSTOMERNAME,frequency
0,Aaarhus,Heintze Collectables,27
1,Allentown,Diecast Classics Inc.,31
2,Barcelona,Enaco Distributors,23
3,Bergamo,Rovelli Gifts,48
4,Bergen,Herkku Gifts,29
