Skip to content

Commit

Permalink
chore(tests): ensure messages.create() and messages.stream() stay in …
Browse files Browse the repository at this point in the history
…sync
  • Loading branch information
RobertCraigie committed May 30, 2024
1 parent 02d482c commit 52bd67b
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions tests/lib/streaming/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import inspect
from typing import Any, TypeVar, cast
from typing_extensions import Iterator, AsyncIterator, override

Expand All @@ -16,7 +17,7 @@
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "my-anthropic-api-key"

client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
sync_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
async_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)

_T = TypeVar("_T")
Expand Down Expand Up @@ -113,7 +114,7 @@ class TestSyncMessages:
def test_basic_response(self, respx_mock: MockRouter) -> None:
respx_mock.post("/v1/messages").mock(return_value=httpx.Response(200, content=basic_response()))

with client.messages.stream(
with sync_client.messages.stream(
max_tokens=1024,
messages=[
{
Expand All @@ -133,7 +134,7 @@ def test_basic_response(self, respx_mock: MockRouter) -> None:
def test_context_manager(self, respx_mock: MockRouter) -> None:
respx_mock.post("/v1/messages").mock(return_value=httpx.Response(200, content=basic_response()))

with client.messages.stream(
with sync_client.messages.stream(
max_tokens=1024,
messages=[
{
Expand Down Expand Up @@ -190,3 +191,35 @@ async def test_context_manager(self, respx_mock: MockRouter) -> None:

# response should be closed even if the body isn't read
assert stream.response.is_closed


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_stream_method_definition_in_sync(sync: bool) -> None:
client: Anthropic | AsyncAnthropic = sync_client if sync else async_client

sig = inspect.signature(client.messages.stream)
generated_sig = inspect.signature(client.messages.create)

errors: list[str] = []

for name, generated_param in generated_sig.parameters.items():
if name == "stream":
# intentionally excluded
continue

custom_param = sig.parameters.get(name)
if not custom_param:
errors.append(f"the `{name}` param is missing")
continue

if custom_param.annotation != generated_param.annotation:
errors.append(
f"types for the `{name}` param are do not match; generated={repr(generated_param.annotation)} custom={repr(generated_param.annotation)}"
)
continue

if errors:
raise AssertionError(
f"{len(errors)} errors encountered with the {'sync' if sync else 'async'} client `messages.stream()` method:\n\n"
+ "\n\n".join(errors)
)

0 comments on commit 52bd67b

Please sign in to comment.