In [1]:
from openai import OpenAI
import pandas as pd
import json
import duckdb
from pydantic import BaseModel, Field
from IPython.display import Markdown
import instructor
from typing import Optional

import warnings
warnings.filterwarnings('ignore')
# Fix for importing helper functions
import sys
import os
# Add the parent directory (Tutorials) to the Python path
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from helper import get_openai_api_key, get_phoenix_endpoint


In [2]:
import phoenix as px
import os
from phoenix.otel import register
from openinference.instrumentation.openai import OpenAIInstrumentor
from openinference.semconv.trace import SpanAttributes
from opentelemetry.trace import Status, StatusCode
from openinference.instrumentation import TracerProvider

In [3]:
# initialize the OpenAI client
openai_api_key = get_openai_api_key()
client = instructor.from_openai(OpenAI())
tool_calling_client = OpenAI(api_key=openai_api_key)

MODEL = "gpt-4o-mini"

In [4]:
PROJECT_NAME = "tracing-agent"

In [None]:
session = px.launch_app()

In [None]:
tracer_provider = register(
    project_name=PROJECT_NAME,
    endpoint= "http://localhost:6006/" + "v1/traces",
    
)

In [7]:
OpenAIInstrumentor().instrument(tracer_provider = tracer_provider)

In [8]:
tracer = tracer_provider.get_tracer(__name__)

In [9]:
# define the path to the transactional data
TRANSACTION_DATA_FILE_PATH = '../data/Store_Sales_Price_Elasticity_Promotions_Data.parquet'

### Tool 1: Retrieve SQL Data

In [10]:
# prompt template for step 2 of tool 1
SQL_GENERATION_PROMPT = """
The prompt is: {prompt}

The available columns are: {columns}
The table name is: {table_name}
"""

In [11]:
class SQLQuery(BaseModel):
    query: str = Field(..., description="The SQL query to execute")

In [12]:
# code for step 2 of tool 1
def generate_sql_query(prompt: str, columns: list, table_name: str) -> str:
    """Generate an SQL query based on a prompt"""
    formatted_prompt = SQL_GENERATION_PROMPT.format(prompt=prompt, 
                                                    columns=columns, 
                                                    table_name=table_name)

    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
        response_model=SQLQuery
    )
    
    return response.query

In [13]:
# code for tool 1
@tracer.tool()
def lookup_sales_data(prompt: str) -> str:
    """Implementation of sales data lookup from parquet file using SQL"""
    try:

        # define the table name
        table_name = "sales"
        
        # step 1: read the parquet file into a DuckDB table
        df = pd.read_parquet(TRANSACTION_DATA_FILE_PATH)
        duckdb.sql(f"CREATE TABLE IF NOT EXISTS {table_name} AS SELECT * FROM df")

        # step 2: generate the SQL code
        sql_query = generate_sql_query(prompt, df.columns, table_name)
        # clean the response to make sure it only includes the SQL code
        sql_query = sql_query.strip()
        sql_query = sql_query.replace("```sql", "").replace("```", "")
        
        # step 3: execute the SQL query
        with tracer.start_as_current_span(
            "execute_sql_query", 
            openinference_span_kind="chain"
        ) as span:
            span.set_input(sql_query)
            result = duckdb.sql(sql_query).df()
            span.set_output(value=result)
            span.set_status(StatusCode.OK)
        
        return result.to_string()
    except Exception as e:
        return f"Error accessing data: {str(e)}"

In [14]:
# example_data = lookup_sales_data("Show me all the sales for store 1320 on November 1st, 2021")
# print(example_data)

### Tool 2: Data Analysis

In [15]:
# Construct prompt based on analysis type and data subset
DATA_ANALYSIS_PROMPT = """
Analyze the following data: {data}
Your job is to answer the following question: {prompt}
"""

In [16]:
class DataAnalysis(BaseModel):
    analysis: str = Field(..., description="The analysis of the data")

