In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import art
import asyncio
from dotenv import load_dotenv
import json
import openai
import random
import re
from typing import TypedDict
from openpipe.client import OpenPipe
import hashlib
import time

load_dotenv()

op_client = OpenPipe()
print("OpenPipe client initialized")


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]

def generate_puzzle_key(puzzle: TemporalCluePuzzle) -> str:
    # hash the puzzle prompt, truncate hex to 10 characters
    return str(puzzle["num_clues"]) + "-" + hashlib.sha256(puzzle["prompt"].encode()).hexdigest()[:10]


puzzles: list[TemporalCluePuzzle] = json.load(open("./data/temporal-clue/puzzles.json"))
val_puzzles = puzzles[:64]
test_puzzles = puzzles[64:128]
train_puzzles = puzzles[128:]
random.seed(42)
random.shuffle(train_puzzles)


api = art.LocalAPI(wandb_project="agent-reinforcement-training")
model = await api.get_or_create_model(
    name="temporal-clue-001", base_model="NousResearch/Hermes-2-Theta-Llama-3-8B"
)


async def rollout(
    client: openai.AsyncOpenAI, puzzle: TemporalCluePuzzle, iteration: int, is_validation: bool
) -> art.Trajectory:
    messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}]

    requested_at = int(time.time() * 1000)
    chat_completion = await client.chat.completions.create(
        messages=messages, model=model.name
    )

    puzzle_key = generate_puzzle_key(puzzle)
    print(f"puzzle_key: {puzzle_key}")
    op_client.report(
        requested_at=requested_at,
        received_at=int(time.time() * 1000),
        req_payload={
            "model": model.name,
            "messages": messages,
            "metadata": {
                "notebook-id": "temporal-clue",
                "iteration": str(iteration),
                "validation": str(is_validation),
                "puzzle_key": puzzle_key,
            },
        },
        resp_payload=chat_completion,
        status_code=200,
    )

    choice = chat_completion.choices[0]
    content = choice.message.content
    assert isinstance(content, str)
    num_correct = 0
    for key, value in puzzle["solution"].items():
        if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
            match = matches[-1]
            if match.strip().lower() == value.lower():
                num_correct += 1
    reward = acc = num_correct / len(puzzle["solution"])
    return art.Trajectory(
        messages_and_choices=[*messages, choice], reward=reward, metrics={"acc": acc}
    )


stride = 32
for i in range(await model.get_iteration(), 10):
    async with model.openai_client(
        estimated_completion_tokens=350, verbosity=2
    ) as openai_client:
        val_groups, train_groups = await asyncio.gather(
            art.gather_groups(
                (
                    (rollout(openai_client, puzzle, i, is_validation=True) for _ in range(2))
                    for puzzle in val_puzzles
                ),
                pbar_desc="val",
                stream_chat_completions=8,
            ),
            art.gather_groups(
                (
                    (rollout(openai_client, puzzle, i, is_validation=False) for _ in range(50))
                    for puzzle in train_puzzles[i * stride : (i + 1) * stride]
                ),
                pbar_desc="train",
            ),
        )
    await model.log(val_groups)
    await model.clear_iterations()
    await model.tune(
        train_groups, config=art.TuneConfig(plot_tensors=True, verbosity=2)
    )

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
OpenPipe client initialized
$ vllm serve /root/sky_workdir/examples/.art/models/temporal-clue-001/0003 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=2048 --max-num-batched-tokens=16384 --num-scheduler-steps=8 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=80 --tensor-parallel-size=1 --tool-call-parser=hermes --served-model-name=temporal-clue-001 --port=8000 --api-key=default


Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.42it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.24it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.21it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.62it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.46it/s]

INFO:     Started server process [23754]
INFO:     Waiting for application startup.
INFO:     Application startup complete.


INFO:     127.0.0.1:39558 - "POST /v1/chat/completions HTTP/1.1" 200 OK


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


vLLM server started succesfully. Logs can be found at ./logs/vllm.log


[34m[1mwandb[0m: Currently logged in as: [33marctic_fly[0m ([33mbased-op[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using previous iteration 94.875 completion tokens per request as estimate


val:   0%|          | 0/128 [00:00<?, ?it/s]

train:   0%|          | 0/1600 [00:00<?, ?it/s]

puzzle_key: 47-65a15b6933
puzzle_key: 25-1d0a0a995c
puzzle_key: 25-1d0a0a995c
puzzle_key: 14-6b84a74264
puzzle_key: 97-ab86d0715a
puzzle_key: 38-6f5312aa1c
puzzle_key: 97-ab86d0715a
puzzle_key: 25-1d0a0a995c
puzzle_key: 38-6f5312aa1c
puzzle_key: 47-65a15b6933
puzzle_key: 38-6f5312aa1c
puzzle_key: 25-1d0a0a995c
puzzle_key: 97-ab86d0715a
puzzle_key: 38-6f5312aa1c
puzzle_key: 68-4274330eb1
puzzle_key: 28-c2ca749b9a
puzzle_key: 38-6f5312aa1c
puzzle_key: 38-6f5312aa1c
puzzle_key: 38-6f5312aa1c
puzzle_key: 47-65a15b6933
puzzle_key: 97-ab86d0715a
puzzle_key: 47-65a15b6933
puzzle_key: 38-6f5312aa1c
puzzle_key: 47-65a15b6933
puzzle_key: 97-ab86d0715a
puzzle_key: 38-6f5312aa1c
puzzle_key: 14-6b84a74264
puzzle_key: 38-6f5312aa1c
puzzle_key: 25-1d0a0a995c
puzzle_key: 97-ab86d0715a
puzzle_key: 38-6f5312aa1c
puzzle_key: 47-65a15b6933
puzzle_key: 47-65a15b6933
puzzle_key: 38-6f5312aa1c
puzzle_key: 28-c2ca749b9a
puzzle_key: 97-ab86d0715a
puzzle_key: 97-ab86d0715a
puzzle_key: 97-ab86d0715a
puzzle_key: 