See our docs for an explanation of what this code is doing!

In [None]:
import os
from docent import Docent

client = Docent(
    api_key=os.getenv("DOCENT_API_KEY"),  # is default and can be omitted

    # Uncomment and adjust these if you're self-hosting
    # server_url="http://localhost:8889",
    # web_url="http://localhost:3001",
)

In [None]:
collection_id = client.create_collection(name="tau-bench example", description="example tau-bench-airline log that comes with the Docent repo")

In [None]:
from docent.samples import get_tau_bench_airline_fpath
import json
with open(get_tau_bench_airline_fpath(), "r") as f:
    tb_log = json.load(f)
print(tb_log)

In [None]:
from typing import Any

from docent.data_models import AgentRun, Transcript
from docent.data_models.chat import ChatMessage, ToolCall, parse_chat_message

def load_tau_bench_log(data: dict[str, Any]) -> AgentRun:
    traj, info, reward, task_id = data["traj"], data["info"], data["reward"], data["task_id"]

    messages: list[ChatMessage] = []
    for msg in traj:
        # Extract raw message data
        role = msg.get("role")
        content = msg.get("content", "")
        raw_tool_calls = msg.get("tool_calls")
        tool_call_id = msg.get("tool_call_id")

        # Create a message data dictionary
        message_data = {
            "role": role,
            "content": content,
        }

        # For tool messages, include the tool name
        if role == "tool":
            message_data["name"] = msg.get("name")
            message_data["tool_call_id"] = tool_call_id

        # For assistant messages, include tool calls if present
        if role == "assistant" and raw_tool_calls:
            # Convert tool calls to the expected format
            parsed_tool_calls: list[ToolCall] = []
            for tc in raw_tool_calls:
                tool_call = ToolCall(
                    id=tc.get("id"),
                    function=tc.get("function", {}).get("name"),
                    arguments=tc.get("function", {}).get("arguments", {}),
                    type="function",
                    parse_error=None,
                )
                parsed_tool_calls.append(tool_call)

            message_data["tool_calls"] = parsed_tool_calls

        # Parse the message into the appropriate type
        chat_message = parse_chat_message(message_data)
        messages.append(chat_message)

    # Extract metadata from the sample
    task_id = info["task"]["user_id"]
    scores = {"reward": round(reward, 3)}

    # Build metadata
    metadata = {
        "benchmark_id": task_id,
        "task_id": task_id,
        "model": "sonnet-35-new",
        "scores": scores,
        "additional_metadata": info,
        "scoring_metadata": info["reward_info"],
    }

    # Create the transcript and wrap in AgentRun
    transcript = Transcript(
        messages=messages,
        metadata=metadata,
    )
    agent_run = AgentRun(
        transcripts=[transcript],
        metadata=metadata,
    )

    return agent_run

In [None]:
agent_runs = [load_tau_bench_log(tb_log)]
print(agent_runs[0].to_text_new())

In [None]:
client.add_agent_runs(collection_id, agent_runs)