In [17]:
# code for tool 2
@tracer.tool()
def analyze_sales_data(prompt: str, data: str) -> str:
    """Implementation of AI-powered sales data analysis"""
    formatted_prompt = DATA_ANALYSIS_PROMPT.format(data=data, prompt=prompt)

    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
        response_model=DataAnalysis
    )
    
    analysis = response.analysis
    return analysis if analysis else "No analysis could be generated"

In [18]:
# print(analyze_sales_data(prompt="what trends do you see in this data", 
#                          data=example_data))

### Tool 3: Data Visualization

In [19]:
# prompt template for step 1 of tool 3
CHART_CONFIGURATION_PROMPT = """
Generate a chart configuration based on this data: {data}
The goal is to show: {visualization_goal}
"""

In [20]:
# class defining the response format of step 1 of tool 3
class VisualizationConfig(BaseModel):
    chart_type: str = Field(..., description="Type of chart to generate")
    x_axis: str = Field(..., description="Name of the x-axis column")
    y_axis: str = Field(..., description="Name of the y-axis column")
    title: str = Field(..., description="Title of the chart")

In [21]:
# code for step 1 of tool 3
@tracer.chain()
def extract_chart_config(data: str, visualization_goal: str) -> dict:
    """Generate chart visualization configuration
    
    Args:
        data: String containing the data to visualize
        visualization_goal: Description of what the visualization should show
        
    Returns:
        Dictionary containing line chart configuration
    """
    formatted_prompt = CHART_CONFIGURATION_PROMPT.format(data=data,
                                                         visualization_goal=visualization_goal)
    
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
        response_model=VisualizationConfig,
    )
    
    try:
        # Extract axis and title info from response
        content = response
        
        # Return structured chart config
        return {
            "chart_type": content.chart_type,
            "x_axis": content.x_axis,
            "y_axis": content.y_axis,
            "title": content.title,
            "data": data
        }
    except Exception:
        return {
            "chart_type": "line", 
            "x_axis": "date",
            "y_axis": "value",
            "title": visualization_goal,
            "data": data
        }

In [22]:
# prompt template for step 2 of tool 3
CREATE_CHART_PROMPT = """
Write python code to create a chart based on the following configuration.
Only return the code, no other text.
config: {config}
"""

In [23]:
class VisualizationCode(BaseModel):
    code: str = Field(..., description="The python code to generate the visualization. Use the data from the config")

In [24]:
# code for step 2 of tool 3
@tracer.chain()
def create_chart(config: dict) -> str:
    
    """Create a chart based on the configuration"""
    formatted_prompt = CREATE_CHART_PROMPT.format(config=config)
    
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
        response_model=VisualizationCode
    )
    
    code = response.code
    code = code.replace("```python", "").replace("```", "")
    code = code.strip()
    
    return code

In [25]:
# code for tool 3
@tracer.tool()
def generate_visualization(data: str, visualization_goal: str) -> str:
    """Generate a visualization based on the data and goal"""
    config = extract_chart_config(data, visualization_goal)
    code = create_chart(config)
    return code

In [26]:
# code = generate_visualization(example_data, 
#                               "A bar chart of sales by product SKU. Put the product SKU on the x-axis and the sales on the y-axis.")
# print(code)

In [27]:
# exec(code)

In [28]:
class FinalResponse(BaseModel):
    code: Optional[str] = Field(None, description="The python code to execute")
    observed_trends: Optional[str] = Field(None, description="The observed trends in the data")
    natural_language_response: Optional[str] = Field(None, description="The natural language response to the user's question")

In [29]:
# prompt template for step 1 of tool 3
FINAL_RESPONSE_PROMPT = """
Generate a final, structured response for the user.
The context is: {messages}
"""

In [30]:
def generate_final_response(messages: list) -> str:
    """Create a final response to the user's question"""
    formatted_prompt = FINAL_RESPONSE_PROMPT.format(messages=messages)
    
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": formatted_prompt}],
        response_model=FinalResponse
    )
    
    return response

