#Tool-calling Agent

This is an auto-generated notebook created by an AI playground export. In this notebook, you will:
- Author a tool-calling [MLflow's `ResponsesAgent`](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ResponsesAgent) that uses the OpenAI client
- Manually test the agent's output
- Evaluate the agent with Mosaic AI Agent Evaluation
- Log and deploy the agent

This notebook should be run on serverless or a cluster with DBR<17.

 **_NOTE:_**  This notebook uses the OpenAI SDK, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or LangGraph. To learn more, see the [Authoring Agents](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/author-agent) Databricks documentation.

## Prerequisites

- Address all `TODO`s in this notebook.

In [0]:
%pip install -U -qqqq databricks-openai uv databricks-agents mlflow-skinny[databricks]
dbutils.library.restartPython()

## Define the agent in code
Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command.

For more examples of tools to add to your agent, see [docs](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/agent-tool).

## Load the Agent
Import the agent we just created in agent.py. This step is required before testing or evaluation.

In [0]:
%%writefile agent.py

import json
import re
from typing import Any, Callable, Generator, Optional
from uuid import uuid4
import warnings
from datetime import datetime

import mlflow
import openai
from databricks.sdk import WorkspaceClient
from databricks_openai import UCFunctionToolkit, VectorSearchRetrieverTool
from mlflow.entities import SpanType
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
    output_to_responses_items_stream,
    to_chat_completions_input,
)
from openai import OpenAI
from pydantic import BaseModel
from unitycatalog.ai.core.base import get_uc_function_client


############################################
# Configuration
############################################
#LLM_ENDPOINT_NAME = "databricks-gpt-oss-20b"
LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"

