Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/fixtures/completions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"completions": [
{
"messages": [
{
"role": "system",
"content": "I am instructions"
},
{
"role": "user",
"content": "I am user message"
}
],
"response": "This is a test response",
"usage": {
"completion_token_count": 222,
"completion_cost_usd": 0.00013319999999999999,
"prompt_token_count": 1230,
"prompt_cost_usd": 0.00018449999999999999,
"model_context_window_size": 1048576
}
}
]
}
16 changes: 0 additions & 16 deletions tests/fixtures/task_example.json

This file was deleted.

11 changes: 11 additions & 0 deletions workflowai/core/client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from workflowai.core._common_types import OutputValidator
from workflowai.core.domain.cache_usage import CacheUsage
from workflowai.core.domain.completion import Completion
from workflowai.core.domain.run import Run
from workflowai.core.domain.task import AgentOutput
from workflowai.core.domain.tool_call import ToolCall as DToolCall
Expand Down Expand Up @@ -160,6 +161,7 @@ class CreateAgentResponse(BaseModel):

class ModelMetadata(BaseModel):
"""Metadata for a model."""

provider_name: str = Field(description="Name of the model provider")
price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD")
price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD")
Expand All @@ -170,6 +172,7 @@ class ModelMetadata(BaseModel):

class ModelInfo(BaseModel):
"""Information about a model."""

id: str = Field(description="Unique identifier for the model")
name: str = Field(description="Display name of the model")
icon_url: Optional[str] = Field(None, description="URL for the model's icon")
Expand All @@ -187,11 +190,19 @@ class ModelInfo(BaseModel):

T = TypeVar("T")


class Page(BaseModel, Generic[T]):
"""A generic paginated response."""

items: list[T] = Field(description="List of items in this page")
count: Optional[int] = Field(None, description="Total number of items available")


class ListModelsResponse(Page[ModelInfo]):
"""Response from the list models API endpoint."""


class CompletionsResponse(BaseModel):
"""Response from the completions API endpoint."""

completions: list[Completion]
17 changes: 17 additions & 0 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams
from workflowai.core.client._api import APIClient
from workflowai.core.client._models import (
CompletionsResponse,
CreateAgentRequest,
CreateAgentResponse,
ListModelsResponse,
Expand All @@ -24,6 +25,7 @@
intolerant_validator,
tolerant_validator,
)
from workflowai.core.domain.completion import Completion
from workflowai.core.domain.errors import BaseError, WorkflowAIError
from workflowai.core.domain.run import Run
from workflowai.core.domain.task import AgentInput, AgentOutput
Expand Down Expand Up @@ -493,3 +495,18 @@ async def list_models(self) -> list[ModelInfo]:
returns=ListModelsResponse,
)
return response.items

async def fetch_completions(self, run_id: str) -> list[Completion]:
"""Fetch the completions for a run.

Args:
run_id (str): The id of the run to fetch completions for.

Returns:
CompletionsResponse: The completions for the run.
"""
raw = await self.api.get(
f"/v1/_/agents/{self.agent_id}/runs/{run_id}/completions",
returns=CompletionsResponse,
)
return raw.completions
29 changes: 29 additions & 0 deletions workflowai/core/client/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from workflowai.core.client.client import (
WorkflowAI,
)
from workflowai.core.domain.completion import Completion, CompletionUsage, Message
from workflowai.core.domain.errors import WorkflowAIError
from workflowai.core.domain.run import Run
from workflowai.core.domain.version_properties import VersionProperties
Expand Down Expand Up @@ -539,3 +540,31 @@ async def test_list_models_registers_if_needed(
assert models[0].modes == ["chat"]
assert models[0].metadata is not None
assert models[0].metadata.provider_name == "OpenAI"


class TestFetchCompletions:
async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
"""Test that fetch_completions correctly fetches and returns completions."""
# Mock the HTTP response instead of the API client method
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/runs/1/completions",
json=fixtures_json("completions.json"),
)

completions = await agent.fetch_completions("1")
assert completions == [
Completion(
messages=[
Message(role="system", content="I am instructions"),
Message(role="user", content="I am user message"),
],
response="This is a test response",
usage=CompletionUsage(
completion_token_count=222,
completion_cost_usd=0.00013319999999999999,
prompt_token_count=1230,
prompt_cost_usd=0.00018449999999999999,
model_context_window_size=1048576,
),
),
]
39 changes: 39 additions & 0 deletions workflowai/core/domain/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional

from pydantic import BaseModel, Field


class CompletionUsage(BaseModel):
"""Usage information for a completion."""

completion_token_count: Optional[int] = None
completion_cost_usd: Optional[float] = None
reasoning_token_count: Optional[int] = None
prompt_token_count: Optional[int] = None
prompt_token_count_cached: Optional[int] = None
prompt_cost_usd: Optional[float] = None
prompt_audio_token_count: Optional[int] = None
prompt_audio_duration_seconds: Optional[float] = None
prompt_image_count: Optional[int] = None
model_context_window_size: Optional[int] = None


class Message(BaseModel):
"""A message in a completion."""

role: str = ""
content: str = ""


class Completion(BaseModel):
"""A completion from the model."""

messages: list[Message] = Field(default_factory=list)
response: Optional[str] = None
usage: CompletionUsage = Field(default_factory=CompletionUsage)


