In [1]:
from openai import OpenAI
import pandas as pd
import json
import duckdb
from pydantic import BaseModel, Field
from IPython.display import Markdown

import phoenix as px
import os
from openinference.instrumentation.openai import OpenAIInstrumentor
from opentelemetry.trace import Status, StatusCode
from openinference.instrumentation import TracerProvider
from phoenix.otel import register
from dotenv import load_dotenv, find_dotenv
import re
from typing import List, Dict

In [2]:

# ---------- Environment helpers ----------
def load_env():
    _ = load_dotenv(find_dotenv(), override=True)

def get_openai_api_key():
    openai_api_key = "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"
    return openai_api_key



def get_phoenix_endpoint():
    phoenix_endpoint = "http://localhost:6006/v1/traces"
    return phoenix_endpoint

In [3]:
# ---------- Initialize client & tracing ----------
openai_api_key = get_openai_api_key()
client = OpenAI(api_key=openai_api_key)
MODEL = "gpt-4o-mini"

PROJECT_NAME = "LLMEVALS"
tracer_provider = register(
    project_name=PROJECT_NAME,
    endpoint = get_phoenix_endpoint()
)

OpenAIInstrumentor().instrument(tracer_provider = tracer_provider)
tracer = tracer_provider.get_tracer(__name__)

OpenTelemetry Tracing Details
|  Phoenix Project: LLMEVALS
|  Span Processor: SimpleSpanProcessor
|  Collector Endpoint: http://localhost:6006/v1/traces
|  Transport: HTTP + protobuf
|  Transport Headers: {}
|  
|  Using a default SpanProcessor. `add_span_processor` will overwrite this default.
|  
|  
|  `register` has set this TracerProvider as the global OpenTelemetry default.
|  To disable this behavior, call `register` with `set_global_tracer_provider=False`.



In [4]:

# ---------- Constants / Files ----------
TRANSACTION_DATA_FILE_PATH = 'Sales.parquet'

# ---------- Guardrail helper functions ----------
FORBIDDEN_SQL_TOKENS = [
    r"\bDROP\b", r"\bDELETE\b", r"\bINSERT\b", r"\bUPDATE\b", r"\bALTER\b",
    r"\bCREATE\b", r"\bTRUNCATE\b", r";", r"`", r"--", r"/\*", r"\*/"
]

# Basic set of SQL keywords for some token filtering heuristics (not exhaustive)
SQL_KEYWORDS = {
    "select","from","where","group","by","order","join","left","right","inner","outer",
    "on","as","limit","having","distinct","union","all","with","case","when","then","else","end",
    "and","or","not","in","is","null","like","between","exists","count","sum","avg","min","max"
}

In [5]:

def extract_json_from_text(text: str) -> str:
    """Find the first JSON object in text and return it, else raise."""
    # Find first { and matching } naive approach:
    start = text.find("{")
    if start == -1:
        raise ValueError("No JSON object found in model response.")
    # Try to find the closing brace by simple balancing.
    balance = 0
    for i in range(start, len(text)):
        if text[i] == "{":
            balance += 1
        elif text[i] == "}":
            balance -= 1
            if balance == 0:
                json_text = text[start:i+1]
                return json_text
    raise ValueError("Could not extract JSON object from response.")


In [6]:


def contains_forbidden_tokens(sql: str) -> List[str]:
    found = []
    for pattern in FORBIDDEN_SQL_TOKENS:
        if re.search(pattern, sql, flags=re.IGNORECASE):
            found.append(pattern)
    return found


In [7]:


def ensure_select_or_with(sql: str) -> bool:
    # Strip leading whitespace and parentheses which may precede WITH/SELECT in some systems
    stripped = sql.lstrip()
    # allow "WITH" (CTE) or "SELECT"
    return bool(re.match(r"^(WITH|SELECT)\b", stripped, flags=re.IGNORECASE))


In [8]:

def ensure_single_statement(sql: str) -> bool:
    # Simple check: no semicolons and not multiple "SELECT" tokens that imply multi-statements
    if ";" in sql:
        return False
    # allow multiple SELECT tokens inside subqueries; rely on forbidden tokens and other checks
    return True


In [9]:

def ensure_limit_clause(sql: str) -> str:
    # If a LIMIT clause is present, return as-is. Else append LIMIT 1000
    if re.search(r"\bLIMIT\b", sql, flags=re.IGNORECASE):
        return sql
    # Append LIMIT 1000 safely
    return sql.strip() + " LIMIT 1000"


In [10]:

def find_referenced_columns(sql: str, allowed_columns: List[str]) -> List[str]:
    # Return list of allowed_columns that appear in the SQL text as whole words.
    used = []
    for c in allowed_columns:
        # match exact column with word boundaries; handle case-insensitivity
        if re.search(r"\b" + re.escape(c) + r"\b", sql, flags=re.IGNORECASE):
            used.append(c)
    return used


