In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import art
import asyncio
from dotenv import load_dotenv
from openai.types.chat.chat_completion import ChatCompletion

load_dotenv()

MODEL_NAME = "001"
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"

model = art.TrainableModel(
    name=MODEL_NAME, project="rock-paper-tool-use", base_model=BASE_MODEL
)
await model.register(art.LocalAPI())
client = model.openai_client()


def get_move(chat_completion: ChatCompletion) -> str:
    tool_calls = chat_completion.choices[0].message.tool_calls
    if tool_calls:
        print(tool_calls[0].function.arguments)
        return tool_calls[0].function.arguments
    return "none"


async def rollout() -> art.Trajectory:
    tools: art.Tools = [
        {
            "type": "function",
            "function": {
                "name": "play_move",
                "description": "Play a move in rock-paper-scissors",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "move": {
                            "type": "string",
                            "enum": ["rock", "paper", "scissors"],
                            "description": "The move to play",
                        }
                    },
                    "required": ["move"],
                },
            },
        }
    ]
    trajectories = [
        art.Trajectory(
            messages_and_choices=[
                {
                    "role": "user",
                    "content": "You are playing rock-paper-scissors with another player. What will your first move be?",
                }
            ],
            reward=0,
        )
        for _ in range(2)
    ]
    while True:
        chat_completions = await asyncio.gather(
            *[
                client.chat.completions.create(
                    messages=trajectory.messages(),
                    model=model,
                    tools=tools,
                )
                for trajectory, model in zip(trajectories, (MODEL_NAME, BASE_MODEL))
            ]
        )
        for trajectory, chat_completion in zip(trajectories, chat_completions):
            trajectory.messages_and_choices.append(chat_completion.choices[0])
        move1, move2 = list(map(get_move, chat_completions))
        beats = {"rock": "scissors", "paper": "rock", "scissors": "paper", "none": None}
        if beats[move1] == move2:
            trajectories[0].reward += 1
        elif beats[move2] == move1:
            trajectories[1].reward += 1
        if max(t.reward for t in trajectories) > 2:
            break
        trajectories[0].messages_and_choices.append(
            {
                "role": "user",
                "content": "The other player played "
                + move2
                + ". What will your next move be?",
            }
        )
        trajectories[1].messages_and_choices.append(
            {
                "role": "user",
                "content": "The other player played "
                + move1
                + ". What will your next move be?",
            }
        )
    return trajectories[0]


for i in range(await model.get_step(), 1_000):
    trajectories = await art.gather_trajectories(rollout() for _ in range(64))
    await model.train(
        [art.TrajectoryGroup(trajectories)],
        config=art.TrainConfig(learning_rate=5e-5),
    )