diff --git a/posthog/ai/anthropic/anthropic_async.py b/posthog/ai/anthropic/anthropic_async.py index 9b02e35c..bd0fd86f 100644 --- a/posthog/ai/anthropic/anthropic_async.py +++ b/posthog/ai/anthropic/anthropic_async.py @@ -8,6 +8,7 @@ import time import uuid +from posthog.ai.stream import AsyncStreamWrapper from typing import Any, Dict, List, Optional from posthog import setup @@ -225,7 +226,7 @@ async def generator(): stop_reason=stop_reason, ) - return generator() + return AsyncStreamWrapper(generator()) async def _capture_streaming_event( self, diff --git a/posthog/ai/openai/openai_async.py b/posthog/ai/openai/openai_async.py index cb25e138..4190a210 100644 --- a/posthog/ai/openai/openai_async.py +++ b/posthog/ai/openai/openai_async.py @@ -1,8 +1,9 @@ import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any, AsyncIterator, Dict, List, Optional from posthog.ai.types import TokenUsage +from posthog.ai.stream import AsyncStreamWrapper try: import openai @@ -206,7 +207,7 @@ async def async_generator(): stop_reason=stop_reason, ) - return async_generator() + return AsyncStreamWrapper(async_generator()) async def _capture_streaming_event( self, @@ -486,7 +487,7 @@ async def async_generator(): stop_reason=stop_reason, ) - return async_generator() + return AsyncStreamWrapper(async_generator()) async def _capture_streaming_event( self, diff --git a/posthog/ai/stream.py b/posthog/ai/stream.py new file mode 100644 index 00000000..148f0920 --- /dev/null +++ b/posthog/ai/stream.py @@ -0,0 +1,61 @@ +"""Shared async streaming utilities for PostHog AI wrappers.""" + +from typing import Any, AsyncGenerator, TypeVar + +T = TypeVar("T") + + +class AsyncStreamWrapper: + """Wraps an async generator so it also implements the async context manager protocol. + + The OpenAI and Anthropic SDKs return stream objects that support both + ``async for`` iteration **and** ``async with`` (i.e. they are both async + iterators and async context managers). PostHog's streaming wrappers + previously returned a bare async generator, which only supports ``async + for``. Libraries such as pydantic-ai call ``async with response:`` before + iterating, causing:: + + TypeError: 'async_generator' object does not support the + asynchronous context manager protocol + + This class wraps the underlying async generator and adds the missing + ``__aenter__`` / ``__aexit__`` methods. On ``__aexit__`` the generator is + closed so that the ``finally`` block inside the generator (which fires the + PostHog usage event) always executes, even when the caller breaks out of + the loop early. + """ + + def __init__(self, generator: AsyncGenerator[T, None]) -> None: + self._generator = generator + + # ------------------------------------------------------------------ # + # Async iterator protocol # + # ------------------------------------------------------------------ # + + def __aiter__(self) -> "AsyncStreamWrapper": + return self + + async def __anext__(self) -> T: + return await self._generator.__anext__() + + # ------------------------------------------------------------------ # + # Async context manager protocol # + # ------------------------------------------------------------------ # + + async def __aenter__(self) -> "AsyncStreamWrapper": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + # Close the generator so the finally block (PostHog event capture) runs + # even on early exit. If the generator is already exhausted this is a + # no-op. + await self._generator.aclose() + return False + + # ------------------------------------------------------------------ # + # Attribute proxy – forward any other attribute access to the # + # underlying generator (e.g. .response on an Anthropic stream). # + # ------------------------------------------------------------------ # + + def __getattr__(self, name: str) -> Any: + return getattr(self._generator, name) diff --git a/posthog/test/test_async_stream_wrapper.py b/posthog/test/test_async_stream_wrapper.py new file mode 100644 index 00000000..c34a11bc --- /dev/null +++ b/posthog/test/test_async_stream_wrapper.py @@ -0,0 +1,127 @@ +"""Regression tests for AsyncStreamWrapper. + +Ensures that PostHog AI streaming wrappers return objects that support both +the async iterator protocol (``async for``) and the async context manager +protocol (``async with``), as required by libraries such as pydantic-ai. + +Issue: https://github.com/PostHog/posthog-python/issues/393 +""" + +import pytest + +from posthog.ai.stream import AsyncStreamWrapper + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _make_gen(items): + """Simple async generator that yields the given items.""" + for item in items: + yield item + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_for_iteration(): + """AsyncStreamWrapper must yield all items when used with ``async for``.""" + wrapper = AsyncStreamWrapper(_make_gen([1, 2, 3])) + result = [] + async for item in wrapper: + result.append(item) + assert result == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_async_context_manager_protocol(): + """AsyncStreamWrapper must support ``async with`` without raising TypeError.""" + wrapper = AsyncStreamWrapper(_make_gen(["a", "b"])) + + # This is the call pattern that pydantic-ai uses and that previously raised: + # TypeError: 'async_generator' object does not support the asynchronous + # context manager protocol + async with wrapper as stream: + result = [] + async for chunk in stream: + result.append(chunk) + + assert result == ["a", "b"] + + +@pytest.mark.asyncio +async def test_context_manager_returns_self(): + """``async with wrapper as w`` should bind the wrapper itself.""" + wrapper = AsyncStreamWrapper(_make_gen([])) + async with wrapper as w: + assert w is wrapper + + +@pytest.mark.asyncio +async def test_finally_block_runs_on_early_exit(): + """The underlying generator's finally block must run even when the caller + breaks out of the loop early (i.e. doesn't fully exhaust the generator).""" + finally_ran = [] + + async def gen_with_finally(): + try: + for i in range(10): + yield i + finally: + finally_ran.append(True) + + wrapper = AsyncStreamWrapper(gen_with_finally()) + async with wrapper as stream: + async for chunk in stream: + if chunk == 2: + break # early exit + + # __aexit__ must have called aclose(), triggering the finally block + assert finally_ran == [True], "finally block in generator did not run on early exit" + + +@pytest.mark.asyncio +async def test_finally_block_runs_on_full_exhaustion(): + """The underlying generator's finally block must also run on normal + exhaustion (``aclose()`` on an exhausted generator is a no-op).""" + finally_ran = [] + + async def gen_with_finally(): + try: + yield 1 + yield 2 + finally: + finally_ran.append(True) + + wrapper = AsyncStreamWrapper(gen_with_finally()) + async with wrapper as stream: + async for _ in stream: + pass + + assert finally_ran == [True] + + +@pytest.mark.asyncio +async def test_attribute_proxy(): + """Attributes not on AsyncStreamWrapper itself should be forwarded to the + underlying generator (for provider-specific metadata access).""" + + class FakeStream: + extra_attr = "hello" + + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + async def aclose(self): + pass + + wrapper = AsyncStreamWrapper(FakeStream()) # type: ignore[arg-type] + assert wrapper.extra_attr == "hello"