# Mosaic AI Agent Framework: Author and deploy a multi-agent system with Genie and Serving Endpoints

This notebook demonstrates how to build a multi-agent system using Mosaic AI Agent Framework and [LangGraph](https://blog.langchain.dev/langgraph-multi-agent-workflows/), where [Genie](https://www.databricks.com/product/ai-bi/genie) is one of the agents.
In this notebook, you:
1. Author a multi-agent system using LangGraph.
1. Wrap the LangGraph agent with MLflow `ResponsesAgent` to ensure compatibility with Databricks features.
1. Manually test the multi-agent system's output.
1. Log and deploy the multi-agent system.

This example is based on [LangGraph documentation - Multi-agent supervisor example](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/agent_supervisor.md)

## Why use a Genie agent?

Multi-agent systems consist of multiple AI agents working together, each with specialized capabilities. As one of those agents, Genie allows users to interact with their structured data using natural language. Unlike SQL functions which can only run pre-defined queries, Genie has the flexibility to create novel queries to answer user questions.

## Prerequisites

- Address all `TODO`s in this notebook.
- Create a Genie Space, see Databricks documentation ([AWS](https://docs.databricks.com/aws/genie/set-up) | [Azure](https://learn.microsoft.com/azure/databricks/genie/set-up)).

In [0]:
%pip install -U -qqq mlflow-skinny[databricks] databricks-langchain databricks-agents uv langgraph-supervisor==0.0.29
dbutils.library.restartPython()


## Define the multi-agent system

Create a multi-agent system in LangGraph using a supervisor agent node with one or more of the following subagents:
- **GenieAgent**: A LangChain runnable that allows you to easily interact with your Genie Space to query structured data.
- **Custom serving agent**: An agent that is already hosted as an existing endpoint on Databricks.
- **In-code tool-calling agent**: An agent that calls Unity Catalog function tools, defined within this notebook. This example uses `system.ai.python_exec`, but for examples of other tools you can add to your agents, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-tool) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-tool)).

The supervisor agent is responsible for creating and routing tool calls to each of your subagents, passing only the context necessary. You can modify this behavior and pass along the entire message history if desired. See the [LangGraph docs](https://langchain-ai.github.io/langgraph/reference/supervisor/) for more information.

#### Write agent code to file

Define the agent code in a single cell below. This lets you write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.


In [0]:
%%writefile agent.py
import json
from typing import Any, Generator, Literal
from uuid import uuid4

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from databricks_langchain.genie import GenieAgent
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)
from pydantic import BaseModel, model_validator

client = DatabricksFunctionClient()
set_uc_function_client(client)

########################################
# Create your LangGraph Supervisor Agent
########################################

GENIE = "genie"


class ServedSubAgent(BaseModel):
    endpoint_name: str
    name: str
    task: Literal["agent/v1/responses", "agent/v1/chat", "agent/v2/chat"]
    description: str


class Genie(BaseModel):
    space_id: str
    name: str
    task: str = GENIE
    description: str



class InCodeSubAgent(BaseModel):
    tools: list[str]
    name: str
    description: str


TOOLS = []

def stringify_content(state):
    msgs = state["messages"]
    if isinstance(msgs[-1].content, list):
        msgs[-1].content = json.dumps(msgs[-1].content, indent=4)
    return {"messages": msgs}


def create_langgraph_supervisor(
    llm: Runnable,
    externally_served_agents: list[ServedSubAgent] = [],
    in_code_agents: list[InCodeSubAgent] = [],
):
    agents = []
    agent_descriptions = ""

    # Process inline code agents
    for agent in in_code_agents:
        agent_descriptions += f"- {agent.name}: {agent.description}\n"
        uc_toolkit = UCFunctionToolkit(function_names=agent.tools)
        TOOLS.extend(uc_toolkit.tools)
        agents.append(create_react_agent(llm, tools=uc_toolkit.tools, name=agent.name))

    # Process served endpoints and Genie Spaces
    for agent in externally_served_agents:
        agent_descriptions += f"- {agent.name}: {agent.description}\n"
        if isinstance(agent, Genie):
            genie_agent = GenieAgent(
                genie_space_id=agent.space_id,
                genie_agent_name=agent.name,
                description=agent.description,
            )
            genie_agent.name = agent.name
            agents.append(genie_agent)
        else:
            model = ChatDatabricks(
                endpoint=agent.endpoint_name, use_responses_api="responses" in agent.task
            )
            # Disable streaming for subagents for ease of parsing
            model._stream = lambda x: model._stream(**x, stream=False)
            agents.append(
                create_react_agent(
                    model,
                    tools=[],
                    name=agent.name,
                    post_model_hook=stringify_content,
                )
            )

    # TODO: The supervisor prompt includes agent names/descriptions as well as general
    # instructions. You can modify this to improve quality or provide custom instructions.
    prompt = f"You are a supervisor in a multi-agent system. If you have enough information based on the chat history to answer the user's query, provide a summarized response to the query, even if it's been answered before. If you still need more information, route requests to the appropriate sub-agent:\n{agent_descriptions}"

    return create_supervisor(
        agents=agents,
        model=llm,
        prompt=prompt,
        add_handoff_messages=True,
    ).compile()


##########################################
# Wrap LangGraph Supervisor as a ResponsesAgent
##########################################


