Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/art/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, AsyncIterator, Iterable, Literal, TypedDict, cast

import httpx
from openai import AsyncOpenAI, BaseModel, _exceptions
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
from openai._compat import cached_property
from openai._qs import Querystring
Expand All @@ -17,8 +18,6 @@
from openai.resources.models import AsyncModels # noqa: F401
from typing_extensions import override

from openai import AsyncOpenAI, BaseModel, _exceptions

from .trajectories import TrajectoryGroup


Expand Down Expand Up @@ -291,7 +290,9 @@ def events(self) -> TrainingJobEvents:

class TrainingJobEvent(BaseModel):
id: str
type: Literal["training_started", "gradient_step", "training_ended"]
type: Literal[
"training_started", "gradient_step", "training_ended", "training_failed"
]
data: dict[str, Any]


Expand Down
3 changes: 2 additions & 1 deletion src/art/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None:
if logprobs:
# TODO: probably shouldn't average this
trajectory.metrics["completion_tokens"] = sum(
len(l.content or l.refusal or []) for l in logprobs # noqa: E741
len(l.content or l.refusal or [])
for l in logprobs # noqa: E741
) / len(logprobs)
context.metric_sums["reward"] += trajectory.reward # type: ignore
context.metric_divisors["reward"] += 1
Expand Down
6 changes: 3 additions & 3 deletions src/art/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def update_chat_completion(
choice.message.tool_calls[tool_call.index].id = tool_call.id
if tool_call.function:
if tool_call.function.name:
choice.message.tool_calls[tool_call.index].function.name = (
tool_call.function.name
)
choice.message.tool_calls[
tool_call.index
].function.name = tool_call.function.name
if tool_call.function.arguments:
choice.message.tool_calls[
tool_call.index
Expand Down
5 changes: 5 additions & 0 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ async def _train_model(
continue
elif event.type == "training_ended":
return
elif event.type == "training_failed":
error_message = event.data.get(
"error_message", "Training failed with an unknown error"
)
raise RuntimeError(f"Training job failed: {error_message}")
after = event.id

# ------------------------------------------------------------------
Expand Down
Loading