class CompletionsResponse(BaseModel):
"""Response from the completions API endpoint."""

completions: list[Completion]
20 changes: 20 additions & 0 deletions workflowai/core/domain/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from workflowai import env
from workflowai.core import _common_types
from workflowai.core.client import _types
from workflowai.core.domain.completion import Completion
from workflowai.core.domain.errors import BaseError
from workflowai.core.domain.task import AgentOutput
from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult
Expand Down Expand Up @@ -130,6 +131,23 @@ def __str__(self) -> str:
def run_url(self):
return f"{env.WORKFLOWAI_APP_URL}/_/agents/{self.agent_id}/runs/{self.id}"

async def fetch_completions(self) -> list[Completion]:
"""Fetch the completions for this run.

Returns:
CompletionsResponse: The completions response containing a list of completions
with their messages, responses and usage information.

Raises:
ValueError: If the agent is not set or if the run id is not set.
"""
if not self._agent:
raise ValueError("Agent is not set")
if not self.id:
raise ValueError("Run id is not set")

return await self._agent.fetch_completions(self.id)


class _AgentBase(Protocol, Generic[AgentOutput]):
async def reply(
Expand All @@ -141,3 +159,5 @@ async def reply(
) -> "Run[AgentOutput]":
"""Reply to a run. Either a user_message or tool_results must be provided."""
...

async def fetch_completions(self, run_id: str) -> list[Completion]: ...
85 changes: 82 additions & 3 deletions workflowai/core/domain/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import pytest
from pydantic import BaseModel

from workflowai.core.domain.run import Run
from workflowai.core.domain.completion import Completion, CompletionUsage, Message
from workflowai.core.domain.run import (
Run,
_AgentBase, # pyright: ignore [reportPrivateUsage]
)
from workflowai.core.domain.version import Version
from workflowai.core.domain.version_properties import VersionProperties

Expand All @@ -13,8 +17,14 @@ class _TestOutput(BaseModel):


@pytest.fixture
def run1() -> Run[_TestOutput]:
return Run[_TestOutput](
def mock_agent() -> Mock:
mock = Mock(spec=_AgentBase)
return mock


@pytest.fixture
def run1(mock_agent: Mock) -> Run[_TestOutput]:
run = Run[_TestOutput](
id="run-id",
agent_id="agent-id",
schema_id=1,
Expand All @@ -26,6 +36,8 @@ def run1() -> Run[_TestOutput]:
tool_calls=[],
tool_call_requests=[],
)
run._agent = mock_agent # pyright: ignore [reportPrivateUsage]
return run


@pytest.fixture
Expand Down Expand Up @@ -128,3 +140,70 @@ class TestRunURL:
@patch("workflowai.env.WORKFLOWAI_APP_URL", "https://workflowai.hello")
def test_run_url(self, run1: Run[_TestOutput]):
assert run1.run_url == "https://workflowai.hello/_/agents/agent-id/runs/run-id"


class TestFetchCompletions:
"""Tests for the fetch_completions method of the Run class."""

# Test that the underlying agent is called with the proper run id
async def test_fetch_completions_success(self, run1: Run[_TestOutput], mock_agent: Mock):
mock_agent.fetch_completions.return_value = [
Completion(
messages=[
Message(role="system", content="You are a helpful assistant"),
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there!"),
],
response="Hi there!",
usage=CompletionUsage(
completion_token_count=3,
completion_cost_usd=0.001,
reasoning_token_count=10,
prompt_token_count=20,
prompt_token_count_cached=0,
prompt_cost_usd=0.002,
prompt_audio_token_count=0,
prompt_audio_duration_seconds=0,
prompt_image_count=0,
model_context_window_size=32000,
),
),
]

# Call fetch_completions
completions = await run1.fetch_completions()

# Verify the API was called correctly
mock_agent.fetch_completions.assert_called_once_with("run-id")

# Verify the response
assert len(completions) == 1
completion = completions[0]
assert len(completion.messages) == 3
assert completion.messages[0].role == "system"
assert completion.messages[0].content == "You are a helpful assistant"
assert completion.response == "Hi there!"
assert completion.usage.completion_token_count == 3
assert completion.usage.completion_cost_usd == 0.001

# Test that fetch_completions fails appropriately when the agent is not set:
# 1. This is a common error case that occurs when a Run object is created without an agent
# 2. The method should fail fast with a clear error message before attempting any API calls
# 3. This protects users from confusing errors that would occur if we tried to use the API client
async def test_fetch_completions_no_agent(self, run1: Run[_TestOutput]):
run1._agent = None # pyright: ignore [reportPrivateUsage]
with pytest.raises(ValueError, match="Agent is not set"):
await run1.fetch_completions()

# Test that fetch_completions fails appropriately when the run ID is not set:
# 1. The run ID is required to construct the API endpoint URL
# 2. Without it, we can't make a valid API request
# 3. This validates that we fail fast with a clear error message
# 4. This should never happen in practice (as Run objects always have an ID),
# but we test it for completeness and to ensure robust error handling
async def test_fetch_completions_no_id(self, run1: Run[_TestOutput]):
mock_agent = Mock()
run1._agent = mock_agent # pyright: ignore [reportPrivateUsage]
run1.id = "" # Empty ID
with pytest.raises(ValueError, match="Run id is not set"):
await run1.fetch_completions()