In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import openai
from dotenv import load_dotenv
from generate_images import generate_yes_no_maybe_prompts, save_prompt_images

import art
from art.local import LocalBackend

load_dotenv()

backend = LocalBackend()
model = art.TrainableModel(
    name="009",
    project="yes-no-maybe-vision",
    base_model="Qwen/Qwen2.5-VL-7B-Instruct",
)
await model.register(backend)


async def rollout(client: openai.AsyncOpenAI, image_path: str) -> art.Trajectory:
    messages: art.Messages = [
        {
            "role": "user",
            "content": [{"type": "image_url", "image_url": {"url": image_path}}],
        }
    ]
    chat_completion = await client.chat.completions.create(
        model=model.name, messages=messages, max_tokens=100, timeout=100
    )
    choice = chat_completion.choices[0]
    content = choice.message.content
    assert isinstance(content, str)
    if content == "yes":
        reward = 0.5
    elif content == "no":
        reward = 0.75
    elif content == "maybe":
        reward = 1.0
    else:
        reward = 0.0
    return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)


image_paths = save_prompt_images(
    generate_yes_no_maybe_prompts(),
    "/tmp/yes-no-maybe-vision/images",
    image_size=(256, 256),
    margin_px=16,
    font_path=None,
)


openai_client = model.openai_client()
for _ in range(await model.get_step(), 1_000):
    train_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(openai_client, image_path.as_uri()) for _ in range(32)
            )
            for image_path in image_paths
        )
    )
    await model.train(
        train_groups,
        config=art.TrainConfig(learning_rate=1e-4),
    )