In [None]:
import mlflow
from mlflow.entities.span import SpanType
from mlflow.tracing.constant import SpanAttributeKey
from mlflow.tracing import set_span_chat_messages, set_span_chat_tools

# example messages and tools
messages = [
    {
        "role": "system",
        "content": "please use the provided tool to answer the user's questions",
    },
    {"role": "user", "content": "what is 1 + 1?"},
]

tools = [
    {
        "type": "function",
        "function": {
            "name": "add",
            "description": "Add two numbers",
            "parameters": {
                "type": "object",
                "properties": {
                    "a": {"type": "number"},
                    "b": {"type": "number"},
                },
                "required": ["a", "b"],
            },
        },
    }
]


@mlflow.trace(span_type=SpanType.CHAT_MODEL)
def call_chat_model(messages, tools):
    # mocking a response
    response = {
        "role": "assistant",
        "tool_calls": [
            {
                "id": "123",
                "function": {"arguments": '{"a": 1,"b": 2}', "name": "add"},
                "type": "function",
            }
        ],
    }

    combined_messages = messages + [response]

    span = mlflow.get_current_active_span()
    set_span_chat_messages(span, combined_messages)
    set_span_chat_tools(span, tools)

    return response


call_chat_model(messages, tools)

trace_id = mlflow.get_last_active_trace_id()
trace = mlflow.get_trace(trace_id)
span = trace.data.spans[0]

print("Messages: ", span.get_attribute(SpanAttributeKey.CHAT_MESSAGES))
print("Tools: ", span.get_attribute(SpanAttributeKey.CHAT_TOOLS))