In [1]:
from openai import OpenAI
import pandas as pd
import json
import duckdb
from pydantic import BaseModel, Field
from IPython.display import Markdown

In [2]:
client = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio")
GENERIC_MODEL = "qwen3-0.6b@bf16"
CODING_MODEL = "olympiccoder-7b"

In [3]:
# define the path to the transactional data
TRANSACTION_DATA_FILE_PATH = 'data/Store_Sales_Price_Elasticity_Promotions_Data.parquet'

In [4]:
# prompt template for step 2 of tool 1
SQL_GENERATION_PROMPT = """
Generate an SQL query based on a prompt. Do not reply with anything besides the SQL query.
The prompt is: {prompt}

The available columns are: {columns}
The table name is: {table_name}
"""

print(SQL_GENERATION_PROMPT)


Generate an SQL query based on a prompt. Do not reply with anything besides the SQL query.
The prompt is: {prompt}

The available columns are: {columns}
The table name is: {table_name}



### Tool 1: Database Lookup

In [5]:
# code for step 2 of tool 1
def generate_sql_query(prompt: str, MODEL: str, columns: list, table_name: str) -> str:
    """Generate an SQL query based on a prompt"""
    formatted_prompt = SQL_GENERATION_PROMPT.format(prompt=prompt, 
                                                    columns=columns, 
                                                    table_name=table_name)

    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
    )
    
    return response.choices[0].message.content

In [6]:
# code for tool 1
def lookup_sales_data(prompt: str, MODEL: str) -> str:
    """Implementation of sales data lookup from parquet file using SQL"""
    try:

        # define the table name
        table_name = "sales"
        
        # step 1: read the parquet file into a DuckDB table
        df = pd.read_parquet(TRANSACTION_DATA_FILE_PATH)
        duckdb.sql(f"CREATE TABLE IF NOT EXISTS {table_name} AS SELECT * FROM df")

        # step 2: generate the SQL code
        sql_query = generate_sql_query(prompt, MODEL, df.columns, table_name)
        # clean the response to make sure it only includes the SQL code
        sql_query = sql_query.strip()
        sql_query = sql_query.replace("```sql", "").replace("```", "")
        idx = sql_query.find("</think>") + len("</think>")
        sql_query = sql_query[idx:]
        
        # step 3: execute the SQL query
        print(sql_query)
        result = duckdb.sql(sql_query).df()
        
        return result.to_string()
    except Exception as e:
        return f"Error accessing data: {str(e)}"

In [7]:
table_name = "sales"

# step 1: read the parquet file into a DuckDB table
df = pd.read_parquet(TRANSACTION_DATA_FILE_PATH)
duckdb.sql(f"CREATE TABLE IF NOT EXISTS {table_name} AS SELECT * FROM df")

In [8]:
prompt = "Show me all the sales for store 1320 on November 1st, 2021"
example_data = lookup_sales_data(prompt, CODING_MODEL)
print(example_data)



SELECT * FROM sales WHERE Store_Number = 1320 AND Sold_Date = '2021-11-01'
    Store_Number  SKU_Coded  Product_Class_Code  Sold_Date  Qty_Sold  Total_Sale_Value  On_Promo
0           1320    6173050               22875 2021-11-01         1          4.990000         0
1           1320    6174250               22875 2021-11-01         1          0.890000         0
2           1320    6176200               22975 2021-11-01         2         99.980003         0
3           1320    6176800               22800 2021-11-01         1         14.970000         0
4           1320    6177250               22975 2021-11-01         1          6.890000         0
5           1320    6177300               22800 2021-11-01         1          9.990000         0
6           1320    6177350               22800 2021-11-01         2         16.980000         0
7           1320    6177700               22875 2021-11-01         1          3.190000         0
8           1320    6178000               22875 20

### Tool 2: Data Analysis

In [9]:
# Construct prompt based on analysis type and data subset
DATA_ANALYSIS_PROMPT = """
Analyze the following data: {data}
Your job is to answer the following question: {prompt}
"""
print(DATA_ANALYSIS_PROMPT)


Analyze the following data: {data}
Your job is to answer the following question: {prompt}



In [10]:
# code for tool 2
def analyze_sales_data(prompt: str, MODEL:str, data: str) -> str:
    """Implementation of AI-powered sales data analysis"""
    formatted_prompt = DATA_ANALYSIS_PROMPT.format(data=data, prompt=prompt)

    response = client.chat.completions.create(
        model=MODEL,
        messages=[
            {"role": "system", "content": "/no_think"},
            {"role": "user", "content": formatted_prompt}
        ],   
    )
        
    analysis = response.choices[0].message.content
    idx = analysis.find("</think>") + len("</think>")
    analysis = analysis[idx:]
    return analysis if analysis else "No analysis could be generated"