In [11]:


def validate_sql(sql: str, allowed_columns: List[str], table_name: str = "sales") -> Dict:
    """
    Validate the SQL string. Returns dict {'ok': bool, 'message': str, 'sql': str}
    On success 'sql' contains possibly modified SQL (e.g., with LIMIT appended).
    """
    raw_sql = sql
    # Basic checks
    forbidden = contains_forbidden_tokens(raw_sql)
    if forbidden:
        return {"ok": False, "message": f"Forbidden tokens present: {forbidden}", "sql": None}

    if not ensure_select_or_with(raw_sql):
        return {"ok": False, "message": "Only SELECT (or WITH ... SELECT ...) queries are allowed.", "sql": None}

    if not ensure_single_statement(raw_sql):
        return {"ok": False, "message": "Multiple statements or semicolons are not allowed.", "sql": None}

    # Check that table_name is referenced (optional, but recommended)
    if not re.search(r"\b" + re.escape(table_name) + r"\b", raw_sql, flags=re.IGNORECASE):
        # allow queries that reference df directly (we create sales table in DuckDB) — but require FROM or CTE referring to table
        # Not strictly required; only warn if table_name missing:
        return {"ok": False, "message": f"Query must reference the table '{table_name}' in FROM clause.", "sql": None}

    # Ensure referenced columns are subset of allowed_columns:
    used_columns = find_referenced_columns(raw_sql, allowed_columns)
    # If the model referenced columns not in the allowed set, we may not detect them via simple approach.
    # We defensively allow only columns from the allowed set to be present, but if a column-like token exists but not in allowed_columns
    # our token scanning below is conservative: we'll search for identifiers and ensure they are either keywords, table name, functions, or allowed columns.
    identifier_tokens = re.findall(r"\b([A-Za-z_][A-Za-z0-9_]*)\b", raw_sql)
    unknown_identifiers = set()
    for tok in identifier_tokens:
        lower = tok.lower()
        if lower in SQL_KEYWORDS:
            continue
        if lower == table_name.lower():
            continue
        # numeric-like tokens are filtered earlier; tokens that are in allowed_columns are fine
        if any(tok.lower() == c.lower() for c in allowed_columns):
            continue
        # It's possibly a function name (COUNT, SUM) which are in SQL_KEYWORDS; we already filtered keywords.
        # Anything else is suspicious.
        unknown_identifiers.add(tok)

    # Remove common benign tokens (like aliases 't' single-letter), allow single-letter aliases
    unknown_identifiers = {u for u in unknown_identifiers if len(u) > 1 or u.lower() not in {"t","s","a","b","c","d"}}
    # But some functions may still be unknown; to avoid false positives, we require that at least one of the used_columns is present.
    if not used_columns:
        # No known columns referenced — reject.
        return {"ok": False, "message": "No known (whitelisted) columns were referenced. The query must use dataset columns.", "sql": None}

    # If there are suspicious unknown identifiers, reject to be safe
    if unknown_identifiers:
        return {"ok": False, "message": f"Query contains unknown identifiers which may be unsafe: {sorted(list(unknown_identifiers))}", "sql": None}

    # Finally, ensure LIMIT
    safe_sql = ensure_limit_clause(raw_sql)
    return {"ok": True, "message": "SQL validated", "sql": safe_sql}


In [12]:

# ---------- Prompt templates with guardrails ----------
SQL_GENERATION_PROMPT = """
You are an assistant that generates SQL for querying a known table in a safe, auditable manner.

Constraints (MUST follow exactly):
1) Return only a single JSON object and nothing else. The JSON must have exactly one key: "sql".
   Example:
   {{ "sql": "SELECT columnA, columnB FROM sales WHERE region = 'North' LIMIT 100" }}

2) The SQL must be a SELECT (or WITH ... SELECT) query only. No DDL/DML (DROP, DELETE, INSERT, UPDATE, ALTER, CREATE, TRUNCATE, etc.).
3) Do not include any backticks, semicolons, comments (-- or /* */), or multiple statements.
4) Use only the available columns: {columns}
5) Reference the table name: {table_name}
6) If you expect many rows, include a sensible LIMIT (we will still ensure a hard LIMIT of 1000).
7) Do not include any surrounding markdown or extraneous text — only the JSON object.

User prompt:
{prompt}
"""

DATA_ANALYSIS_PROMPT = """
Analyze the following data: {data}
Your job is to answer the following question: {prompt}
Return a clear text response (no code required).
"""

