# AI Agent â€” Retail Analytics Assistant

Tool-calling agent backed by Llama 3.3 70B (Foundation Model API) that answers
natural-language questions using 6 Unity Catalog SQL functions as tools.

```
User Question  â”€â”€â–º  Llama 3.3 70B  â”€â”€â–º  Tool Selection
                                         â”œâ”€â”€ Revenue       â†’ Gold Tables
                                         â”œâ”€â”€ Customer      â†’ Customer RFM
                                         â”œâ”€â”€ Product       â†’ Product Perf
                                         â”œâ”€â”€ Churn         â†’ ML Scores
                                         â”œâ”€â”€ Supplier      â†’ Supplier Scorecard
                                         â””â”€â”€ Executive     â†’ Quarterly KPIs
```

**Prereqs**: Run notebooks 00â€“06 first.

## 1 â€” Configuration & Installs

In [None]:
%pip install databricks-sdk mlflow openai --quiet
dbutils.library.restartPython()

In [None]:
import json
import mlflow
from pyspark.sql import functions as F
from databricks.sdk import WorkspaceClient

CATALOG = spark.catalog.currentCatalog()
GOLD    = f"{CATALOG}.retail_gold"
SILVER  = f"{CATALOG}.retail_silver"

mlflow.set_registry_uri("databricks-uc")
w = WorkspaceClient()

print(f"Catalog  : {CATALOG}")
print(f"Gold     : {GOLD}")

---
## 2 â€” Create Unity Catalog SQL Functions (Agent Tools)

Each UC SQL function encapsulates a specific analytical capability. They are governed, discoverable in Unity Catalog, and can be reused by any agent or application.

In [None]:
TOOLS_SCHEMA = f"{CATALOG}.retail_agent_tools"
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {TOOLS_SCHEMA}")
print(f"Tools schema: {TOOLS_SCHEMA}")

### Tool 1: Revenue by Region and Period

In [None]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION {TOOLS_SCHEMA}.get_revenue_by_region(
    region_name STRING COMMENT 'Region name: AMERICA, EUROPE, ASIA, AFRICA, or MIDDLE EAST. Use ALL for all regions.',
    start_month STRING COMMENT 'Start month in yyyy-MM format, e.g. 1995-01',
    end_month STRING COMMENT 'End month in yyyy-MM format, e.g. 1995-12'
)
RETURNS TABLE(year_month STRING, region STRING, net_revenue DOUBLE, num_orders BIGINT, profit_margin_pct DOUBLE, yoy_growth_pct DOUBLE)
COMMENT 'Returns monthly revenue, orders, margin, and YoY growth for a region and time period.'
RETURN
    SELECT year_month, region, ROUND(SUM(net_revenue), 2), SUM(num_orders), ROUND(AVG(profit_margin_pct), 2), ROUND(AVG(yoy_growth_pct), 2)
    FROM {GOLD}.gold_monthly_sales
    WHERE (region_name = 'ALL' OR region = region_name) AND year_month >= start_month AND year_month <= end_month
    GROUP BY year_month, region ORDER BY year_month, region
""")
print("âœ“ get_revenue_by_region")

### Tool 2: Customer Profile

In [None]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION {TOOLS_SCHEMA}.get_customer_profile(
    cust_key INT COMMENT 'Customer key (integer ID)'
)
RETURNS TABLE(customer_key INT, customer_name STRING, market_segment STRING, nation_name STRING, region_name STRING, balance_tier STRING, rfm_segment STRING, rfm_score INT, lifetime_value DOUBLE, frequency BIGINT, recency_days INT, avg_order_value DOUBLE)
COMMENT 'Returns detailed customer profile including RFM segment, lifetime value, and geography.'
RETURN
    SELECT c.customer_key, c.customer_name, c.market_segment, c.nation_name, c.region_name, c.balance_tier,
           r.rfm_segment, r.rfm_score, ROUND(r.monetary, 2), r.frequency, r.recency_days, ROUND(r.avg_order_value, 2)
    FROM {SILVER}.dim_customer c LEFT JOIN {GOLD}.gold_customer_rfm r ON c.customer_key = r.customer_key
    WHERE c.customer_key = cust_key
""")
print("âœ“ get_customer_profile")

