Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions examples/tic_tac_toe/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

import openai
import weave
from dotenv import load_dotenv
from game_utils import (
apply_agent_move,
Expand All @@ -11,24 +12,18 @@
get_opponent_move,
render_board,
)
from openpipe.client import OpenPipe
from pydantic import BaseModel

import art

load_dotenv()

op_client = OpenPipe()
print("OpenPipe client initialized")


op_client = OpenPipe(api_key=os.getenv("OPENPIPE_API_KEY"))


class TicTacToeScenario(BaseModel):
step: int


@weave.op
@art.retry(exceptions=(openai.LengthFinishReasonError,))
async def rollout(model: art.Model, scenario: TicTacToeScenario) -> art.Trajectory:
game = generate_game()
Expand Down Expand Up @@ -57,7 +52,6 @@ async def rollout(model: art.Model, scenario: TicTacToeScenario) -> art.Trajecto
{"role": "user", "content": render_board(game)}
)

requested_at = int(time.time() * 1000)
messages = trajectory.messages()

try:
Expand Down Expand Up @@ -110,30 +104,4 @@ async def rollout(model: art.Model, scenario: TicTacToeScenario) -> art.Trajecto
trajectory.metrics["num_moves"] = move_number
trajectory.metrics["invalid_move"] = 1 if invalid_move else 0

if op_client.api_key:
try:
reported_win = (
trajectory.metrics["win"] if "win" in trajectory.metrics else -1
)
op_client.report(
requested_at=requested_at,
received_at=int(time.time() * 1000),
req_payload={
"model": model.name,
"messages": messages,
"metadata": {
"notebook-id": "tic-tac-toe",
"step": str(scenario.step),
"num_moves": str(move_number),
"win": str(reported_win),
"reward": str(trajectory.reward),
"invalid_move": str(invalid_move),
},
},
resp_payload=chat_completion,
status_code=200,
)
except Exception as e:
print(f"Error reporting to OpenPipe: {e}")

return trajectory
4 changes: 4 additions & 0 deletions examples/tic_tac_toe/tic-tac-toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
import random

import weave
from dotenv import load_dotenv
from rollout import TicTacToeScenario, rollout

import art
from art.utils.strip_logprobs import strip_logprobs

load_dotenv()

Expand All @@ -27,6 +29,8 @@
)
args = parser.parse_args()

weave.init("tic-tac-toe", global_postprocess_output=strip_logprobs)


async def main():
# Avoid import unnecessary backend dependencies
Expand Down
58 changes: 12 additions & 46 deletions examples/tic_tac_toe_self_play/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

import openai
import weave
from dotenv import load_dotenv
from game_utils import (
TicTacToeGame,
Expand All @@ -13,16 +14,13 @@
unwrap_move,
)
from openai.types.chat import ChatCompletion
from openpipe.client import OpenPipe
from pydantic import BaseModel

import art
from art.guided_completion import get_guided_completion_params

load_dotenv()

op_client = OpenPipe(api_key=os.getenv("OPENPIPE_API_KEY"))


class PlayerState(BaseModel):
trajectory: art.Trajectory
Expand Down Expand Up @@ -110,6 +108,7 @@ class TicTacToeScenario(BaseModel):


@art.retry(exceptions=(openai.LengthFinishReasonError,))
@weave.op
async def rollout(
x_model: art.Model, o_model: art.Model, scenario: TicTacToeScenario
) -> list[art.Trajectory]:
Expand Down Expand Up @@ -204,48 +203,15 @@ async def rollout(
1 if player_state.invalid_move else 0
)

if op_client.api_key:
for symbol in ["x", "o"]:
player_state = player_states[symbol]
trajectory = player_state.trajectory
messages = trajectory.messages()
# avoid double-reporting the last assistant completion message
if messages[-1]["role"] == "assistant":
messages = messages[:-1]

model = x_model if symbol == "x" else o_model
teacher = scenario.x_teacher if symbol == "x" else scenario.o_teacher
try:
reported_win = (
trajectory.metrics["win"] if "win" in trajectory.metrics else -1
)
op_client.report(
requested_at=start_time,
received_at=int(time.time() * 1000),
req_payload={
"model": model.name,
"messages": messages,
"metadata": {
"project": "tic-tac-toe",
"split": scenario.split,
"step": str(scenario.step),
"num_moves": str(move_number),
"win": str(reported_win),
"reward": str(trajectory.reward),
"invalid_move": str(player_state.invalid_move),
"symbol": symbol,
"teacher": teacher.name if teacher else "",
"initial_move": (
unwrap_move(scenario.initial_move)
if scenario.initial_move
else ""
),
},
},
resp_payload=player_state.last_completion,
status_code=200,
)
except Exception as e:
print(f"Error reporting to OpenPipe: {e}")
for symbol in ["x", "o"]:
player_state = player_states[symbol]
trajectory = player_state.trajectory
messages = trajectory.messages()
# avoid double-reporting the last assistant completion message
if messages[-1]["role"] == "assistant":
messages = messages[:-1]

model = x_model if symbol == "x" else o_model
teacher = scenario.x_teacher if symbol == "x" else scenario.o_teacher

return player_states["x"].trajectory, player_states["o"].trajectory
4 changes: 4 additions & 0 deletions examples/tic_tac_toe_self_play/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import os
import random

import weave
from dotenv import load_dotenv
from game_utils import possible_moves
from gather_trajectory_groups_by_index import gather_trajectory_groups_by_index
from rollout import ModelConfig, TicTacToeScenario, rollout

import art
from art.utils.strip_logprobs import strip_logprobs

load_dotenv()

Expand All @@ -23,6 +25,8 @@
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MODEL_NAME = "llama-8b-student-001"

weave.init("tic-tac-toe", global_postprocess_output=strip_logprobs)


async def main():
parser = argparse.ArgumentParser(description="Train a model to play Tic-Tac-Toe")
Expand Down
25 changes: 12 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ sources = ["src"]

[tool.hatch.build.targets.sdist]
exclude = [
"/dev",
"/wandb",
"/.art",
"/.ruff_cache",
"/.venv",
"/dist",
"/.git",
"/.github",
"/examples/*/data",
"/examples/*/wandb",
"**/__pycache__",
"**/*.pyc",
"/dev",
"/wandb",
"/.art",
"/.ruff_cache",
"/.venv",
"/dist",
"/.git",
"/.github",
"/examples/*/data",
"/examples/*/wandb",
"**/__pycache__",
"**/*.pyc",
]

[tool.ruff.lint]
Expand All @@ -99,7 +99,6 @@ dev-dependencies = [
"black>=25.1.0",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.5",
"openpipe>=4.49.0",
"hatch>=1.14.1",
"ruff>=0.12.1",
"pytest>=8.4.1",
Expand Down
Loading