In [31]:
# Define tools/functions that can be called by the model
tools = [
    {
        "type": "function",
        "function": {
            "name": "lookup_sales_data",
            "description": "Look up data from Store Sales Price Elasticity Promotions dataset",
            "parameters": {
                "type": "object",
                "properties": {
                    "prompt": {"type": "string", "description": "The unchanged prompt that the user provided."}
                },
                "required": ["prompt"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "analyze_sales_data", 
            "description": "Analyze sales data to extract insights",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {"type": "string", "description": "The lookup_sales_data tool's output."},
                    "prompt": {"type": "string", "description": "The unchanged prompt that the user provided."}
                },
                "required": ["data", "prompt"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "generate_visualization",
            "description": "Generate Python code to create data visualizations",
            "parameters": {
                "type": "object", 
                "properties": {
                    "data": {"type": "string", "description": "The lookup_sales_data tool's output."},
                    "visualization_goal": {"type": "string", "description": "The goal of the visualization."}
                },
                "required": ["data", "visualization_goal"]
            }
        }
    },
]

# Dictionary mapping function names to their implementations
tool_implementations = {
    "lookup_sales_data": lookup_sales_data,
    "analyze_sales_data": analyze_sales_data, 
    "generate_visualization": generate_visualization,
}

In [32]:
# code for executing the tools returned in the model's response
@tracer.chain()
def handle_tool_calls(tool_calls, messages) -> tuple[list, str]:
    
    for tool_call in tool_calls:   
        function = tool_implementations[tool_call.function.name]
        function_args = json.loads(tool_call.function.arguments)
        result = function(**function_args)
        messages.append({"role": "tool", "content": result, "tool_call_id": tool_call.id})
        print(tool_call.function.name)
        
    return messages

In [33]:
SYSTEM_PROMPT = """
You are a helpful assistant that can answer questions about the Store Sales Price Elasticity Promotions dataset.
"""

In [34]:
def run_agent(messages):
    print("Running agent with messages:", messages)

    if isinstance(messages, str):
        messages = [{"role": "user", "content": messages}]
        
    # Check and add system prompt if needed
    if not any(
            isinstance(message, dict) and message.get("role") == "system" for message in messages
        ):
            system_prompt = {"role": "system", "content": SYSTEM_PROMPT}
            messages.append(system_prompt)

    while True:
        # Router Span
        print("Starting router call span")
        with tracer.start_as_current_span(
            "router_call", openinference_span_kind="chain",
        ) as span:
            span.set_input(value=messages)
            print("Making router call to OpenAI")
            response = tool_calling_client.chat.completions.create(
                model=MODEL,
                messages=messages,
                tools=tools,
            )
            messages.append(response.choices[0].message)
            tool_calls = response.choices[0].message.tool_calls
            print("Received response with tool calls:", bool(tool_calls))
            span.set_status(StatusCode.OK)

            # if the model decides to call function(s), call handle_tool_calls
            if tool_calls:
                print("Starting tool calls span")
                messages = handle_tool_calls(tool_calls, messages)
                span.set_output(value=tool_calls)
            else:
                print("No tool calls, returning final response")
                response = generate_final_response(messages)
                span.set_output(value=response)
                return response

In [35]:
def start_main_span(messages):
    print("Starting main span with messages:", messages)
    
    with tracer.start_as_current_span(
        "AgentRun", openinference_span_kind="agent"
    ) as span:
        span.set_input(value=messages)
        ret = run_agent(messages)
        print("Main span completed with return value:", ret)
        span.set_output(value=ret)
        span.set_status(StatusCode.OK)
        return ret

In [None]:
query = "Show me all the sales for store 1320 on November 1st, 2021 and visualize them"
result = start_main_span([{"role": "user", 
                           "content": query}])

In [None]:
result

In [None]:
result.observed_trends

In [None]:
exec(result.code)