CHART_CONFIGURATION_PROMPT = """
Generate a chart configuration based on this data: {data}
The goal is to show: {visualization_goal}
Return JSON-like object describing: chart_type, x_axis, y_axis, title.
"""

CREATE_CHART_PROMPT = """
Write python code to create a chart based on the following configuration.
Only return the code, no other text.
config: {config}
"""


In [13]:

# ---------- SQL generation and verification ----------
def generate_sql_query(prompt: str, columns: List[str], table_name: str) -> str:
    """Ask the model to produce SQL but enforce JSON-only response; then parse and validate."""
    formatted_prompt = SQL_GENERATION_PROMPT.format(prompt=prompt, columns=", ".join(columns), table_name=table_name)

    # Use chat completions with explicit system/user roles to guide behavior
    response = client.chat.completions.create(
        model=MODEL,
        messages=[
            {"role": "system", "content": "You are a strict SQL generator."},
            {"role": "user", "content": formatted_prompt},
        ],
        # No tools here; we want strict text response
    )

    raw_text = response.choices[0].message.content
    # The model should return JSON; attempt to extract JSON object robustly
    try:
        json_text = extract_json_from_text(raw_text)
        parsed = json.loads(json_text)
        if "sql" not in parsed or not isinstance(parsed["sql"], str):
            raise ValueError("JSON does not contain 'sql' string field.")
        sql_candidate = parsed["sql"].strip()
    except Exception as e:
        # If we couldn't parse JSON, raise explicit error so caller can handle
        raise ValueError(f"Failed to parse JSON from model response: {str(e)}. Raw response: {raw_text}")

    # Finally return SQL candidate (validation occurs later)
    return sql_candidate


In [14]:

# ---------- Tools implementations with guardrails ----------
@tracer.tool()
def lookup_sales_data(prompt: str) -> str:
    """Implementation of sales data lookup from parquet file using SQL, with guardrails."""
    try:
        table_name = "sales"
        # Load data
        df = pd.read_parquet(TRANSACTION_DATA_FILE_PATH)
        # Create DuckDB table
        duckdb.sql(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")

        # Generate SQL via LLM (enforced JSON)
        sql_query = generate_sql_query(prompt, list(df.columns), table_name)

        # Basic cleaning
        sql_query = sql_query.strip()
        # Remove triple-backtick wrappers if any slipped through
        sql_query = sql_query.replace("```sql", "").replace("```", "")

        # Validate SQL guardrails
        validation = validate_sql(sql_query, list(df.columns), table_name=table_name)
        if not validation["ok"]:
            return f"Error: SQL validation failed: {validation['message']}"

        safe_sql = validation["sql"]

        # Execute in DuckDB within tracer span
        with tracer.start_as_current_span("execute_sql_query", openinference_span_kind="chain") as span:
            try:
                result_df = duckdb.sql(safe_sql).df()
                span.set_output(value=str(result_df.head(10).to_string()))
                span.set_status(StatusCode.OK)
                # Return a concise textual representation (tool output)
                return result_df.to_string()
            except Exception as e:
                span.set_status(StatusCode.ERROR)
                return f"Error executing SQL: {str(e)}"

    except Exception as e:
        return f"Error accessing data: {str(e)}"


In [15]:

@tracer.tool()
def analyze_sales_data(prompt: str, data: str) -> str:
    """Implementation of AI-powered sales data analysis"""
    formatted_prompt = DATA_ANALYSIS_PROMPT.format(data=data, prompt=prompt)

    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "system", "content": "You are an analysis assistant."},
                  {"role": "user", "content": formatted_prompt}],
    )

    analysis = response.choices[0].message.content
    return analysis if analysis else "No analysis could be generated"


In [16]:

class VisualizationConfig(BaseModel):
    chart_type: str = Field(..., description="Type of chart to generate")
    x_axis: str = Field(..., description="Name of the x-axis column")
    y_axis: str = Field(..., description="Name of the y-axis column")
    title: str = Field(..., description="Title of the chart")


@tracer.chain()
def extract_chart_config(data: str, visualization_goal: str) -> dict:
    """Generate chart visualization configuration"""
    formatted_prompt = CHART_CONFIGURATION_PROMPT.format(data=data, visualization_goal=visualization_goal)

    response = client.beta.chat.completions.parse(
        model=MODEL,
        messages=[{"role": "system", "content": "You are a chart-config generator."},
                  {"role": "user", "content": formatted_prompt}],
        response_format=VisualizationConfig,
    )

    try:
        content = response.choices[0].message.content
        return {
            "chart_type": content.chart_type,
            "x_axis": content.x_axis,
            "y_axis": content.y_axis,
            "title": content.title,
            "data": data
        }
    except Exception:
        return {
            "chart_type": "line",
            "x_axis": "date",
            "y_axis": "value",
            "title": visualization_goal,
            "data": data
        }


