In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import art
import asyncio
from dotenv import load_dotenv
import json
import random
import re
from typing import TypedDict

from art.local import LocalBackend

load_dotenv()


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]


puzzles_path = "../data/temporal-clue/puzzles.json"
puzzles: list[TemporalCluePuzzle] = json.loads(open(puzzles_path).read())
val_puzzles = puzzles[:64]
test_puzzles = puzzles[64:128]
train_puzzles = puzzles[128:]
random.seed(42)
random.shuffle(train_puzzles)


async def rollout(model: art.Model, puzzle: TemporalCluePuzzle) -> art.Trajectory:
    messages: art.Messages = [{"role": "user", "content": puzzle["prompt"] + " /no_think"}]
    client = model.openai_client()
    chat_completion = await client.chat.completions.create(
        messages=messages, model=model.name
    )
    choice = chat_completion.choices[0]
    content = choice.message.content
    assert isinstance(content, str)
    num_correct = 0
    for key, value in puzzle["solution"].items():
        if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
            match = matches[-1]
            if match.strip().lower() == value.lower():
                num_correct += 1
    reward = acc = num_correct / len(puzzle["solution"])
    return art.Trajectory(
        messages_and_choices=[*messages, choice], reward=reward, metrics={"acc": acc}
    )


model = art.TrainableModel(
    name="005",
    project="temporal-clue",
    base_model="Qwen/Qwen2.5-7B-Instruct"
)
backend = LocalBackend()
await model.register(backend)

stride = 8
for i in range(await model.get_step(), 1_000):
    val_groups, train_groups = await asyncio.gather(
        art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(rollout(model, puzzle) for _ in range(1))
                for puzzle in val_puzzles
            ),
            pbar_desc="val",
            pbar_total_completion_tokens=False,
        ),
        art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(rollout(model, puzzle) for _ in range(32))
                for puzzle in train_puzzles[i * stride : (i + 1) * stride]
            ),
            pbar_desc="train",
            pbar_total_completion_tokens=False,
        ),
    )
    await model.log(val_groups)
    await model.delete_checkpoints()
    await model.train(
        train_groups,
        config=art.TrainConfig(learning_rate=1e-6),
    )

[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO 07-16 17:11:20 [__init__.py:244] Automatically detected platform cuda.



Please restructure your imports with 'import unsloth' at the top of your file.
  import unsloth  # type: ignore # noqa: F401


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-16 17:11:39 [__init__.py:244] Automatically detected platform cuda.
==((====))==  Unsloth 2025.6.12: Fast Qwen2 patching. Transformers: 4.53.2. vLLM: 0.9.1.
   \\   /|    NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.097 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit with actual GPU utilization = 78.47%
Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.1 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.
Unsloth: vLLM's KV Cache can use up to 56.2 GB. 