> **Warning:** There is currently a bug with tool use functionality. The issue appears to be that vLLM does not return all the token log probabilities for tool use. Further investigation is needed to determine the exact cause. For now, teaching use case-specific tool use with non-tool use models is the recommended workaround.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import art
import asyncio
from dotenv import load_dotenv
import json
import openai
import random
import re
from typing import TypedDict

load_dotenv()


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]


puzzles: list[TemporalCluePuzzle] = json.load(open("./data/temporal-clue/puzzles.json"))
val_puzzles = puzzles[:64]
test_puzzles = puzzles[64:128]
train_puzzles = puzzles[128:]
random.seed(42)
random.shuffle(train_puzzles)


api = art.LocalAPI(wandb_project="agent-reinforcement-training")
model = await api.get_or_create_model(
    name="temporal-clue-tool-use-001",
    base_model="NousResearch/Hermes-3-Llama-3.1-8B",
)


async def rollout(
    client: openai.AsyncOpenAI, puzzle: TemporalCluePuzzle
) -> art.Trajectory:
    messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}]
    tools: art.Tools = [
        {
            "type": "function",
            "function": {
                "name": "get_hints",
                "description": "A function to retrieve one or two hints. No more than 3 hints may be retrieved total. "
                "Each retrieved hint decreases your final accuracy score by 5%.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "num_hints": {
                            "type": "integer",
                            "description": "Number of hints to retrieve (1 or 2)",
                            "enum": [1, 2],
                        }
                    },
                },
            },
        }
    ]
    chat_completion = await client.chat.completions.create(
        messages=messages,
        model=model.name,
        max_tokens=2048,
        tools=tools,
        stop=["<|end_of_text|>"],
    )
    choice = chat_completion.choices[0]
    messages_and_choices = [*messages, choice]
    hints = [
        f"The answer for {key} is {value}" for key, value in puzzle["solution"].items()
    ]
    random.shuffle(hints)
    hints_shared = 0

    def get_hints(function_name: str, function_arguments: str) -> str:
        nonlocal hints_shared
        if function_name != "get_hints":
            return f"Error: unexpected function name {function_name}"
        try:
            num_hints = json.loads(function_arguments or "{}").get("num_hints", 1)
        except Exception:
            return f"Error: invalid JSON {function_arguments}"
        if num_hints not in {1, 2}:
            return f"Error: invalid number of hints {num_hints}"
        if num_hints + hints_shared > 3:
            return f"Error: cannot retrieve {num_hints} hints, already retrieved {hints_shared} hints"
        hints_shared += num_hints
        content = "Hints:"
        for _ in range(num_hints):
            content += f"\n{hints.pop()}"
        return content

    while tool_calls := choice.message.tool_calls:
        messages.append(
            {
                "role": "assistant",
                "content": choice.message.content,
                "tool_calls": [
                    {
                        "id": tool_call.id,
                        "type": "function",
                        "function": {
                            "name": tool_call.function.name,
                            "arguments": tool_call.function.arguments or "{}",
                        },
                    }
                    for tool_call in tool_calls
                ],
            }
        )
        for tool_call in tool_calls:
            messages.append(
                {
                    "role": "tool",
                    "tool_call_id": tool_call.id,
                    "content": get_hints(
                        tool_call.function.name, tool_call.function.arguments
                    ),
                }
            )
            messages_and_choices.append(messages[-1])
        try:
            chat_completion = await client.chat.completions.create(
                messages=messages,
                model=model.name,
                max_tokens=2048,
                stop=["<|end_of_text|>"],
                tools=tools,
            )
        except openai.BadRequestError:
            # Likely incorrectly formatted tool call arguments. We'll break
            # out of the loop and allow the model to (probably) fail.
            print(messages[-2].get("tool_calls"))
            break
        choice = chat_completion.choices[0]
        messages_and_choices.append(choice)

    content = choice.message.content or ""
    num_correct = 0
    for key, value in puzzle["solution"].items():
        if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
            match = matches[-1]
            if match.strip().lower() == value.lower():
                num_correct += 1
    reward = acc = num_correct / len(puzzle["solution"])
    return art.Trajectory(
        messages_and_choices=messages_and_choices,
        reward=reward - hints_shared * 0.05,
        metrics={"acc": acc, "hints": hints_shared},
        tools=tools,
    )


stride = 32
for i in range(await model.get_iteration(), 1_000):
    async with model.openai_client(
        estimated_completion_tokens=180, tool_use=True, verbosity=2
    ) as openai_client:
        val_groups, train_groups = await asyncio.gather(
            art.gather_trajectories(
                (
                    (rollout(openai_client, puzzle) for _ in range(2))
                    for puzzle in val_puzzles
                ),
                pbar_desc="val",
                stream_chat_completions=8,
            ),
            art.gather_trajectories(
                (
                    (rollout(openai_client, puzzle) for _ in range(50))
                    for puzzle in train_puzzles[i * stride : (i + 1) * stride]
                ),
                pbar_desc="train",
            ),
        )
    await model.log(val_groups)
    await model.clear_iterations()
    await model.tune(
        train_groups, config=art.TuneConfig(plot_tensors=True, verbosity=2)
    )