In [21]:
import pandas as pd
import duckdb as db
from openai import OpenAI
import os


In [22]:
user_template = "Write a SQL query that returns - {}"

In [23]:
system_template = """
Given the following SQL table, your job is to write queries based on a user's request. \n
CREATE TABLE {} ({}) \n
"""


In [24]:
retail_sales = pd.read_csv("../data/Red30 Tech US online retail sales.csv")

In [25]:
question = "What are the total sales by state?"

tbl_name = "retail_sales"


In [26]:
tbl_description = db.sql("DESCRIBE SELECT * FROM " + tbl_name + ";")
tbl_description

┌──────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name  │ column_type │  null   │   key   │ default │  extra  │
│   varchar    │   varchar   │ varchar │ varchar │ varchar │ varchar │
├──────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ OrderNum     │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ OrderDate    │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ OrderType    │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ CustomerType │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ CustName     │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ CustState    │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ ProdCategory │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ ProdNumber   │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ ProdName     │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Quantity     │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ Pric

In [27]:
col_attr = tbl_description.df()[["column_name", "column_type"]]
col_attr["column_joint"] = col_attr["column_name"] + " " + col_attr["column_type"]
tbl_schema = (
    str(list(col_attr["column_joint"].values))
    .replace("[", "")
    .replace("]", "")
    .replace("'", "")
)


In [28]:
print(tbl_schema)

OrderNum BIGINT, OrderDate VARCHAR, OrderType VARCHAR, CustomerType VARCHAR, CustName VARCHAR, CustState VARCHAR, ProdCategory VARCHAR, ProdNumber VARCHAR, ProdName VARCHAR, Quantity BIGINT, Price DOUBLE, Discount DOUBLE, OrderTotal DOUBLE


In [29]:
user = user_template.format(question)
print(user)


Write a SQL query that returns - What are the total sales by state?


In [30]:
system = system_template.format(tbl_name,tbl_schema)
print(system)



Given the following SQL table, your job is to write queries based on a user's request. 

CREATE TABLE retail_sales (OrderNum BIGINT, OrderDate VARCHAR, OrderType VARCHAR, CustomerType VARCHAR, CustName VARCHAR, CustState VARCHAR, ProdCategory VARCHAR, ProdNumber VARCHAR, ProdName VARCHAR, Quantity BIGINT, Price DOUBLE, Discount DOUBLE, OrderTotal DOUBLE) 




In [12]:
openai_api_key = os.environ.get("OPENAI_API_KEY")
max_tokens = 5000
client = OpenAI(api_key=openai_api_key)

In [13]:
response = client.chat.completions.create(
    model="gpt-5",
    messages=[{"role": "system", "content": system}, 
              {"role": "user", "content": user}],
    max_completion_tokens=max_tokens,
)


In [14]:
print(response.choices[0].message.content)


SELECT
  CustState,
  SUM(OrderTotal) AS total_sales
FROM retail_sales
GROUP BY CustState
ORDER BY total_sales DESC;


In [15]:
query = response.choices[0].message.content


In [16]:
db.sql(query)


┌────────────────┬────────────────────┐
│   CustState    │    total_sales     │
│    varchar     │       double       │
├────────────────┼────────────────────┤
│ New York       │  616925.8400000001 │
│ California     │  540285.5199999997 │
│ Florida        │  394483.7399999998 │
│ Texas          │ 349925.48000000004 │
│ North Carolina │  345118.2600000001 │
│ Pennsylvania   │ 290120.09000000014 │
│ Minnesota      │ 267308.95000000007 │
│ Washington     │ 257140.79999999993 │
│ Virginia       │ 248797.89000000004 │
│ Georgia        │ 237607.26000000004 │
│    ·           │          ·         │
│    ·           │          ·         │
│    ·           │          ·         │
│ New Mexico     │ 36696.380000000005 │
│ New Jersey     │            35971.9 │
│ North Dakota   │            29907.5 │
│ West Virginia  │           23967.97 │
│ Nebraska       │           22604.36 │
│ Wyoming        │  9712.550000000001 │
│ Wisconsin      │  7345.029999999999 │
│ South Dakota   │            4071.62 │


In [31]:
question2 = "Which product had the highest sales?"

In [33]:
user2 = user_template.format(question2)

response2 = client.chat.completions.create(
    model="gpt-5",
    messages=[{"role": "system", "content": system}, {"role": "user", "content": user2}],
    max_completion_tokens=max_tokens,
)

In [34]:
print(response2.choices[0].message.content)


SELECT ProdNumber, ProdName, total_sales
FROM (
  SELECT
    ProdNumber,
    ProdName,
    SUM(OrderTotal) AS total_sales,
    DENSE_RANK() OVER (ORDER BY SUM(OrderTotal) DESC) AS rnk
  FROM retail_sales
  GROUP BY ProdNumber, ProdName
) s
WHERE rnk = 1;


In [35]:
query2 = response2.choices[0].message.content
db.sql(query2)


┌────────────┬──────────────┬───────────────────┐
│ ProdNumber │   ProdName   │    total_sales    │
│  varchar   │   varchar    │      double       │
├────────────┼──────────────┼───────────────────┤
│ RS706      │ RWW-75 Robot │ 653773.2000000004 │
└────────────┴──────────────┴───────────────────┘