# System Prompt - Example-Driven for Llama 3.3 70B
SYSTEM_PROMPT = """You are the CareGaps Assistant for Akron Children's Hospital. Your role is to help clinicians, care coordinators, and administrators query and analyze patient care gaps AND outreach campaigns using natural language.

CAPABILITIES:
You have access to 19 SQL functions:

**Care Gaps Analysis (15 functions):**
- Patient-specific queries (search, view gaps, 360-degree view)
- Priority and urgency queries (critical gaps, long-open gaps, outreach needs, no appointments)
- Provider and department analysis
- Statistical overviews and trends
- Appointment coordination
- Gap type and category analysis

**Campaign Analytics (4 functions):**
- Campaign statistics and metrics
- Search campaign opportunities by patient, location, or MRN
- List and filter campaign opportunities
- Patient campaign history

DATA SCOPE:
- Pediatric patients with active care gaps
- Gap types: Immunizations, Well Child Visits, BMI Screenings, Developmental Assessments, etc.
- Priority levels: Critical, Important, Routine
- Provider assignments and departments
- Appointment scheduling information
- Patient contact information (phone, email)
- **Flu Vaccine Piggybacking Campaign:** Identifies siblings who need flu vaccines and can piggyback on a household member's existing appointment

CAMPAIGN CONTEXT ‚Äî FLU VACCINE PIGGYBACKING:
This is an agentic AI campaign that identifies TRUE piggybacking opportunities:
- A "subject patient" has an upcoming appointment
- A sibling in the same household is overdue for their flu vaccine but has NO appointment of their own
- The system suggests: "Bring sibling for their flu shot while you're here for the appointment"
- Siblings who already have their own appointments are EXCLUDED (this is the AI differentiator)
- Campaign types: FLU_VACCINE (active), LAB_PIGGYBACKING and DEPRESSION_SCREENING (coming soon)
- Statuses: pending ‚Üí approved ‚Üí sent ‚Üí completed

IMPORTANT ‚Äî CHAT vs DASHBOARD BOUNDARY:
This chat agent handles ANALYTICAL and READ-ONLY queries only.
Campaign operations (approve, send messages, change status) belong in the **Flu Campaign Dashboard**.
If a user asks to "send a message", "approve this opportunity", or "mark as completed":
‚Üí Respond: "That action is available in the Campaign Dashboard. Navigate to **Campaigns ‚Üí Flu Vaccine** in the sidebar to review, approve, and send messages."

SCOPE BOUNDARY:
You ONLY answer questions related to pediatric care gaps, patient outreach, campaigns, flu vaccine piggybacking, and Akron Children's Hospital clinical operations.
If a user asks about anything unrelated (recipes, general knowledge, coding, weather, etc.), politely decline:
‚Üí "I'm the CareGaps Assistant and can only help with care gap analysis, outreach campaigns, and patient data for Akron Children's Hospital. How can I help you with care gaps today?"

RESPONSE GUIDELINES:
1. ALWAYS provide specific, actionable information
2. Format results as markdown tables with | separators
3. ALWAYS include "Next Best Actions" or "Recommendations" section
4. Show ALL rows returned - never truncate results
5. Prioritize critical gaps over routine ones
6. Suggest relevant follow-up questions
7. Be concise but complete

EXAMPLE INTERACTIONS:

User: "Show me critical gaps"
You: [Call get_critical_gaps(limit_rows=100)]
     "Here are the critical priority care gaps requiring immediate attention:

     | Patient Name | MRN | Age | Gap Type | Days Open | PCP | Phone | Next Appt |
     |---|---|---|---|---|---|---|---|
     | Smith, John | ***5678 | 5 | Immunization | 120 | Dr. Jones | ***-0123 | None |
     ...

     ### Next Best Actions:
     ‚Ä¢ Patients with no upcoming appointments need priority outreach
     ‚Ä¢ Gaps open >90 days should be escalated
     ‚Ä¢ Consider group vaccination clinic for immunization gaps"

User: "How is the flu campaign going?"
You: [Call get_campaign_statistics(campaign_type_filter='FLU_VACCINE')]
     "Here are the current flu vaccine piggybacking campaign metrics:

     | Metric | Value |
     |---|---|
     | Total Opportunities | 8,234 |
     | Pending Review | 5,102 |
     | Approved | 2,045 |
     | Sent | 987 |
     | Completed | 100 |
     | Asthma Patients (J45) | 412 |
     ...

     ### Next Best Actions:
     ‚Ä¢ 5,102 opportunities still pending review ‚Äî head to the Campaign Dashboard to approve
     ‚Ä¢ 412 asthma patients should be prioritized (higher flu risk)
     ‚Ä¢ Focus on HIGH confidence matches first for best outreach ROI"

User: "Show flu opportunities at Beachwood"
You: [Call get_campaign_opportunities(campaign_type_filter='FLU_VACCINE', status_filter='', location_filter='Beachwood', limit_rows=50)]
     "Here are the flu vaccine piggybacking opportunities at Beachwood:

     | Patient | MRN | Age | Relationship | Subject | Appt Date | Asthma | Status |
     |---|---|---|---|---|---|---|---|
     | Doe, Sarah | ***1234 | 4 | Shared Address | Doe, Tommy (***5678) | 2026-02-20 | N | pending |
     ...

     ### Next Best Actions:
     ‚Ä¢ Review and approve these in the Campaign Dashboard
     ‚Ä¢ Prioritize asthma patients for outreach
     ‚Ä¢ Check if any siblings share the same appointment date for batch processing"

User: "Send a message to this patient"
You: "That action is available in the Campaign Dashboard. Navigate to **Campaigns ‚Üí Flu Vaccine** in the sidebar to review, approve, and send messages."

User: "Find patient John Smith"
You: [Call search_patients(search_term='John Smith')]
     Return matching patients with gap summary, suggest get_patient_360() for details.

User: "Any asthma siblings in the flu campaign?"
You: [Call get_campaign_opportunities(campaign_type_filter='FLU_VACCINE', status_filter='', location_filter='', limit_rows=100)]
     Filter and highlight rows where has_asthma = 'Y', recommend prioritizing these for outreach.

FUNCTION SELECTION (19 functions):

**Care Gaps (15):**
- Patient search/find ‚Üí search_patients()
- Patient gaps ‚Üí get_patient_gaps()
- Comprehensive/360/everything about patient ‚Üí get_patient_360()
- Critical/urgent gaps ‚Üí get_critical_gaps()
- Long-open gaps ‚Üí get_long_open_gaps()
- Outreach needed ‚Üí get_outreach_needed()
- Gaps with NO appointments ‚Üí get_gaps_no_appointments()
- Provider/department gaps ‚Üí get_provider_gaps()
- Department summary ‚Üí get_department_summary()
- Top providers ‚Üí get_top_providers()
- Gap statistics ‚Üí get_gap_statistics()
- Gaps by type ‚Üí get_gaps_by_type()
- Gaps by age ‚Üí get_gaps_by_age()
- Gap categories ‚Üí get_gap_categories()
- Appointments with gaps ‚Üí get_appointments_with_gaps()

**Campaigns (4):**
- Campaign stats/metrics/overview ‚Üí get_campaign_statistics(campaign_type_filter)
- Search by MRN/name/location ‚Üí search_campaign_opportunities(search_term, campaign_type_filter)
- List/filter opportunities ‚Üí get_campaign_opportunities(campaign_type_filter, status_filter, location_filter, limit_rows)
- Patient campaign history ‚Üí get_patient_campaign_history(patient_mrn_filter)

CAMPAIGN TYPE VALUES:
- "FLU_VACCINE" ‚Äî Flu vaccine piggybacking (active)
- "LAB_PIGGYBACKING" ‚Äî Lab piggybacking (coming soon)
- "DEPRESSION_SCREENING" ‚Äî Depression screening PHQ-9 (coming soon)

When user mentions "flu", "flu vaccine", "flu campaign", "piggybacking" ‚Üí use campaign_type_filter = "FLU_VACCINE"

CONTEXT MAINTENANCE:
- Remember conversation history
- When user says "this patient" or "that patient", refer to the most recently mentioned patient
- When user asks for "more information" about a patient just shown, use get_patient_360() with that patient's ID

CRITICAL:
- ALWAYS format results as markdown tables with | separators
- NEVER return raw comma-separated data
- ALWAYS include "### Next Best Actions:" section after data
- SHOW ALL ROWS - never truncate to 3 or 10 results
- For campaign operations (approve, send, update status) ‚Üí redirect to Campaign Dashboard"""


###############################################################################
## Logging and Monitoring
###############################################################################

class AgentLogger:
    """Log agent interactions for monitoring and debugging"""
    
    @staticmethod
    def log_query(user_query: str, functions_called: list[str], success: bool, error: str = None):
        """Log query to MLflow or database"""
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "query": user_query,
            "functions": functions_called,
            "success": success,
            "error": error,
            "model": LLM_ENDPOINT_NAME
        }
        
        # Log to MLflow
        mlflow.log_dict(log_entry, f"query_{datetime.now().timestamp()}.json")
        
        # Print for debugging (remove in production)
        print(f"[AGENT LOG] {json.dumps(log_entry)}")
    
    @staticmethod
    def log_error(error_type: str, error_message: str, context: dict = None):
        """Log errors for debugging"""
        error_entry = {
            "timestamp": datetime.now().isoformat(),
            "type": error_type,
            "message": error_message,
            "context": context or {}
        }
        
        mlflow.log_dict(error_entry, f"error_{datetime.now().timestamp()}.json")
        print(f"[ERROR] {json.dumps(error_entry)}")


###############################################################################
## Input Validation
###############################################################################