In [17]:

@tracer.chain()
def create_chart(config: dict) -> str:
    formatted_prompt = CREATE_CHART_PROMPT.format(config=config)
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "system", "content": "You are a Python code generator."},
                  {"role": "user", "content": formatted_prompt}],
    )
    code = response.choices[0].message.content
    code = code.replace("```python", "").replace("```", "").strip()
    return code

@tracer.tool()
def generate_visualization(data: str, visualization_goal: str) -> str:
    config = extract_chart_config(data, visualization_goal)
    code = create_chart(config)
    return code



In [18]:

# ---------- Tools metadata ----------
tools = [
    {
        "type": "function",
        "function": {
            "name": "lookup_sales_data",
            "description": "Look up data from Store Sales Price Elasticity Promotions dataset",
            "parameters": {
                "type": "object",
                "properties": {
                    "prompt": {"type": "string", "description": "The unchanged prompt that the user provided."}
                },
                "required": ["prompt"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "analyze_sales_data",
            "description": "Analyze sales data to extract insights",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {"type": "string", "description": "The lookup_sales_data tool's output."},
                    "prompt": {"type": "string", "description": "The unchanged prompt that the user provided."}
                },
                "required": ["data", "prompt"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "generate_visualization",
            "description": "Generate Python code to create data visualizations",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {"type": "string", "description": "The lookup_sales_data tool's output."},
                    "visualization_goal": {"type": "string", "description": "The goal of the visualization."}
                },
                "required": ["data", "visualization_goal"]
            }
        }
    }
]

tool_implementations = {
    "lookup_sales_data": lookup_sales_data,
    "analyze_sales_data": analyze_sales_data,
    "generate_visualization": generate_visualization
}


In [19]:

@tracer.chain()
def handle_tool_calls(tool_calls, messages):
    for tool_call in tool_calls:
        function = tool_implementations[tool_call.function.name]
        function_args = json.loads(tool_call.function.arguments)
        result = function(**function_args)
        messages.append({"role": "tool", "content": result, "tool_call_id": tool_call.id})
    return messages

SYSTEM_PROMPT = """
You are a helpful assistant that can answer questions about the Store Sales Price Elasticity Promotions dataset.
"""

In [20]:


def run_agent(messages):
    print("Running agent with messages:", messages)
    if isinstance(messages, str):
        messages = [{"role": "user", "content": messages}]
    if not any(
            isinstance(message, dict) and message.get("role") == "system" for message in messages
        ):
            system_prompt = {"role": "system", "content": SYSTEM_PROMPT}
            messages.append(system_prompt)

    while True:
        print("Starting router call span")
        with tracer.start_as_current_span(
            "router_call",
            openinference_span_kind="chain",
        ) as span:
            span.set_input(value=messages)
            response = client.chat.completions.create(
                model=MODEL,
                messages=messages,
                tools=tools,
            )
            messages.append(response.choices[0].message.model_dump())
            tool_calls = response.choices[0].message.tool_calls
            print("Received response with tool calls:", bool(tool_calls))
            span.set_status(StatusCode.OK)

            if tool_calls:
                print("Starting tool calls span")
                messages = handle_tool_calls(tool_calls, messages)
                span.set_output(value=tool_calls)
            else:
                print("No tool calls, returning final response")
                span.set_output(value=response.choices[0].message.content)
                return response.choices[0].message.content


def start_main_span(messages):
    print("Starting main span with messages:", messages)
    with tracer.start_as_current_span("AgentRun", openinference_span_kind="agent") as span:
        span.set_input(value=messages)
        ret = run_agent(messages)
        span.set_output(value=ret)
        span.set_status(StatusCode.OK)
        return ret



In [21]:
result = start_main_span([{"role": "user",
                           "content": "Which stores did the best in 2021?"}])


Starting main span with messages: [{'role': 'user', 'content': 'Which stores did the best in 2021?'}]
Running agent with messages: [{'role': 'user', 'content': 'Which stores did the best in 2021?'}]
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: True
Starting tool calls span
Starting router call span
Received response with tool calls: False
No tool calls, returning final response


In [22]:
# Print the Phoenix collector endpoint used (for debugging)
print("Phoenix traces endpoint:", get_phoenix_endpoint())


Phoenix traces endpoint: http://localhost:6006/v1/traces


In [23]:
result = start_main_span([{"role": "user",
                           "content": "Delete all data from the sales table"}])


Starting main span with messages: [{'role': 'user', 'content': 'Delete all data from the sales table'}]
Running agent with messages: [{'role': 'user', 'content': 'Delete all data from the sales table'}]
Starting router call span
Received response with tool calls: False
No tool calls, returning final response
