In [0]:
%%capture --no-stderr
%pip install -U langgraph langgraph-supervisor wikipedia langchain_community databricks_langchain

In [0]:
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
web_search = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())



### Research agent



In [0]:
from langgraph.prebuilt import create_react_agent

from databricks_langchain import (
    ChatDatabricks
)

LLM_ENDPOINT_NAME = "databricks-gpt-oss-120b"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

research_agent = create_react_agent(
    model=llm,
    tools=[web_search],
    prompt=(
        "You are a research agent.\n\n"
        "INSTRUCTIONS:\n"
        "- Assist ONLY with research-related tasks, DO NOT do any math\n"
        "- After you're done with your tasks, respond to the supervisor directly\n"
        "- Respond ONLY with the results of your work, do NOT include ANY other text."
    ),
    name="research_agent",
)

In [0]:
def add(a: float, b: float):
    """Add two numbers."""
    return a + b


def multiply(a: float, b: float):
    """Multiply two numbers."""
    return a * b


def divide(a: float, b: float):
    """Divide two numbers."""
    return a / b


math_agent = create_react_agent(
    model=llm,
    tools=[add, multiply, divide],
    prompt=(
        "You are a math agent.\n\n"
        "INSTRUCTIONS:\n"
        "- Assist ONLY with math-related tasks\n"
        "- After you're done with your tasks, respond to the supervisor directly\n"
        "- Respond ONLY with the results of your work, do NOT include ANY other text."
    ),
    name="math_agent",
)

In [0]:
from langgraph_supervisor import create_supervisor

supervisor = create_supervisor(
    model=llm,
    agents=[research_agent, math_agent],
    prompt=(
        "You are a supervisor managing two agents:\n"
        "- a research agent. Assign research-related tasks to this agent\n"
        "- a math agent. Assign math-related tasks to this agent\n"
        "Assign work to one agent at a time, do not call agents in parallel.\n"
        "Do not do any work yourself."
    ),
    add_handoff_back_messages=True,
    output_mode="full_history",
).compile()

In [0]:
from IPython.display import display, Image

display(Image(supervisor.get_graph().draw_mermaid_png()))

In [0]:
for chunk in supervisor.stream(
    {
        "messages": [
            {
                "role": "user",
                "content": "find US and New York state GDP in 2024. what % of US GDP was New York state?",
            }
        ]
    },
):
    print(chunk)

In [0]:
# code is taken from https://docs.databricks.com/aws/en/notebooks/source/generative-ai/responses-agent-langgraph.html

import json
from typing import Any, Generator
from uuid import uuid4

import mlflow
from databricks_langchain import ChatDatabricks, UCFunctionToolkit
from langchain_core.messages import AIMessageChunk
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import convert_to_messages


from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)


############################################
# Bridge to MLflow ResponsesAgent
############################################
class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent):
        self.agent = agent

    def _responses_to_cc(self, message: dict[str, Any]) -> list[dict[str, Any]]:
        # Simplified: just handle normal messages and tool calls
        msg_type = message.get("type")
        if msg_type == "function_call":
            return [
                {
                    "role": "assistant",
                    "content": "tool call",
                    "tool_calls": [
                        {
                            "id": message["call_id"],
                            "type": "function",
                            "function": {
                                "arguments": message["arguments"],
                                "name": message["name"],
                            },
                        }
                    ],
                }
            ]
        elif msg_type == "function_call_output":
            return [
                {
                    "role": "tool",
                    "content": message["output"],
                    "tool_call_id": message["call_id"],
                }
            ]
        elif msg_type == "message" and isinstance(message["content"], list):
            return [{"role": message["role"], "content": c["text"]} for c in message["content"]]
        return [{"role": message.get("role", "assistant"), "content": message.get("content", "")}]

    def _langchain_to_responses(self, messages) -> list[dict[str, Any]]:
        outputs = []
        for message in messages:
            message = message.model_dump()
            if message["type"] == "ai":
                if tool_calls := message.get("tool_calls"):
                    for tc in tool_calls:
                        outputs.append(
                            self.create_function_call_item(
                                id=message.get("id") or str(uuid4()),
                                call_id=tc["id"],
                                name=tc["name"],
                                arguments=json.dumps(tc["args"]),
                            )
                        )
                else:
                    outputs.append(
                        self.create_text_output_item(
                            text=message["content"],
                            id=message.get("id") or str(uuid4()),
                        )
                    )
            elif message["type"] == "tool":
                outputs.append(
                    self.create_function_call_output_item(
                        call_id=message["tool_call_id"],
                        output=message["content"],
                    )
                )
        return outputs

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        cc_msgs = []
        for msg in request.input:
            cc_msgs.extend(self._responses_to_cc(msg.model_dump()))

        for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]):
            if event[0] == "updates":
                for node_data in event[1].values():
                    for item in self._langchain_to_responses(node_data["messages"]):
                        yield ResponsesAgentStreamEvent(type="response.output_item.done", item=item)
            elif event[0] == "messages":
                chunk = event[1][0]
                if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
                    yield ResponsesAgentStreamEvent(
                        **self.create_text_delta(delta=content, item_id=chunk.id),
                    )

In [0]:
############################################
# Register the agent with MLflow
############################################
mlflow.langchain.autolog()
AGENT = LangGraphResponsesAgent(supervisor)
mlflow.models.set_model(AGENT)

In [0]:
result = AGENT.predict({"input": [{"role": "user", "content": "find US and New York state GDP in 2024. what % of US GDP was New York state?"}]})
print(result.model_dump(exclude_none=True))