## Develop AI Agent

In [0]:
%pip install -U -qqq mlflow-skinny[databricks] databricks-langchain databricks-agents uv langgraph-supervisor==0.0.29
dbutils.library.restartPython()

In [0]:
import mlflow
import yaml
from databricks_langchain import  ChatDatabricks, DatabricksFunctionClient, UCFunctionToolkit, set_uc_function_client

from pprint import pprint
from langchain_core.messages import SystemMessage, HumanMessage


from langgraph.prebuilt import ToolNode, tools_condition

from IPython.display import Image, display
from langgraph.graph import MessagesState 
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver

# Custom imports
from configs import variables
from prompts import data_inspector


from langchain.agents import AgentExecutor, create_tool_calling_agent

from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from typing import Annotated, Dict, List, Optional
from langgraph.graph.message import add_messages, AnyMessage

In [0]:
# Define client to get all the functions into Unity Catalog
client = DatabricksFunctionClient()
functions = client.list_functions(catalog=variables.CATALOG_NAME, schema=variables.SCHEMA_NAME)
func_names = []
for f in functions:
    func_names.append(f"{variables.CATALOG_NAME}.{variables.SCHEMA_NAME}.{f.name}")

# assign function to UCFunctionToolkit
toolkit = UCFunctionToolkit(function_names=func_names)
tools_uc = toolkit.tools


In [0]:
# Define planner state
class DataInspectorState(TypedDict):
    """The state of the agent."""
    messages: Annotated[list[AnyMessage], add_messages]

In [0]:
## Define which LLM endpoint to use
chat_model = ChatDatabricks(endpoint=variables.LLM_ENDPOINT_NAME)
#chat_model.invoke("Ciao")

In [0]:
# Create system message for Agents
prompt = SystemMessage(content=data_inspector.system_prompt)

# Create an Agent
def data_inspector_agent(state: DataInspectorState):
    response = [chat_model.invoke([prompt] + state["messages"])]
    last_message = response[-1]
    return {"messages": response}

In [0]:
# Build graph
builder = StateGraph(DataInspectorState)
builder.add_node("data_inspector", data_inspector_agent)

# Logic
builder.add_edge(START, "data_inspector")
builder.add_node("tools", ToolNode(tools_uc
                                   ))

builder.add_edge("data_inspector", "tools")
builder.add_conditional_edges(
    "data_inspector",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "data_inspector")


# Add memory. In this way we can save the state of the agent and mantain the memory to save the chat history. 
# This can be extended with external memory with Lakebase
memory = MemorySaver()
react_graph = builder.compile()


# Show the graph
display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))

In [0]:

messages = [HumanMessage(content="""
  {
    "domain": "retail",
    "num_records": 1000,
    "uc_catalog_source": "financial",
    "uc_schema_source": "sales",
    "uc_table_source": "sales,
    "uc_catalog_target": "financial",
    "uc_schema_target": "sales",
    "uc_table_target": "sales"

  }""")]

config = {"configurable": {"thread_id": "alessandro2"} }

messages = react_graph.invoke({"messages": messages})