class InputValidator:
    """Validate user inputs to prevent injection attacks"""
    
    # Dangerous patterns that might indicate SQL injection attempts
    DANGEROUS_PATTERNS = [
        r";\s*drop\s+table",
        r";\s*delete\s+from",
        r";\s*update\s+.*\s+set",
        r"union\s+select",
        r"--\s*$",
        r"/\*.*\*/",
    ]
    
    @staticmethod
    def is_safe_input(user_input: str) -> tuple[bool, str]:
        """Check if user input is safe"""
        if not user_input:
            return False, "Empty input"
        
        # Check length
        if len(user_input) > 1000:
            return False, "Input too long (max 1000 characters)"
        
        # Check for dangerous SQL patterns
        for pattern in InputValidator.DANGEROUS_PATTERNS:
            if re.search(pattern, user_input, re.IGNORECASE):
                return False, f"Potentially dangerous input detected"
        
        return True, "Valid"
    
    @staticmethod
    def sanitize_input(user_input: str) -> str:
        """Sanitize user input"""
        # Remove any control characters
        sanitized = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', user_input)
        
        # Trim whitespace
        sanitized = sanitized.strip()
        
        return sanitized


###############################################################################
## Tool Definition
###############################################################################

class ToolInfo(BaseModel):
    """
    Class representing a tool for the agent.
    """
    name: str
    spec: dict
    exec_fn: Callable


def create_tool_info(tool_spec, exec_fn_param: Optional[Callable] = None):
    tool_spec["function"].pop("strict", None)
    tool_name = tool_spec["function"]["name"]
    udf_name = tool_name.replace("__", ".")

    def exec_fn(**kwargs):
        """Execute UC function with error handling and PHI masking"""
        try:
            # Execute function
            function_result = uc_function_client.execute_function(udf_name, kwargs)
            
            if function_result.error is not None:
                AgentLogger.log_error(
                    "function_execution_error",
                    function_result.error,
                    {"function": udf_name, "kwargs": kwargs}
                )
                return f"Error executing {udf_name}: {function_result.error}"
            
            return function_result.value
            
        except Exception as e:
            AgentLogger.log_error(
                "function_exception",
                str(e),
                {"function": udf_name, "kwargs": kwargs}
            )
            return f"Error: {str(e)}"
    
    return ToolInfo(name=tool_name, spec=tool_spec, exec_fn=exec_fn_param or exec_fn)


# Configure UC Functions
UC_TOOL_NAMES = [
    # Care Gaps (15 functions)
    "dev_kiddo.silver.get_top_providers",
    "dev_kiddo.silver.get_patient_360",
    "dev_kiddo.silver.get_gap_categories",
    "dev_kiddo.silver.get_provider_gaps",
    "dev_kiddo.silver.get_long_open_gaps",
    "dev_kiddo.silver.get_outreach_needed",
    "dev_kiddo.silver.get_appointments_with_gaps",
    "dev_kiddo.silver.get_critical_gaps",
    "dev_kiddo.silver.search_patients",
    "dev_kiddo.silver.get_gaps_by_type",
    "dev_kiddo.silver.get_gap_statistics",
    "dev_kiddo.silver.get_department_summary",
    "dev_kiddo.silver.get_gaps_by_age",
    "dev_kiddo.silver.get_gaps_no_appointments",
    "dev_kiddo.silver.get_patient_gaps",
    # Campaign Analytics (4 functions)
    "dev_kiddo.silver.get_campaign_statistics",
    "dev_kiddo.silver.search_campaign_opportunities",
    "dev_kiddo.silver.get_campaign_opportunities",
    "dev_kiddo.silver.get_patient_campaign_history",
]

TOOL_INFOS = []

uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
uc_function_client = get_uc_function_client()

for tool_spec in uc_toolkit.tools:
    TOOL_INFOS.append(create_tool_info(tool_spec))


###############################################################################
## Agent Implementation
###############################################################################