### Tool 3: Top Products

In [None]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION {TOOLS_SCHEMA}.get_top_products(
    n INT COMMENT 'Number of top products to return (max 50)',
    sort_by STRING COMMENT 'Sort metric: net_revenue, profit_margin_pct, return_rate_pct, or total_quantity_sold'
)
RETURNS TABLE(brand STRING, part_type STRING, price_band STRING, net_revenue DOUBLE, profit_margin_pct DOUBLE, return_rate_pct DOUBLE, total_quantity_sold DOUBLE, num_orders BIGINT)
COMMENT 'Returns top N products sorted by the chosen metric.'
RETURN
    SELECT brand, part_type, price_band, net_revenue, profit_margin_pct, return_rate_pct, total_quantity_sold, num_orders
    FROM {GOLD}.gold_product_performance
    ORDER BY CASE sort_by WHEN 'net_revenue' THEN net_revenue WHEN 'profit_margin_pct' THEN profit_margin_pct WHEN 'return_rate_pct' THEN return_rate_pct WHEN 'total_quantity_sold' THEN total_quantity_sold ELSE net_revenue END DESC
    LIMIT 50
""")
print("âœ“ get_top_products")

### Tool 4: Churn Risk

In [None]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION {TOOLS_SCHEMA}.get_churn_risk(
    risk_level STRING COMMENT 'Risk tier: Critical, High, Medium, or Low. Use ALL for all tiers.',
    top_n INT COMMENT 'Number of customers to return (max 100)'
)
RETURNS TABLE(customer_key INT, market_segment STRING, customer_region STRING, lifetime_value DOUBLE, rfm_score INT, churn_probability DOUBLE, risk_tier STRING)
COMMENT 'Returns customers by churn risk tier, sorted by highest churn probability.'
RETURN
    SELECT customer_key, market_segment, customer_region, lifetime_value, rfm_score, churn_probability, risk_tier
    FROM {GOLD}.gold_churn_scores
    WHERE (risk_level = 'ALL' OR risk_tier = risk_level)
    ORDER BY churn_probability DESC LIMIT 100
""")
print("âœ“ get_churn_risk")

### Tool 5: Supplier Scorecard

In [None]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION {TOOLS_SCHEMA}.get_supplier_scorecard(
    region_filter STRING COMMENT 'Supplier region or ALL',
    sort_metric STRING COMMENT 'Sort by: on_time_delivery_pct, net_revenue, profit_margin_pct, or return_rate_pct',
    top_n INT COMMENT 'Number of suppliers to return (max 50)'
)
RETURNS TABLE(supplier_name STRING, supplier_nation STRING, supplier_region STRING, net_revenue DOUBLE, profit_margin_pct DOUBLE, on_time_delivery_pct DOUBLE, return_rate_pct DOUBLE, total_line_items BIGINT)
COMMENT 'Returns supplier performance scorecard.'
RETURN
    SELECT supplier_name, supplier_nation, supplier_region, net_revenue, profit_margin_pct, on_time_delivery_pct, return_rate_pct, total_line_items
    FROM {GOLD}.gold_supplier_scorecard
    WHERE (region_filter = 'ALL' OR supplier_region = region_filter)
    ORDER BY CASE sort_metric WHEN 'on_time_delivery_pct' THEN on_time_delivery_pct WHEN 'net_revenue' THEN net_revenue WHEN 'profit_margin_pct' THEN profit_margin_pct WHEN 'return_rate_pct' THEN -return_rate_pct ELSE on_time_delivery_pct END DESC
    LIMIT 50