## Mosaic AI Agent wrap  an agent into ResponseAgent
How to author an LangGraph agent and wrap it using the ResponsesAgent interface to make it compatible with Mosaic AI. [Response Agent](https://mlflow.org/docs/latest/genai/serving/responses-agent/)

The ResponsesAgent extends MLflow's PyFunc model interface to support conversational AI applications that require advanced capabilities such as multi-turn dialogue, tool-calling, multi-agent orchestration, and compatibility with OpenAI's Responses API and MLflow model tracking.\

You can use this notebook [Notebook](https://docs.databricks.com/aws/en/notebooks/source/generative-ai/responses-agent-langgraph.html)

### Create and wrap the agent

In [0]:
%%writefile agents/planner_agent.py
import yaml
from databricks_langchain import  ChatDatabricks
from typing import Any, Generator, Literal

from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage

from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import MemorySaver

from langchain.agents import AgentExecutor, create_tool_calling_agent

from pydantic import BaseModel
from typing_extensions import TypedDict
from typing import Annotated
from langgraph.graph.message import add_messages, AnyMessage

import mlflow
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)

from configs import variables
from prompts import planner


# Define LLm endpoint
chat_model = ChatDatabricks(endpoint=variables.LLM_ENDPOINT_NAME)

# Set State
class PlannerState(TypedDict):
    """The state of the agent."""
    messages: Annotated[list[AnyMessage], add_messages]

# Create system message for Agents
prompt = SystemMessage(content=planner.system_prompt)

# Create an Agent
def planner_agent(state: PlannerState):
    response = [chat_model.invoke([prompt] + state["messages"])]
    last_message = response[-1]
    return {"messages": response}


class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def _langchain_to_responses(self, message: BaseMessage) -> list[dict[str, Any]]:
        "Convert from ChatCompletion dict to Responses output item dictionaries. Ignore user and human messages"
        message = message.model_dump()
        role = message["type"]
        output = []
        if role == "ai":
            if message.get("content"):
                output.append(
                    self.create_text_output_item(
                        text=message["content"],
                        id=message.get("id") or str(uuid4()),
                    )
                )
            if tool_calls := message.get("tool_calls"):
                output.extend(
                    [
                        self.create_function_call_item(
                            id=message.get("id") or str(uuid4()),
                            call_id=tool_call["id"],
                            name=tool_call["name"],
                            arguments=json.dumps(tool_call["args"]),
                        )
                        for tool_call in tool_calls
                    ]
                )

        elif role == "tool":
            output.append(
                self.create_function_call_output_item(
                    call_id=message["tool_call_id"],
                    output=message["content"],
                )
            )
        elif role == "user" or "human":
            pass
        return output

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(self, request: ResponsesAgentRequest,) -> Generator[ResponsesAgentStreamEvent, None, None]:
        cc_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])
        first_name = True
        seen_ids = set()

        for event_name, events in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates"]):
            if event_name == "updates":
                if not first_name:
                    node_name = tuple(events.keys())[0]  # assumes one name per node
                    yield ResponsesAgentStreamEvent(
                        type="response.output_item.done",
                        item=self.create_text_output_item(
                            text=f"<name>{node_name}</name>",
                            id=str(uuid4()),
                        ),
                    )
                for node_data in events.values():
                    for msg in node_data["messages"]:
                        if msg.id not in seen_ids:
                            print(msg.id, msg)
                            seen_ids.add(msg.id)
                            for item in self._langchain_to_responses(msg):
                                yield ResponsesAgentStreamEvent(
                                    type="response.output_item.done", item=item
                                )
            first_name = False


# Build graph
builder = StateGraph(PlannerState)
builder.add_node("planner", planner_agent)

# Logic
builder.add_edge(START, "planner")


# Add memory. In this way we can save the state of the agent and mantain the memory to save the chat history. 
# This can be extended with external memory with Lakebase
memory = MemorySaver()
react_graph = builder.compile()

mlflow.langchain.autolog()
AGENT = LangGraphResponsesAgent(react_graph)
mlflow.models.set_model(AGENT)

### Test the Agent

In [0]:
dbutils.library.restartPython()

In [0]:
from agents.planner_agent import AGENT
from configs import variables

# TODO: Replace this placeholder `input_example` with a domain-specific prompt for your agent.
input_example = {"input": [{"role": "user", "content": "Hi what can I do with you?"}]}

AGENT.predict(input_example)

### Log Agent as an MLFLow model

In [0]:
import mlflow
from configs import variables
from mlflow.models.resources import DatabricksServingEndpoint
from pkg_resources import get_distribution

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="planner_agent",
        python_model="agents/planner_agent.py",
        code_paths=["configs", "prompts"],
        pip_requirements=[
            "databricks-langchain",
            f"langgraph=={get_distribution('langgraph').version}",
            f"databricks-connect=={get_distribution('databricks-connect').version}",
        ],
        resources= [DatabricksServingEndpoint(endpoint_name=variables.LLM_ENDPOINT_NAME)]
    )

### Pre Deployment Validation

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/planner_agent",
    input_data={"input": [{"role": "user", "content": "Hi what can I do with you?"}]},
    env_manager="uv",
)

### Register the model to Unity Catalog

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
UC_MODEL_NAME = f"{variables.CATALOG_NAME}.{variables.SCHEMA_NAME}.planner_agent"

# register the model to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME)

### Deploy Agent

In [0]:
from databricks import agents

agents.deploy(
    UC_MODEL_NAME,
    uc_registered_model_info.version,
    tags={"endpointSource": "docs"},
)