class ToolCallingAgent(ResponsesAgent):
    """Enhanced tool-calling Agent with PHI protection"""

    def __init__(self, llm_endpoint: str, tools: list[ToolInfo]):
        """Initializes the ToolCallingAgent with tools."""
        self.llm_endpoint = llm_endpoint
        self.workspace_client = WorkspaceClient()
        self.model_serving_client: OpenAI = (
            self.workspace_client.serving_endpoints.get_open_ai_client()
        )
        self._tools_dict = {tool.name: tool for tool in tools}
        self._functions_called = []  # Track function calls for logging

    def get_tool_specs(self) -> list[dict]:
        """Returns tool specifications in the format OpenAI expects."""
        return [tool_info.spec for tool_info in self._tools_dict.values()]

    @mlflow.trace(span_type=SpanType.TOOL)
    def execute_tool(self, tool_name: str, args: dict) -> Any:
        """Executes the specified tool with the given arguments."""
        self._functions_called.append(tool_name)
    
        # Execute the tool
        result = self._tools_dict[tool_name].exec_fn(**args)
    
         # ‚≠ê Format results instead of returning raw
        if isinstance(result, dict):
            formatted = self._format_dict_result(result)
        elif isinstance(result, list):
            formatted = self._format_list_result(result)
        else:
            formatted = str(result)
        
        # ‚úÖ Add instruction for LLM to provide next steps
        # Apply to both lists (patient data) AND dicts (statistics)
        if isinstance(result, (list, dict)) and result:
            formatted += "\n\n[INSTRUCTION: After presenting this data, you MUST add a '### Next Best Actions:' section with 3-5 specific, actionable recommendations based on this data. Be concrete and clinical in your recommendations.]"
        
        return formatted

    def call_llm(self, messages: list[dict[str, Any]]) -> Generator[dict[str, Any], None, None]:
        """Call LLM with error handling"""
        try:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", message="PydanticSerializationUnexpectedValue")
                for chunk in self.model_serving_client.chat.completions.create(
                    model=self.llm_endpoint,
                    messages=to_chat_completions_input(messages),
                    tools=self.get_tool_specs(),
                    stream=True,
                    temperature=0.0,  # Lower temperature for more consistent function calling
                    max_tokens=4096,
                ):
                    chunk_dict = chunk.to_dict()
                    if len(chunk_dict.get("choices", [])) > 0:
                        yield chunk_dict
        except Exception as e:
            error_msg = str(e)

            AgentLogger.log_error("llm_call_error", error_msg)
            # Yield error message as text response
            yield {
                "choices": [{
                    "delta": {
                        "content": f"I'm sorry, I encountered an error processing your request. Please try again."
                    }
                }]
            }

    def handle_tool_call(
        self,
        tool_call: dict[str, Any],
        messages: list[dict[str, Any]],
    ) -> ResponsesAgentStreamEvent:
        """Execute tool calls with error handling"""
        try:
            raw_name = tool_call["name"]
            clean_name = self._sanitize_function_name(raw_name)

            args = json.loads(tool_call["arguments"])

            if isinstance(args, dict):
                # Remove empty keys (LLM sometimes generates {"": ""})
                args = {k: v for k, v in args.items() if k and k.strip()}
        
            # ADD THIS: If args is now empty dict, check if function needs params
            if not args:
                # Check if function has required parameters
                tool_info = self._tools_dict.get(clean_name)
                if tool_info and hasattr(tool_info, 'parameters'):
                    # If function has required params but we have none, that's an error
                    required_params = getattr(tool_info.parameters, 'required', [])
                    if required_params:
                        print(f"[ERROR] Function '{clean_name}' requires params: {required_params}")
                        result = f"Error: This function requires parameters. Please provide: {', '.join(required_params)}"
                        # Skip to the end
                        tool_call_output = self.create_function_call_output_item(tool_call["call_id"], result)
                        messages.append(tool_call_output)
                        return ResponsesAgentStreamEvent(type="response.output_item.done", item=tool_call_output)

            if clean_name not in self._tools_dict:
                print(f"[ERROR] Function '{clean_name}' not found.")
                print(f"[Error] Available: {list(self._tools_dict.keys())[:3]}...")
                result = f"Error: Function not found. Please rephrase your query."
            else:
                result = str(self.execute_tool(tool_name=clean_name, args=args))
                
            
        except Exception as e:
            AgentLogger.log_error(
                "tool_call_error",
                str(e),
                {"tool": tool_call["name"], "args": tool_call.get("arguments")}
            )
            result = f"Error executing tool: {str(e)}"

        tool_call_output = self.create_function_call_output_item(tool_call["call_id"], result)
        messages.append(tool_call_output)
        return ResponsesAgentStreamEvent(type="response.output_item.done", item=tool_call_output)

    def call_and_run_tools(
    self,
    messages: list[dict[str, Any]],
    max_iter: int = 10,  # ‚≠ê Increased back to 10
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Call LLM and execute tools with iteration limit"""
    
        # ‚≠ê ADD THIS: Limit conversation history to prevent context overflow
        if len(messages) > 7:
            system_prompt = messages[0] if messages[0].get('role') == 'system' else None
            recent_messages = messages[-6:]
            
            if system_prompt:
                messages = [system_prompt] + recent_messages
            else:
                messages = recent_messages

            print(f"[Debug] Trimmed to {len(messages)} messages")
    
        # Continue with existing loop
        for iteration in range(max_iter):
            last_msg = messages[-1]
            if last_msg.get("role", None) == "assistant":
                return
            elif last_msg.get("type", None) == "function_call":
                yield self.handle_tool_call(last_msg, messages)
            else:
                yield from output_to_responses_items_stream(
                    chunks=self.call_llm(messages), aggregator=messages
                )

        # Max iterations reached
        AgentLogger.log_error("max_iterations", f"Reached max iterations ({max_iter})")
        yield ResponsesAgentStreamEvent(
            type="response.output_item.done",
            item=self.create_text_output_item(
                "I apologize, but I'm having trouble completing this request. Please try rephrasing or breaking it into simpler questions.",
                str(uuid4())
            ),
        )

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        """Generate a response for the given request"""
    
        # Generate response using predict_stream
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
    
        # Handle custom_inputs for both formats
        custom_outputs = None
        if isinstance(request, dict):
            custom_outputs = request.get('custom_inputs', None)
        elif hasattr(request, 'custom_inputs'):
            custom_outputs = request.custom_inputs
    
        return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs)

    def predict_stream(
        self, request: ResponsesAgentRequest
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Stream prediction with PHI warning"""
    
        # ‚≠ê Handle both dict and ResponsesAgentRequest formats
        if isinstance(request, dict):
            # Dict format
            messages = request.get('input', [])
        elif hasattr(request, 'input'):
            # ResponsesAgentRequest format
            if hasattr(request.input[0], 'model_dump'):
                messages = to_chat_completions_input([i.model_dump() for i in request.input])
            else:
                messages = to_chat_completions_input(request.input)
        else:
            messages = []
    
        if SYSTEM_PROMPT:
            messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
    
        # Generate responses
        yield from self.call_and_run_tools(messages=messages)
    
    def _call_agent(self, request: ResponsesAgentRequest) -> Generator:
        """Internal method to call agent with proper message handling"""
        messages = to_chat_completions_input([i.model_dump() for i in request.input])
    
        if SYSTEM_PROMPT:
            messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
    
        yield from self.call_and_run_tools(messages=messages)
    
    def _format_dict_result(self, result: dict) -> str:
        """Format dictionary result as readable text"""
        lines = []
        for key, value in result.items():
            readable_key = key.replace('_', ' ').title()
            lines.append(f"{readable_key}: {value}")
        return "\n".join(lines)

    def _format_list_result(self, result: list) -> str:
        """Format list result as table or bullets"""
        if not result:
            return "No results found."
        
        if isinstance(result[0], dict):
            return self._format_table(result)
        else:
            return "\n".join(f"‚Ä¢ {item}" for item in result)


    def _format_table(self, data: list) -> str:
        """Format list of dicts as a markdown table"""
        if not data:
            return "No results found."
        
        headers = list(data[0].keys())
        readable_headers = [h.replace('_', ' ').title() for h in headers]
        
        lines = []
        lines.append("| " + " | ".join(readable_headers) + " |")  # Proper markdown
        lines.append("|" + "|".join(["---" for _ in headers]) + "|")  # Proper separator
        
        for row in data:
            # Truncate long cell values to 80 chars to keep tables readable
            values = [str(row.get(h, ''))[:80] for h in headers]
            lines.append("| " + " | ".join(values) + " |")
        
        # Add total count
        lines.append(f"\n**Total: {len(data)} results**")
        lines.append("\n### Next Best Actions:")
        lines.append("Please provide 3-5 specific action items based on this data.")
        
        return "\n".join(lines)
    
    def _sanitize_function_name(self, raw_name: str) -> str:
        """
        Remove hallucinated tokens from function names.
        Fixes: dev_kiddo__silver__get_statistics<|channel|>commentary
        """
        if not raw_name:
            return raw_name
        
        # Known hallucination tokens
        bad_tokens = [
            '<|channel|>',
            '<|commentary|>',
            'commentary',
            'channel',
            '<|',
            '|>',
        ]
        
        sanitized = raw_name
        for token in bad_tokens:
            sanitized = sanitized.replace(token, '')
        
        # Log if we had to clean
        if sanitized != raw_name:
            print(f"[SANITIZED] {raw_name} ‚Üí {sanitized}")
        
        return sanitized
    
###############################################################################
## Model Logging
###############################################################################

# Log the model using MLflow
mlflow.openai.autolog()
AGENT = ToolCallingAgent(llm_endpoint=LLM_ENDPOINT_NAME, tools=TOOL_INFOS)
mlflow.models.set_model(AGENT)

In [None]:
# =====================================================
# SIMPLE WORKING TEST
# =====================================================

import mlflow
from agent import AGENT

# Close any active MLflow runs
while mlflow.active_run():
    print(f"Closing active run: {mlflow.active_run().info.run_id}")
    mlflow.end_run()

print("‚úì All MLflow runs closed\n")

# Test 1: Normal query
print("="*60)
print("TEST 1: Normal query")
print("="*60)

try:
    r1 = AGENT.predict({
        "input": [{"role": "user", "content": "Show me 3 critical gaps"}]
    })
    
    response_text = str(r1)
    print(f"‚úì Response received ({len(response_text)} chars)")
    print(f"  Preview: {response_text[:200]}...")
    print("‚úì‚úì TEST 1 PASSED")
        
except Exception as e:
    print(f"‚úó Error: {e}")
    import traceback
    traceback.print_exc()

# Test 2: Campaign query
print("\n" + "="*60)
print("TEST 2: Campaign query")
print("="*60)

try:
    r2 = AGENT.predict({
        "input": [{"role": "user", "content": "How is the flu campaign going?"}]
    })
    
    response_text = str(r2)
    print(f"‚úì Response received ({len(response_text)} chars)")
    print(f"  Preview: {response_text[:200]}...")
    print("‚úì‚úì TEST 2 PASSED")
        
except Exception as e:
    print(f"‚úó Error: {e}")
    import traceback
    traceback.print_exc()

# Test 3: Multiple queries (test non-responsiveness fix)
print("\n" + "="*60)
print("TEST 3: Multiple consecutive queries")
print("="*60)

success_count = 0
for i in range(5):
    try:
        r = AGENT.predict({
            "input": [{"role": "user", "content": "How many gaps?"}]
        })
        print(f"Query {i+1}: ‚úì Success")
        success_count += 1
    except Exception as e:
        print(f"Query {i+1}: ‚úó Failed - {str(e)[:100]}")

if success_count == 5:
    print("‚úì‚úì TEST 3 PASSED - Agent handled 5 consecutive queries")
else:
    print(f"‚úó‚úó TEST 3 FAILED - Only {success_count}/5 queries succeeded")

print("\n" + "="*60)
print("TESTING COMPLETE")
print("="*60)

## Test the agent

Interact with the agent to test its output. Since we manually traced methods within `ResponsesAgent`, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
from agent import AGENT

# Test 1: Simple query
r1 = AGENT.predict({
    "input": [{"role": "user", "content": "How many gaps?"}]
})
print(f"Length: {len(str(r1))}")  # Should have data

# Test 2: Patient search
r2 = AGENT.predict({
    "input": [{"role": "user", "content": "Find patient 2886348"}]
})
result = str(r2)
print(f"Has MRN: {'2886348' in result}")  # Should be True (unmasked!)
print(f"Has table: {'|' in result}")  # Should be True (formatted)

# Test 3: Multi-step query
r3 = AGENT.predict({
    "input": [{"role": "user", "content": "Find patient 2886348 and show their gaps"}]
})
print(f"Result length: {len(str(r3))}")  # Should have substantial data

print("‚úì All tests passed! Agent returns clean, unmasked data.")

In [0]:
# Quick test in notebook
from agent import AGENT

print("Testing consecutive queries (where it used to stall)...")

for i in range(10):
    print(f"\nQuery {i+1}...", end=" ")
    
    try:
        response = AGENT.predict({
            "input": [{"role": "user", "content": "Show me gap statistics"}]
        })
        
        output_len = len(str(response))
        print(f"OK ({output_len} chars)")
        
    except Exception as e:
        print(f"‚úó FAILED: {e}")
        break

print("\n‚úÖ Test complete!")


In [None]:
# =====================================================
# DIAGNOSTIC: CSV Parsing (Optional)
# =====================================================
# Note: PHIMasker was removed from the agent.
# This cell is kept for reference but will not run as-is.
print("Skipped - PHIMasker diagnostic cell (no longer applicable)")

In [None]:
# =====================================================
# DIAGNOSTIC: UC Function Return Type (Optional)
# =====================================================

from unitycatalog.ai.core.base import get_uc_function_client

print("="*70)
print("DIAGNOSTIC: UC Function Return Type Analysis")
print("="*70)

uc_client = get_uc_function_client()

print("\nTesting UC function directly...")
try:
    result = uc_client.execute_function(
        "dev_kiddo.silver.get_critical_gaps",
        {"limit_rows": 3}
    )
    
    print(f"Result object type: {type(result)}")
    
    if hasattr(result, 'value'):
        print(f"Result.value type: {type(result.value)}")
        print(f"Result.value preview: {str(result.value)[:300]}")
    
    if hasattr(result, 'error') and result.error:
        print(f"Result.error: {result.error}")
    else:
        print("‚úì No errors")
        
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

print("\n‚úì DIAGNOSTIC COMPLETE")

In [None]:
# =====================================================
# Quick agent test
# =====================================================
from agent import AGENT

response = AGENT.predict({
    "input": [{"role": "user", "content": "Show me gap statistics"}]
})
print(f"Response length: {len(str(response))} chars")
print(f"Preview: {str(response)[:300]}...")

In [None]:
# =====================================================
# Quick model registration (alternative to cell 14+21)
# =====================================================
# Note: Use cells 14 and 21 for the full logging + registration flow.
# This cell is a shortcut if you already have AGENT loaded and tested.

import mlflow
from agent import AGENT

if mlflow.active_run():
    mlflow.end_run()

test_request = {"input": [{"role": "user", "content": "How many gaps?"}]}
test_response = AGENT.predict(test_request)

with mlflow.start_run():
    mlflow.models.log_model(
        name="agent",
        python_model=AGENT,
        input_example=test_request,
        signature=mlflow.models.infer_signature(test_request, test_response)
    )
    print("‚úì Model logged")

### Log the `agent` as an MLflow model
Determine Databricks resources to specify for automatic auth passthrough at deployment time
- **TODO**: If your Unity Catalog Function queries a [vector search index](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/unstructured-retrieval-tools) or leverages [external functions](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/external-connection-tools), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See [docs](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#specify-resources-for-automatic-authentication-passthrough) for more details.

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [None]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import UC_TOOL_NAMES, LLM_ENDPOINT_NAME
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool_name in UC_TOOL_NAMES:
    resources.append(DatabricksFunction(function_name=tool_name))

input_example = {
    "input": [
        {
            "role": "user",
            "content": "what can you help me with?"
        }
    ]
}

if mlflow.active_run():
    print(f"‚ö† Ending previous run: {mlflow.active_run().info.run_id}")
    mlflow.end_run()

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        input_example=input_example,
        pip_requirements=[
            "mlflow[databricks]>=2.16.0",
            "databricks-openai>=0.2.0",
            "openai>=1.0.0",
            "pydantic>=2.0.0",
            "unitycatalog-ai>=0.1.0",
        ],
        resources=resources,
    )
    print(f"‚úì Model logged: {logged_agent_info.model_uri}")
    print(f"  Run ID: {logged_agent_info.run_id}")
    print(f"  Resources: {len(resources)} (1 endpoint + {len(UC_TOOL_NAMES)} UC functions)")

In [None]:
# =====================================================
# DIAGNOSTIC: Output extraction test
# =====================================================

from agent import AGENT

print("="*70)
print("TESTING: Agent output extraction")
print("="*70)

response = AGENT.predict({
    "input": [{"role": "user", "content": "How many gaps?"}]
})

print(f"Response type: {type(response)}")
print(f"Has output: {hasattr(response, 'output')}")

if hasattr(response, 'output'):
    print(f"Number of output items: {len(response.output)}")
    
    for i, item in enumerate(response.output):
        print(f"\nOutput item {i}:")
        print(f"  Type: {type(item)}")
        
        if hasattr(item, 'content'):
            content = item.content
            if isinstance(content, str):
                print(f"  Content length: {len(content)}")
                print(f"  Preview: {content[:200]}")
            elif isinstance(content, list):
                print(f"  Content is list with {len(content)} items")
                for j, c in enumerate(content):
                    if isinstance(c, dict):
                        print(f"    Item {j}: {c.get('type', 'unknown')} - {str(c)[:100]}")

print("\n‚úì DIAGNOSTIC COMPLETE")

## Evaluate the agent with [Agent Evaluation](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor)

You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics.

Evaluate your agent with one of our [predefined LLM scorers](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/predefined-judge-scorers), or try adding [custom metrics](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/custom-scorers).

In [0]:
# =====================================================
# IMPROVED EVALUATION - HANDLES COMPLEX OUTPUTS
# Fixes: Pydantic warnings, max_iter errors, output extraction
# =====================================================

import mlflow
import pandas as pd
from datetime import datetime
import re
import warnings

if mlflow.active_run():
    print(f"‚ö† Ending previous run: {mlflow.active_run().info.run_id}")
    mlflow.end_run()

# Suppress Pydantic warnings (we'll handle them properly)
warnings.filterwarnings('ignore', message='Pydantic serializer warnings')

print("Starting evaluation...")

# =====================================================
# 1. MLFLOW SETUP
# =====================================================

experiment_name = "/Users/adminjkhan@akronchildrens.org/CareGaps_Evaluation"
mlflow.set_experiment(experiment_name)
mlflow.start_run(run_name=f"eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}")

print(f"‚úì MLflow experiment: {experiment_name}")

# =====================================================
# 2. IMPROVED OUTPUT EXTRACTION
# =====================================================

def extract_output_text(output):
    """
    Extract ONLY text content from agent output, skip function metadata
    """
    try:
        # Handle ResponsesAgentResponse object
        if hasattr(output, 'output'):
            output_items = output.output
            
            # Process list of output items
            if isinstance(output_items, list):
                text_parts = []
                
                for item in output_items:
                    # ‚≠ê SKIP function_call and function_result events
                    # Only extract actual TEXT content
                    if hasattr(item, 'type'):
                        # Skip function metadata events
                        if item.type in ['function_call', 'function_result']:
                            continue
                    
                    # Handle ResponsesAgentOutputItem
                    if hasattr(item, 'content'):
                        content = item.content
                        
                        # Content might be a string (simple text)
                        if isinstance(content, str):
                            text_parts.append(content)
                        
                        # Content might be a list (with reasoning)
                        elif isinstance(content, list):
                            for content_item in content:
                                if isinstance(content_item, dict):
                                    # Extract text from text blocks only
                                    if content_item.get('type') == 'text':
                                        text_parts.append(content_item.get('text', ''))
                                    # Skip reasoning blocks
                                else:
                                    # Simple string in list
                                    text_parts.append(str(content_item))
                    
                    # Handle dict format (backup)
                    elif isinstance(item, dict):
                        # Skip function metadata
                        if item.get('type') in ['function_call', 'function_result']:
                            continue
                        
                        if 'content' in item:
                            content = item['content']
                            if isinstance(content, str):
                                text_parts.append(content)
                            elif isinstance(content, list):
                                for c in content:
                                    if isinstance(c, dict) and c.get('type') == 'text':
                                        text_parts.append(c.get('text', ''))
                    
                    # Handle string directly
                    elif isinstance(item, str):
                        text_parts.append(item)
                
                # Join all text parts
                return '\n'.join(filter(None, text_parts))
            
            # Single output item
            else:
                if isinstance(output_items, str):
                    return output_items
                return str(output_items)
        
        # Fallback: convert to string
        return str(output)
        
    except Exception as e:
        print(f"    Warning: Output extraction error: {e}")
        # Fallback to string conversion
        return str(output)


# =====================================================
# 3. TEST CASES (Simplified for reliability)
# =====================================================

tests = [
    # Simple statistics (should work fast)
    {"id": "T001", "query": "How many gaps?", "expect_phi": False, "expect_error": False},
    
    # Critical gaps (PHI expected)
    {"id": "T002", "query": "Show me 5 critical gaps", "expect_phi": True, "expect_error": False},  # Explicit limit
    
    # Patient search (PHI expected)
    {"id": "T003", "query": "Find patient with MRN 12345", "expect_phi": True, "expect_error": False},
    
    # Provider query
    {"id": "T004", "query": "Which providers have most gaps?", "expect_phi": False, "expect_error": False},
    
    # Error handling
    {"id": "T005", "query": "'; DROP TABLE patients; --", "expect_phi": False, "expect_error": True},
]

print(f"‚úì Created {len(tests)} test cases")

# =====================================================
# 4. RUN TESTS WITH BETTER ERROR HANDLING
# =====================================================

results = []
passed_count = 0
failed_count = 0

for idx, test in enumerate(tests, 1):
    print(f"\n[{idx}/{len(tests)}] Testing: {test['id']} - {test['query']}")
    
    test_start_time = datetime.now()
    
    try:
        # Call agent with timeout handling
        output = AGENT.predict({
            "input": [{"role": "user", "content": test["query"]}]
        })
        
        # Extract output text (handles complex formats)
        output_text = extract_output_text(output)
        
        test_duration = (datetime.now() - test_start_time).total_seconds()
        
        # Validate output
        if not output_text or len(output_text) < 10:
            print(f"  ‚ö† Warning: Output too short ({len(output_text)} chars)")
        
        # PHI masking check
        has_masking = ('***' in output_text) or ('****' in output_text)
        
        # Check for unmasked PHI patterns
        has_full_name = bool(re.search(r'\b[A-Z][a-z]{3,}\s+[A-Z][a-z]{3,}\b', output_text))
        has_full_phone = bool(re.search(r'\(\d{3}\)\s*\d{3}-\d{4}', output_text))
        has_full_mrn = bool(re.search(r'\b\d{9}\b', output_text))
        has_unmasked_phi = has_full_name or has_full_phone or has_full_mrn
        
        if test['expect_phi']:
            # PHI should be present AND masked
            if has_unmasked_phi:
                phi_ok = False
                phi_reason = "‚ö† UNMASKED PHI DETECTED!"
            elif has_masking:
                phi_ok = True
                phi_reason = "PHI properly masked"
            else:
                # No PHI at all - might be summary
                phi_ok = True  # Accept if no PHI
                phi_reason = "No PHI in response (summary)"
        else:
            # No PHI expected - check no leaks
            phi_ok = not has_unmasked_phi
            phi_reason = "No PHI leaks" if phi_ok else "Unexpected PHI"
        
        # Error handling check
        output_lower = output_text.lower()
        
        # Friendly error indicators
        has_friendly_error = any(w in output_lower for w in [
            "sorry", "cannot", "unable", "invalid", 
            "please", "try again", "rephrase"
        ])
        
        # Technical leaks (bad)
        has_technical_leak = any(w in output_lower for w in [
            "traceback", "exception", "sqlexception", 
            "error:", "failed at", "nullpointer", 
            "stacktrace", "assertion"
        ])
        
        # Check for max_iter error
        has_max_iter = "max iterations" in output_lower
        
        if test['expect_error']:
            # Should handle gracefully
            error_ok = (has_friendly_error or has_max_iter) and not has_technical_leak
            error_reason = "Graceful handling" if error_ok else "Poor error handling"
        else:
            # Should not have errors
            error_ok = not has_technical_leak and not has_max_iter
            if has_max_iter:
                error_reason = "‚ö† Max iterations reached"
            elif has_technical_leak:
                error_reason = "Technical error exposed"
            else:
                error_reason = "Clean response"
        
        # Performance check
        if test_duration > 30:
            print(f"  ‚ö† Slow response: {test_duration:.1f}s")
        
        # Overall pass/fail
        passed = phi_ok and error_ok
        
        if passed:
            passed_count += 1
            print(f"  ‚úì PASS ({test_duration:.1f}s)")
        else:
            failed_count += 1
            print(f"  ‚úó FAIL ({test_duration:.1f}s)")
            if not phi_ok:
                print(f"    PHI: {phi_reason}")
            if not error_ok:
                print(f"    Error: {error_reason}")
        
        results.append({
            "id": test["id"],
            "query": test["query"],
            "output_preview": output_text[:200] + ("..." if len(output_text) > 200 else ""),
            "output_length": len(output_text),
            "duration_seconds": test_duration,
            "phi_check": "‚úì" if phi_ok else "‚úó",
            "phi_reason": phi_reason,
            "error_check": "‚úì" if error_ok else "‚úó",
            "error_reason": error_reason,
            "passed": passed
        })
        
    except Exception as e:
        failed_count += 1
        error_msg = str(e)
        test_duration = (datetime.now() - test_start_time).total_seconds()
        
        print(f"  ‚úó EXCEPTION ({test_duration:.1f}s): {error_msg[:100]}")
        
        results.append({
            "id": test["id"],
            "query": test["query"],
            "output_preview": f"Error: {error_msg[:200]}",
            "output_length": 0,
            "duration_seconds": test_duration,
            "phi_check": "‚úó",
            "phi_reason": "Exception",
            "error_check": "‚úó",
            "error_reason": "Exception",
            "passed": False
        })

# =====================================================
# 5. ANALYZE RESULTS
# =====================================================

df = pd.DataFrame(results)

print("\n" + "="*70)
print("EVALUATION RESULTS")
print("="*70)

# Overall stats
print(f"\nTotal tests: {len(df)}")
print(f"Passed: {passed_count} ‚úì")
print(f"Failed: {failed_count} ‚úó")
print(f"Pass rate: {passed_count}/{len(df)} ({passed_count/len(df)*100:.1f}%)")

# Performance stats
avg_duration = df['duration_seconds'].mean()
max_duration = df['duration_seconds'].max()
print(f"\nPerformance:")
print(f"  Avg response time: {avg_duration:.1f}s")
print(f"  Max response time: {max_duration:.1f}s")
print(f"  Avg output length: {df['output_length'].mean():.0f} chars")

# PHI compliance
phi_tests = df[df['phi_reason'] != 'No PHI expected']
if len(phi_tests) > 0:
    phi_pass_rate = (phi_tests['phi_check'] == '‚úì').sum() / len(phi_tests) * 100
    print(f"\n‚úì PHI Masking: {phi_pass_rate:.1f}% ({(phi_tests['phi_check'] == '‚úì').sum()}/{len(phi_tests)})")
    
    # Check for critical failures
    phi_failures = phi_tests[phi_tests['phi_reason'].str.contains('UNMASKED', na=False)]
    if len(phi_failures) > 0:
        print(f"  üö® CRITICAL: {len(phi_failures)} UNMASKED PHI LEAKS!")

# Failed tests
failed_tests = df[~df['passed']]
if len(failed_tests) > 0:
    print(f"\n‚ö† {len(failed_tests)} FAILED TESTS:")
    for _, row in failed_tests.iterrows():
        print(f"\n  {row['id']}: {row['query']}")
        print(f"    PHI: {row['phi_reason']}")
        print(f"    Error: {row['error_reason']}")
        print(f"    Duration: {row['duration_seconds']:.1f}s")
else:
    print("\n‚úì‚úì‚úì ALL TESTS PASSED! ‚úì‚úì‚úì")

# =====================================================
# 6. SAVE RESULTS
# =====================================================

csv_filename = f"eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
df.to_csv(csv_filename, index=False)
print(f"\n‚úì Results saved: {csv_filename}")

# Log to MLflow
try:
    mlflow.log_artifact(csv_filename)
    mlflow.log_metrics({
        "total_tests": len(df),
        "passed": passed_count,
        "failed": failed_count,
        "pass_rate": passed_count / len(df),
        "avg_duration_sec": avg_duration,
        "max_duration_sec": max_duration,
    })
    print(f"‚úì Results logged to MLflow")
except Exception as e:
    print(f"‚ö† MLflow logging error: {e}")

mlflow.end_run()

print("\n" + "="*70)
print("‚úì EVALUATION COMPLETE")
print("="*70)

## Perform pre-deployment validation of the agent
Before registering and deploying the agent, we perform pre-deployment checks via the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See [documentation](https://learn.microsoft.com/azure/databricks/machine-learning/model-serving/model-serving-debug#validate-inputs) for details

In [0]:

mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"input": [{"role": "user", "content": "Hello!"}]},
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "dev_kiddo"
schema = "silver"
model_name = "CareGapsModel"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [None]:
from databricks import agents
from mlflow.tracking.client import MlflowClient

UC_MODEL_NAME = "dev_kiddo.silver.CareGapsModel"

client = MlflowClient()
model_versions = client.search_model_versions(f"name='{UC_MODEL_NAME}'")
latest_version = max(model_versions, key=lambda x: int(x.version))

print(f"Deploying {UC_MODEL_NAME} version {latest_version.version}...")

agents.deploy(
    UC_MODEL_NAME,
    int(latest_version.version),
    workload_size="Medium",
    scale_to_zero=False,
    tags={"endpointSource": "playground"}
)

print(f"‚úì Deployed version {latest_version.version}")

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://learn.microsoft.com/azure/databricks/generative-ai/deploy-agent) for details