# Mosaic AI Agent Framework: Author and deploy a simple OpenAI agent

This notebook demonstrates how to author a OpenAI agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to:
- Author a OpenAI agent with `ChatAgent`
- Manually test the agent's output
- Log and deploy the agent

To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/author-agent) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/create-chat-model)).

## Prerequisites

- Address all `TODO`s in this notebook.

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] openai databricks-agents uv
dbutils.library.restartPython()


## Define the agent in code
Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.


In [0]:
%%writefile rag_agent.py

from typing import Any, Generator, Optional

import mlflow
from databricks.sdk import WorkspaceClient
from mlflow.entities import SpanType
from mlflow.pyfunc.model import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.openai.autolog()

# TODO: Replace with your model serving endpoint
LLM_ENDPOINT_NAME = "databricks-llama-4-maverick"


class SimpleChatAgent(ChatAgent):
    def __init__(self):
        self.workspace_client = WorkspaceClient()
        self.client = self.workspace_client.serving_endpoints.get_open_ai_client()
        self.llm_endpoint = LLM_ENDPOINT_NAME
    
    def prepare_messages_for_llm(self, messages: list[ChatAgentMessage]) -> list[dict[str, Any]]:
        """Filter out ChatAgentMessage fields that are not compatible with LLM message formats"""
        compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"]
        m = [
            {k: v for k, v in m.model_dump_compat(exclude_none=True).items() if k in compatible_keys} for m in messages
        ]

        print(m)
        return m

    @mlflow.trace(span_type=SpanType.RETRIVAL)
    def retrieve_context(self, question):
        if "mlflow" in question.lower():
            return ["MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It provides tools for experiment tracking, model packaging, and deployment."]
        elif "genie" in question.lower():
            return ["Genie is a Databricks feature that allows business teams to interact with their data using natural language. It uses generative AI tailored to your organization's terminology and data, with the ability to monitor and refine its performance through user feedback."
        else:
            return ["General information about machine learning and data science."]

    @mlflow.trace(span_type=SpanType.AGENT)
    def augment_query(self, question, question_context):
        return f"""

    You are a question answering bot. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

    ======================================================================

    Context:{'''
    ----------------------------------------------------------------------
    '''.join([c for c in question_context])}.
    ======================================================================

    Question: {question}

    """

    def predict_common(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        
        question_context = self.retrieve_context(question=messages[-1].content)
        augmented_query = self.augment_query(
            question=messages[-1].content, question_context=question_context
        )
        messages[-1].content = augmented_query

        return messages

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        
        messages = self.predict_common(messages, context, custom_inputs)

        resp = self.client.chat.completions.create(
            model=self.llm_endpoint,
            messages=self.prepare_messages_for_llm(messages),
        )

        return ChatAgentResponse(
            messages=[ChatAgentMessage(**resp.choices[0].message.to_dict(), id=resp.id)],
        )

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        
        messages = self.predict_common(messages, context, custom_inputs)

        for chunk in self.client.chat.completions.create(
            model=self.llm_endpoint,
            messages=self.prepare_messages_for_llm(messages),
            stream=True,
        ):
            if not chunk.choices or not chunk.choices[0].delta.content:
                continue

            yield ChatAgentChunk(
                delta=ChatAgentMessage(
                    **{
                        "role": "assistant",
                        "content": chunk.choices[0].delta.content,
                        "id": chunk.id,
                    }
                )
            )


from mlflow.models import set_model

AGENT = SimpleChatAgent()
set_model(AGENT)

## Test the agent

Interact with the agent to test its output. 

Since you manually traced methods within `ChatAgent`, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
dbutils.library.restartPython()

In [0]:
from rag_agent import AGENT

AGENT.predict({"messages": [{"role": "user", "content": "What is a genie space?"}]})


### Log the `agent` as an MLflow model
Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [0]:
import mlflow
from rag_agent import LLM_ENDPOINT_NAME
from mlflow.models.resources import DatabricksServingEndpoint
from pkg_resources import get_distribution


resources = [
    DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
]

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="rag_agent.py",
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"openai=={get_distribution('openai').version}",
        ],
        resources=resources,
    )

## Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"messages": [{"role": "user", "content": "Hello!"}]},
    env_manager="uv",
)

In [0]:
import mlflow
traces = mlflow.search_traces()
traces

In [0]:
traces

In [0]:

import mlflow.genai.datasets

evaluation_dataset_table_name = "ml_demo.default.simple_rag"

# eval_dataset = mlflow.genai.datasets.create_dataset(
#     uc_table_name=evaluation_dataset_table_name,
# )
# eval_dataset.merge_records(traces)


# note that even if you have just created it by un-commenting the above you need to read it in again for it to work with mlflow.genai.evaluate.  I think this is a bug

eval_dataset = mlflow.genai.get_dataset(evaluation_dataset_table_name)



In [0]:
import mlflow


#model = mlflow.pyfunc.load_model(logged_agent_info.model_uri)
model = mlflow.pyfunc.load_model(f"runs:/17663e36bb1e4b6398ae9f3633d2a4db/agent")

In [0]:
def predict_fn(messages: list, context: dict, custom_inputs: dict) -> dict:
  return model.predict({"messages": messages})

In [0]:
from mlflow.genai.scorers import RelevanceToQuery

mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=predict_fn,
    scorers=[
      RelevanceToQuery(),
    ])

## Register the model to Unity Catalog

Before you deploy the agent, you must register the agent to Unity Catalog.

- **TODO** Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "ml_demo"
schema = "default"
model_name = "simple-rag-agent"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(model_uri=f"runs:/17663e36bb1e4b6398ae9f3633d2a4db/agent", name=UC_MODEL_NAME)

## Deploy the agent

In [0]:
from databricks import agents

agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, scale_to_zero=True)

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).

## Collect Traces

Use your model agent in the playground and notice that traces are collected in the associated experiment


looks like a table was created when model was deployed, but it contains older traces, not ones from interaction in plaground.