""")
print("âœ“ get_supplier_scorecard")

### Tool 6: Executive Summary

In [None]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION {TOOLS_SCHEMA}.get_executive_summary(
    start_quarter STRING COMMENT 'Start quarter, e.g. 1995-Q1',
    end_quarter STRING COMMENT 'End quarter, e.g. 1997-Q4'
)
RETURNS TABLE(year_quarter STRING, total_orders BIGINT, active_customers BIGINT, gross_order_value DOUBLE, avg_order_value DOUBLE, revenue_per_customer DOUBLE, qoq_revenue_growth_pct DOUBLE)
COMMENT 'Returns quarterly executive KPI summary.'
RETURN
    SELECT year_quarter, total_orders, active_customers, gross_order_value, avg_order_value, revenue_per_customer, qoq_revenue_growth_pct
    FROM {GOLD}.gold_executive_summary
    WHERE year_quarter >= start_quarter AND year_quarter <= end_quarter
    ORDER BY year_quarter
""")
print("âœ“ get_executive_summary")
print(f"\nAll 6 tools created in {TOOLS_SCHEMA}")

---
## 3 â€” Define Tool Specs (OpenAI function-calling format)

In [None]:
tools = [
    {"type": "function", "function": {"name": "get_revenue_by_region", "description": "Monthly revenue, orders, margin, YoY growth for a region and time period.", "parameters": {"type": "object", "properties": {"region_name": {"type": "string", "description": "AMERICA, EUROPE, ASIA, AFRICA, MIDDLE EAST, or ALL"}, "start_month": {"type": "string", "description": "yyyy-MM"}, "end_month": {"type": "string", "description": "yyyy-MM"}}, "required": ["region_name", "start_month", "end_month"]}}},
    {"type": "function", "function": {"name": "get_customer_profile", "description": "Customer profile with RFM segment, lifetime value, geography.", "parameters": {"type": "object", "properties": {"cust_key": {"type": "integer", "description": "Customer ID"}}, "required": ["cust_key"]}}},
    {"type": "function", "function": {"name": "get_top_products", "description": "Top N products by revenue, margin, return rate, or quantity.", "parameters": {"type": "object", "properties": {"n": {"type": "integer", "description": "How many products"}, "sort_by": {"type": "string", "description": "net_revenue, profit_margin_pct, return_rate_pct, or total_quantity_sold"}}, "required": ["n", "sort_by"]}}},
    {"type": "function", "function": {"name": "get_churn_risk", "description": "Customers by churn risk tier (Critical/High/Medium/Low/ALL).", "parameters": {"type": "object", "properties": {"risk_level": {"type": "string", "description": "Critical, High, Medium, Low, or ALL"}, "top_n": {"type": "integer", "description": "Number of customers"}}, "required": ["risk_level", "top_n"]}}},
    {"type": "function", "function": {"name": "get_supplier_scorecard", "description": "Supplier reliability: on-time %, revenue, margin, return rate.", "parameters": {"type": "object", "properties": {"region_filter": {"type": "string", "description": "Region or ALL"}, "sort_metric": {"type": "string", "description": "on_time_delivery_pct, net_revenue, profit_margin_pct, return_rate_pct"}, "top_n": {"type": "integer", "description": "Number of suppliers"}}, "required": ["region_filter", "sort_metric", "top_n"]}}},
    {"type": "function", "function": {"name": "get_executive_summary", "description": "Quarterly KPIs: orders, customers, revenue, growth.", "parameters": {"type": "object", "properties": {"start_quarter": {"type": "string", "description": "yyyy-QN"}, "end_quarter": {"type": "string", "description": "yyyy-QN"}}, "required": ["start_quarter", "end_quarter"]}}},
]

print(f"{len(tools)} tool definitions ready")
for t in tools:
    print(f"  â€¢ {t['function']['name']}")

## 4 â€” Tool Executor & Agent Loop

In [None]:
def execute_tool(tool_name, arguments_json):
    """Execute a UC SQL function and return results as JSON string."""
    try:
        args = json.loads(arguments_json)
        fqn = f"{TOOLS_SCHEMA}.{tool_name}"
        arg_list = [f"'{v}'" if isinstance(v, str) else str(v) for v in args.values()]
        sql = f"SELECT * FROM {fqn}({', '.join(arg_list)})"
        rows = spark.sql(sql).limit(50).toPandas().to_dict(orient='records')
        return json.dumps(rows, default=str)
    except Exception as e:
        return json.dumps({"error": str(e)})