In [11]:
formatted_prompt = DATA_ANALYSIS_PROMPT.format(data=example_data, prompt=prompt)

response = client.chat.completions.create(
    model=CODING_MODEL,
    messages=[
        {"role": "user", "content": formatted_prompt}
    ],   
)
    
analysis = response.choices[0].message.content
idx = analysis.find("</think>") + len("</think>")
analysis = analysis[idx:]

In [12]:
print(response.choices[0].message.content)

Okay, let's see. The task is to find all the sales for Store 1320 on 2021-11-01. So I need to look through the data and filter the rows where Store_Number is 1320 and Sold_Date is '2021-11-01'. Then, present those entries.

Looking at the data provided, all the entries here are from Store 1320. Wait, in the first entry, Store_Number is 1320, and all subsequent entries in this snippet also have 1320. So maybe the entire dataset here is for store 1320? Because the sample given shows only entries from that store.

But the question says to show all sales for that store on that date. So even if there are other dates mixed in, we need to filter for Sold_Date '2021-11-01'.

Wait, but looking at the data provided, all the entries have the same Sold_Date of 2021-11-01. So maybe this is a test run. But according to the sample data given, all rows are from that store and date. So perhaps in the actual full dataset, there are other stores and dates, but here we only see Store 1320's entries for th

In [13]:
analysis_prompt = "what trends do you see in this data"
print(analyze_sales_data(prompt=analysis_prompt,
                         MODEL=CODING_MODEL,
                         data=example_data))

APITimeoutError: Request timed out.

### Tool 3: Data Visualization

In [None]:
# prompt template for step 1 of tool 3
CHART_CONFIGURATION_PROMPT = """
Generate a chart configuration based on this data: {data}
The goal is to show: {visualization_goal}
"""

# prompt template for step 2 of tool 3
CREATE_CHART_PROMPT = """
Write python code to create a chart based on the following configuration.
Only return the code, no other text.
config: {config}
"""

print(CHART_CONFIGURATION_PROMPT, CREATE_CHART_PROMPT)

In [None]:
# class defining the response format of step 1 of tool 3
class VisualizationConfig(BaseModel):
    chart_type: str = Field(..., description="Type of chart to generate")
    x_axis: str = Field(..., description="Name of the x-axis column")
    y_axis: str = Field(..., description="Name of the y-axis column")
    title: str = Field(..., description="Title of the chart")

In [None]:
# code for step 1 of tool 3
def extract_chart_config(data: str, MODEL:str, visualization_goal: str) -> dict:
    """Generate chart visualization configuration
    
    Args:
        data: String containing the data to visualize
        visualization_goal: Description of what the visualization should show
        
    Returns:
        Dictionary containing line chart configuration
    """
    formatted_prompt = CHART_CONFIGURATION_PROMPT.format(data=data,
                                                         visualization_goal=visualization_goal)
    
    response = client.beta.chat.completions.parse(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
        response_format=VisualizationConfig,
    )
    
    try:
        # Extract axis and title info from response
        content = response.choices[0].message.content
        idx = content.find("</think>") + len("</think>")
        content = content[idx:]
        # Return structured chart config
        return {
            "chart_type": content.chart_type,
            "x_axis": content.x_axis,
            "y_axis": content.y_axis,
            "title": content.title,
            "data": data
        }
    except Exception:
        return {
            "chart_type": "line", 
            "x_axis": "date",
            "y_axis": "value",
            "title": visualization_goal,
            "data": data
        }

In [None]:
# code for step 2 of tool 3
def create_chart(config: dict, MODEL:str) -> str:
    """Create a chart based on the configuration"""
    formatted_prompt = CREATE_CHART_PROMPT.format(config=config)
    
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
    )
    
    code = response.choices[0].message.content
    code = code.replace("```python", "").replace("```", "")
    code = code.strip()
    code = code[code.find("</think>") + len("</think>"):]
    return code

In [None]:
# code for tool 3
def generate_visualization(data: str, visualization_goal: str, MODEL:str) -> str:
    """Generate a visualization based on the data and goal"""
    config = extract_chart_config(data, MODEL, visualization_goal)
    code = create_chart(config, MODEL)
    return code

In [None]:
chart_prompt = "A bar chart of sales by product SKU. Put the product SKU on the x-axis and the sales on the y-axis."
code = generate_visualization(example_data, 
                              chart_prompt=chart_prompt, MODEL=CODING_MODEL)
print(code)

In [None]:
exec(code)