In [36]:
import pandas as pd
import duckdb as db
from openai import OpenAI
import os

In [37]:
user_template = "Write a SQL query that returns - {}"

system_template = """
Given the following SQL table, your job is to write queries based on a user's request. \n
CREATE TABLE {} ({}) \n
"""

retail_sales = pd.read_csv("../data/Red30 Tech US online retail sales.csv")

In [38]:
question = "What are the total sales by state?"

tbl_name = "retail_sales"

tbl_description = db.sql("DESCRIBE SELECT * FROM " + tbl_name + ";")
col_attr = tbl_description.df()[["column_name", "column_type"]]
col_attr["column_joint"] = col_attr["column_name"] + " " + col_attr["column_type"]
tbl_schema = (
    str(list(col_attr["column_joint"].values))
    .replace("[", "")
    .replace("]", "")
    .replace("'", "")
)

In [39]:
user = user_template.format(question)
system = system_template.format(tbl_name, tbl_schema)

In [40]:
gemini_api_key = os.environ.get("GEMINI_API_KEY")

gemini_client = OpenAI(
    api_key=gemini_api_key,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)

temperature = 0
max_tokens = 5000

response = gemini_client.chat.completions.create(
    model="gemini-2.5-pro",
    messages=[{"role": "system", "content": system}, {"role": "user", "content": user}],
    temperature=temperature,
    max_completion_tokens=max_tokens,
)

In [41]:
print(response.choices[0].message.content)


```sql
SELECT
  CustState,
  SUM(OrderTotal) AS TotalSales
FROM retail_sales
GROUP BY
  CustState
ORDER BY
  TotalSales DESC;
```


In [42]:
query = response.choices[0].message.content

db.sql(query)

ParserException: Parser Error: syntax error at or near "```"

In [43]:
import re


def is_markdown_code_chunk(text):
    """
    Checks if the given text is in Markdown code chunk format.

    Args:
        text (str): The text to check.

    Returns:
        bool: True if the text is in Markdown code chunk format, False otherwise.
    """
    pattern = r"```[^`]*```"
    return bool(re.search(pattern, text, re.DOTALL))


def extract_code_from_markdown(markdown_text):
    """
    Extracts code from a Markdown code chunk.

    Args:
        markdown_text (str): The Markdown text containing the code chunk.

    Returns:
        str: The extracted code.
    """
    pattern = r"```(.*?)\n(?P<code>.*?)\n```"
    match = re.search(pattern, markdown_text, re.DOTALL)
    if match:
        return match.group("code")
    else:
        return None


In [53]:
if is_markdown_code_chunk(text=query):
    clean_query = extract_code_from_markdown(markdown_text=query)
else:
    clean_query = query

print(clean_query)


SELECT
  CustState,
  SUM(OrderTotal) AS TotalSales
FROM retail_sales
GROUP BY
  CustState
ORDER BY
  TotalSales DESC;


In [45]:
db.sql(clean_query)


┌────────────────┬────────────────────┐
│   CustState    │     TotalSales     │
│    varchar     │       double       │
├────────────────┼────────────────────┤
│ New York       │  616925.8400000001 │
│ California     │  540285.5199999997 │
│ Florida        │  394483.7399999998 │
│ Texas          │ 349925.48000000004 │
│ North Carolina │  345118.2600000001 │
│ Pennsylvania   │ 290120.09000000014 │
│ Minnesota      │ 267308.95000000007 │
│ Washington     │ 257140.79999999993 │
│ Virginia       │ 248797.89000000004 │
│ Georgia        │ 237607.26000000004 │
│    ·           │          ·         │
│    ·           │          ·         │
│    ·           │          ·         │
│ New Mexico     │ 36696.380000000005 │
│ New Jersey     │            35971.9 │
│ North Dakota   │            29907.5 │
│ West Virginia  │           23967.97 │
│ Nebraska       │           22604.36 │
│ Wyoming        │  9712.550000000001 │
│ Wisconsin      │  7345.029999999999 │
│ South Dakota   │            4071.62 │


In [50]:
system_template2 = """
Given the following SQL table, your job is to write queries based on a user's request. \n
Return just the SQL query as plan text, without additional text and don't use markdown format. \n
CREATE TABLE {} ({}) \n
"""

system2 = system_template2.format(tbl_name, tbl_schema)


response2 = gemini_client.chat.completions.create(
    model="gemini-2.5-pro",
    messages=[{"role": "system", "content": system2}, {"role": "user", "content": user}],
    temperature=temperature,
    max_completion_tokens=max_tokens,
)


In [51]:
print(response2.choices[0].message.content)


SELECT CustState, SUM(OrderTotal) FROM retail_sales GROUP BY CustState


In [52]:
query2 = response2.choices[0].message.content
db.sql(query2)

┌──────────────────────┬────────────────────┐
│      CustState       │  sum(OrderTotal)   │
│       varchar        │       double       │
├──────────────────────┼────────────────────┤
│ Virginia             │ 248797.89000000004 │
│ Hawaii               │           61579.92 │
│ District of Columbia │ 131662.74000000002 │
│ Louisiana            │ 113639.52999999996 │
│ Washington           │ 257140.79999999993 │
│ Connecticut          │ 57114.099999999984 │
│ Montana              │ 51841.609999999986 │
│ Minnesota            │ 267308.95000000007 │
│ New Hampshire        │ 59687.459999999985 │
│ Mississippi          │           78850.98 │
│     ·                │               ·    │
│     ·                │               ·    │
│     ·                │               ·    │
│ New Mexico           │ 36696.380000000005 │
│ Iowa                 │  68893.29000000001 │
│ Kentucky             │          115596.76 │
│ Wisconsin            │  7345.029999999999 │
│ New York             │  616925.8