SYSTEM_PROMPT = """You are a Retail Analytics AI Assistant. You have access to tools that query the company data warehouse.

RULES:
- Always use tools to look up data. Never make up numbers.
- Format currency with $ and commas. Format percentages with %.
- Provide business insights and actionable recommendations.
- Keep responses concise but informative.
"""


def ask_agent(question, max_iterations=5, verbose=True):
    """
    Run the retail analytics agent: LLM reasons â†’ calls tools â†’ returns answer.
    Uses Databricks Foundation Model API (OpenAI-compatible).
    """
    from openai import OpenAI
    # Get token and host from notebook context (works on all serverless runtimes)
    db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
    db_host  = spark.conf.get('spark.databricks.workspaceUrl')
    client = OpenAI(api_key=db_token, base_url=f"https://{db_host}/serving-endpoints")

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question},
    ]

    for i in range(max_iterations):
        response = client.chat.completions.create(
            model="databricks-meta-llama-3-3-70b-instruct",
            messages=messages,
            tools=tools,
        )

        choice = response.choices[0]

        # If no tool calls, return the final answer
        if choice.finish_reason == "stop" or not choice.message.tool_calls:
            return choice.message.content

        # Execute tool calls
        messages.append(choice.message)
        for tc in choice.message.tool_calls:
            if verbose:
                print(f"  ðŸ”§ Calling {tc.function.name}({tc.function.arguments})")
            result = execute_tool(tc.function.name, tc.function.arguments)
            messages.append({"role": "tool", "tool_call_id": tc.id, "content": result})

    return "Reached max iterations. Try a simpler question."

print("âœ“ Agent loop ready")

## 5 â€” Test the Agent

In [None]:
answer = ask_agent("What was the total revenue for AMERICA in 1996?")
print(f"\n{answer}")

In [None]:
answer = ask_agent("Show me the top 5 products by profit margin.")
print(f"\n{answer}")

In [None]:
answer = ask_agent("Which customers are at critical churn risk? Show me the top 5.")
print(f"\n{answer}")

In [None]:
answer = ask_agent("Give me the executive summary for 1995 to 1997.")
print(f"\n{answer}")

In [None]:
answer = ask_agent("Who are the most reliable suppliers in EUROPE? Rank the top 10.")
print(f"\n{answer}")

## 6 â€” Log Agent Interactions to MLflow

In [None]:
experiment_name = f"/Users/{spark.sql('SELECT current_user()').collect()[0][0]}/retail_agent_experiment"
mlflow.set_experiment(experiment_name)

test_questions = [
    "What was total revenue for AMERICA in 1996?",
    "Show me top 5 products by revenue.",
    "Which customers have the highest churn risk?",
    "Give me the quarterly executive summary for 1997.",
    "Who are the most reliable suppliers in EUROPE?",
]

with mlflow.start_run(run_name="agent_test_v1"):
    for i, q in enumerate(test_questions):
        answer = ask_agent(q, verbose=False)
        mlflow.log_param(f"q{i+1}", q[:100])
        mlflow.log_text(answer or "No answer", f"answers/q{i+1}.txt")
        print(f"  âœ“ Q{i+1}: {q[:60]}...")

    mlflow.log_param("model", "databricks-meta-llama-3-3-70b-instruct")
    mlflow.log_param("num_tools", len(tools))
    mlflow.log_param("tool_names", str([t['function']['name'] for t in tools]))

print(f"\nâœ“ Agent interactions logged to MLflow")

## 7 â€” Interactive Agent (Run your own questions)

In [None]:
# Change the question below and re-run this cell to ask the agent anything!
question = "Compare revenue across all regions for 1997. Which region grew the fastest?"

print(f"Q: {question}\n")
answer = ask_agent(question)
print(f"\nA: {answer}")

---
Agent is working â€” 6 UC SQL tools, Llama 3.3 70B, interactions logged to MLflow.

Continue with `08_ai_bi_dashboard.ipynb`.