In [14]:
import os
import warnings
import pandas as pd
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_experimental.tools import PythonAstREPLTool
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser


warnings.filterwarnings("ignore", category=UserWarning)
load_dotenv()

True

In [15]:

file = './data_store/orders.csv'

df = pd.read_csv(file)
print(df.shape)
print(df.head(3))

(1000, 6)
   order_id  user_id         product_name  quantity  price  order_date
0         1        8      Sunflower Seeds        12     98  2024-10-03
1         2       21        Sliced Olives        14     50  2024-09-15
2         3       15  Steak Seasoning Rub         9     84  2024-11-24


In [16]:
llm = AzureChatOpenAI(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    api_version=os.getenv("OPENAI_API_VERSION"),
    azure_endpoint=os.getenv("OPENAI_ENDPOINT"),
    model="gpt-4o",
    temperature=0
    # timeout=None
)

tool = PythonAstREPLTool(locals={"df": df})
llm_with_tools = llm.bind_tools([tool], tool_choice=tool.name)
parser = JsonOutputKeyToolsParser(key_name=tool.name, first_tool_only=True)

In [19]:
system = f"""You have access to a pandas dataframe `df`.
Here is the output of `df.head().to_markdown()`:

\`\`\`
{df.head().to_markdown()}
\`\`\`

Given a user question, write the Python code to answer it.
The Python code should use the `df` dataframe and generate a high-quality graph using this dataframe only.
DO NOT create any new variables or dataframes.
Return ONLY the valid Python code and nothing else.
Don't assume you have access to any libraries other than built-in Python ones, pandas, and plotly.
"""
prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "{question}")
])
chain = prompt | llm_with_tools | parser | tool
result = chain.invoke({"question": "plot a scatter chart showing the total number of orders per users."})
print(result)




In [20]:
chain.invoke({"question": "Plot a bar chart showing the total number of orders per users."})

''

In [18]:
import plotly.express as px

user_order_counts = df.groupby('user_id')['order_id'].count().reset_index()
user_order_counts.rename(columns={'order_id': 'total_orders'}, inplace=True)

# Plot scatter chart
fig = px.scatter(user_order_counts, x='user_id', y='total_orders', title='Total Orders per User', labels={'user_id': 'User ID', 'total_orders': 'Total Orders'})
fig.show()