class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def _langchain_to_responses(self, message: BaseMessage) -> list[dict[str, Any]]:
        "Convert from ChatCompletion dict to Responses output item dictionaries. Ignore user and human messages"
        message = message.model_dump()
        role = message["type"]
        output = []
        if role == "ai":
            if message.get("content"):
                output.append(
                    self.create_text_output_item(
                        text=message["content"],
                        id=message.get("id") or str(uuid4()),
                    )
                )
            if tool_calls := message.get("tool_calls"):
                output.extend(
                    [
                        self.create_function_call_item(
                            id=message.get("id") or str(uuid4()),
                            call_id=tool_call["id"],
                            name=tool_call["name"],
                            arguments=json.dumps(tool_call["args"]),
                        )
                        for tool_call in tool_calls
                    ]
                )

        elif role == "tool":
            output.append(
                self.create_function_call_output_item(
                    call_id=message["tool_call_id"],
                    output=message["content"],
                )
            )
        elif role == "user" or "human":
            pass
        return output

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        cc_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])
        first_name = True
        seen_ids = set()

        for event_name, events in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates"]):
            if event_name == "updates":
                if not first_name:
                    node_name = tuple(events.keys())[0]  # assumes one name per node
                    yield ResponsesAgentStreamEvent(
                        type="response.output_item.done",
                        item=self.create_text_output_item(
                            text=f"<name>{node_name}</name>",
                            id=str(uuid4()),
                        ),
                    )
                for node_data in events.values():
                    for msg in node_data["messages"]:
                        if msg.id not in seen_ids:
                            print(msg.id, msg)
                            seen_ids.add(msg.id)
                            for item in self._langchain_to_responses(msg):
                                yield ResponsesAgentStreamEvent(
                                    type="response.output_item.done", item=item
                                )
            first_name = False


#######################################################
# Configure the Foundation Model and Serving Sub-Agents
#######################################################

# TODO: Replace with your model serving endpoint
LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

# TODO: Add the necessary information about each of your serving/Genie Space subagents
# in the list below, following the examples.
#  - Your agent descriptions are crucial for improving quality. Include as much detail as possible.
EXTERNALLY_SERVED_AGENTS = [
    Genie(
        space_id="01f0939cca1c1a1d87e507943f57333a",
        name="revenue-genie",
        description="This agent can answer questions about revenue related to the top 10 companies in the S&P 500.",
    ),
    # ServedSubAgent(
    #     endpoint_name="cities-agent",
    #     name="city-agent", # choose a semantically relevant name for your agent
    #     task="agent/v1/responses",
    #     description="This agent can answer questions about the best cities to visit in the world.",
    # ),
]

############################################################
# Create additional agents in code
############################################################

# TODO: Fill the following with UC function-calling agents. The tools parameter is a list of UC function names that you want your agent to call.
IN_CODE_AGENTS = [
    InCodeSubAgent(
        tools=["system.ai.*"],
        name="code execution agent",
        description="The code execution agent specializes in solving programming challenges, generating code snippets, debugging issues, and explaining complex coding concepts.",
    )
]

#################################################
# Create supervisor and set up MLflow for tracing
#################################################

supervisor = create_langgraph_supervisor(llm, EXTERNALLY_SERVED_AGENTS, IN_CODE_AGENTS)

mlflow.langchain.autolog()
AGENT = LangGraphResponsesAgent(supervisor)
mlflow.models.set_model(AGENT)

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

**Important:** LangGraph internally uses exceptions (something like `Command` or `ParentCommand`) to switch between nodes. These particular exceptions may appear in your MLflow traces as Events, but this behavior is expected and should not be a cause for concern.

In [0]:
dbutils.library.restartPython()

In [0]:
from agent import AGENT

# TODO: Replace this placeholder `input_example` with a domain-specific prompt for your agent.
input_example = {"input": [{"role": "user", "content": "Which companies do I have revenue data for"}]}

AGENT.predict(input_example)

In [0]:
for event in AGENT.predict_stream(input_example):
  print(event.model_dump(exclude_none=True))

## Log the agent as an MLflow model

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

### Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model().`
  - **TODO**: If your Unity Catalog tool queries a [vector search index](docs link) or leverages [external functions](docs link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough) | [Azure](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough)).

  - **TODO**: Add the SQL Warehouse or tables powering your Genie space to enable passthrough authentication. ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough) | [Azure](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough)). If your genie space uses "embedded credentials" then you do not have to add this.

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import EXTERNALLY_SERVED_AGENTS, LLM_ENDPOINT_NAME, TOOLS, Genie
from databricks_langchain import UnityCatalogTool, VectorSearchRetrieverTool
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksGenieSpace,
    DatabricksServingEndpoint,
    DatabricksSQLWarehouse,
    DatabricksTable
)
from pkg_resources import get_distribution

# TODO: Manually include underlying resources if needed. See the TODO in the markdown above for more information.
resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
# TODO: Add SQL Warehouses and delta tables powering the Genie Space
resources.append(DatabricksSQLWarehouse(warehouse_id=""))
resources.append(DatabricksTable(table_name=""))

# Add tools from Unity Catalog
for tool in TOOLS:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

# Add serving endpoints and Genie Spaces
for agent in EXTERNALLY_SERVED_AGENTS:
    if isinstance(agent, Genie):
        resources.append(DatabricksGenieSpace(genie_space_id=agent.space_id))
    else:
        resources.append(DatabricksServingEndpoint(endpoint_name=agent.endpoint_name))

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
            f"langgraph-supervisor=={get_distribution('langgraph-supervisor').version}",
        ],
    )

## Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).

In [0]:
import mlflow
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data=input_example,
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = ""
schema = ""
model_name = ""
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [0]:
from databricks import agents

agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags={"endpointSource": "docs"})

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).