From 96ed062368cbf89b7007f52d67b89cc1572370c8 Mon Sep 17 00:00:00 2001 From: Radu Raicea Date: Wed, 26 Nov 2025 11:31:40 -0500 Subject: [PATCH] feat(llma): add Gemini async --- CHANGELOG.md | 4 + posthog/ai/gemini/__init__.py | 3 + posthog/ai/gemini/gemini.py | 2 +- posthog/ai/gemini/gemini_async.py | 423 ++++++++++ posthog/test/ai/gemini/test_gemini_async.py | 843 ++++++++++++++++++++ posthog/version.py | 2 +- 6 files changed, 1275 insertions(+), 2 deletions(-) create mode 100644 posthog/ai/gemini/gemini_async.py create mode 100644 posthog/test/ai/gemini/test_gemini_async.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f00a91e..96aaa67d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 7.1.0 - 2025-11-26 + +Add support for the async version of Gemini. + # 7.0.2 - 2025-11-18 Add support for Python 3.14. diff --git a/posthog/ai/gemini/__init__.py b/posthog/ai/gemini/__init__.py index eb17989d..c250b8b6 100644 --- a/posthog/ai/gemini/__init__.py +++ b/posthog/ai/gemini/__init__.py @@ -1,4 +1,5 @@ from .gemini import Client +from .gemini_async import AsyncClient from .gemini_converter import ( format_gemini_input, format_gemini_response, @@ -9,12 +10,14 @@ # Create a genai-like module for perfect drop-in replacement class _GenAI: Client = Client + AsyncClient = AsyncClient genai = _GenAI() __all__ = [ "Client", + "AsyncClient", "genai", "format_gemini_input", "format_gemini_response", diff --git a/posthog/ai/gemini/gemini.py b/posthog/ai/gemini/gemini.py index 2b3eeb94..41ec6619 100644 --- a/posthog/ai/gemini/gemini.py +++ b/posthog/ai/gemini/gemini.py @@ -304,7 +304,7 @@ def _generate_content_streaming( def generator(): nonlocal usage_stats - nonlocal accumulated_content # noqa: F824 + nonlocal accumulated_content try: for chunk in response: # Extract usage stats from chunk diff --git a/posthog/ai/gemini/gemini_async.py b/posthog/ai/gemini/gemini_async.py new file mode 100644 index 00000000..60acc006 --- /dev/null +++ b/posthog/ai/gemini/gemini_async.py @@ -0,0 +1,423 @@ +import os +import time +import uuid +from typing import Any, Dict, Optional + +from posthog.ai.types import TokenUsage, StreamingEventData +from posthog.ai.utils import merge_system_prompt + +try: + from google import genai +except ImportError: + raise ModuleNotFoundError( + "Please install the Google Gemini SDK to use this feature: 'pip install google-genai'" + ) + +from posthog import setup +from posthog.ai.utils import ( + call_llm_and_track_usage_async, + capture_streaming_event, + merge_usage_stats, +) +from posthog.ai.gemini.gemini_converter import ( + extract_gemini_usage_from_chunk, + extract_gemini_content_from_chunk, + format_gemini_streaming_output, +) +from posthog.ai.sanitization import sanitize_gemini +from posthog.client import Client as PostHogClient + + +class AsyncClient: + """ + An async drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog. + + Usage: + client = AsyncClient( + api_key="your_api_key", + posthog_client=posthog_client, + posthog_distinct_id="default_user", # Optional defaults + posthog_properties={"team": "ai"} # Optional defaults + ) + response = await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello world"], + posthog_distinct_id="specific_user" # Override default + ) + """ + + _ph_client: PostHogClient + + def __init__( + self, + api_key: Optional[str] = None, + vertexai: Optional[bool] = None, + credentials: Optional[Any] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[Any] = None, + http_options: Optional[Any] = None, + posthog_client: Optional[PostHogClient] = None, + posthog_distinct_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: bool = False, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Args: + api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI) + vertexai: Whether to use Vertex AI authentication + credentials: Vertex AI credentials object + project: GCP project ID for Vertex AI + location: GCP location for Vertex AI + debug_config: Debug configuration for the client + http_options: HTTP options for the client + posthog_client: PostHog client for tracking usage + posthog_distinct_id: Default distinct ID for all calls (can be overridden per call) + posthog_properties: Default properties for all calls (can be overridden per call) + posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call) + posthog_groups: Default groups for all calls (can be overridden per call) + **kwargs: Additional arguments (for future compatibility) + """ + + self._ph_client = posthog_client or setup() + + if self._ph_client is None: + raise ValueError("posthog_client is required for PostHog tracking") + + self.models = AsyncModels( + api_key=api_key, + vertexai=vertexai, + credentials=credentials, + project=project, + location=location, + debug_config=debug_config, + http_options=http_options, + posthog_client=self._ph_client, + posthog_distinct_id=posthog_distinct_id, + posthog_properties=posthog_properties, + posthog_privacy_mode=posthog_privacy_mode, + posthog_groups=posthog_groups, + **kwargs, + ) + + +class AsyncModels: + """ + Async Models interface that mimics genai.Client().aio.models with PostHog tracking. + """ + + _ph_client: PostHogClient # Not None after __init__ validation + + def __init__( + self, + api_key: Optional[str] = None, + vertexai: Optional[bool] = None, + credentials: Optional[Any] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[Any] = None, + http_options: Optional[Any] = None, + posthog_client: Optional[PostHogClient] = None, + posthog_distinct_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: bool = False, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Args: + api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI) + vertexai: Whether to use Vertex AI authentication + credentials: Vertex AI credentials object + project: GCP project ID for Vertex AI + location: GCP location for Vertex AI + debug_config: Debug configuration for the client + http_options: HTTP options for the client + posthog_client: PostHog client for tracking usage + posthog_distinct_id: Default distinct ID for all calls + posthog_properties: Default properties for all calls + posthog_privacy_mode: Default privacy mode for all calls + posthog_groups: Default groups for all calls + **kwargs: Additional arguments (for future compatibility) + """ + + self._ph_client = posthog_client or setup() + + if self._ph_client is None: + raise ValueError("posthog_client is required for PostHog tracking") + + # Store default PostHog settings + self._default_distinct_id = posthog_distinct_id + self._default_properties = posthog_properties or {} + self._default_privacy_mode = posthog_privacy_mode + self._default_groups = posthog_groups + + # Build genai.Client arguments + client_args: Dict[str, Any] = {} + + # Add Vertex AI parameters if provided + if vertexai is not None: + client_args["vertexai"] = vertexai + + if credentials is not None: + client_args["credentials"] = credentials + + if project is not None: + client_args["project"] = project + + if location is not None: + client_args["location"] = location + + if debug_config is not None: + client_args["debug_config"] = debug_config + + if http_options is not None: + client_args["http_options"] = http_options + + # Handle API key authentication + if vertexai: + # For Vertex AI, api_key is optional + if api_key is not None: + client_args["api_key"] = api_key + else: + # For non-Vertex AI mode, api_key is required (backwards compatibility) + if api_key is None: + api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY") + + if api_key is None: + raise ValueError( + "API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable" + ) + + client_args["api_key"] = api_key + + self._client = genai.Client(**client_args) + self._base_url = "https://generativelanguage.googleapis.com" + + def _merge_posthog_params( + self, + call_distinct_id: Optional[str], + call_trace_id: Optional[str], + call_properties: Optional[Dict[str, Any]], + call_privacy_mode: Optional[bool], + call_groups: Optional[Dict[str, Any]], + ): + """Merge call-level PostHog parameters with client defaults.""" + + # Use call-level values if provided, otherwise fall back to defaults + distinct_id = ( + call_distinct_id + if call_distinct_id is not None + else self._default_distinct_id + ) + privacy_mode = ( + call_privacy_mode + if call_privacy_mode is not None + else self._default_privacy_mode + ) + groups = call_groups if call_groups is not None else self._default_groups + + # Merge properties: default properties + call properties (call properties override) + properties = dict(self._default_properties) + + if call_properties: + properties.update(call_properties) + + if call_trace_id is None: + call_trace_id = str(uuid.uuid4()) + + return distinct_id, call_trace_id, properties, privacy_mode, groups + + async def generate_content( + self, + model: str, + contents, + posthog_distinct_id: Optional[str] = None, + posthog_trace_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: Optional[bool] = None, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + """ + Generate content using Gemini's API while tracking usage in PostHog. + + This method signature exactly matches genai.Client().aio.models.generate_content() + with additional PostHog tracking parameters. + + Args: + model: The model to use (e.g., 'gemini-2.0-flash') + contents: The input content for generation + posthog_distinct_id: ID to associate with the usage event (overrides client default) + posthog_trace_id: Trace UUID for linking events (auto-generated if not provided) + posthog_properties: Extra properties to include in the event (merged with client defaults) + posthog_privacy_mode: Whether to redact sensitive information (overrides client default) + posthog_groups: Group analytics properties (overrides client default) + **kwargs: Arguments passed to Gemini's generate_content + """ + + # Merge PostHog parameters + distinct_id, trace_id, properties, privacy_mode, groups = ( + self._merge_posthog_params( + posthog_distinct_id, + posthog_trace_id, + posthog_properties, + posthog_privacy_mode, + posthog_groups, + ) + ) + + kwargs_with_contents = {"model": model, "contents": contents, **kwargs} + + return await call_llm_and_track_usage_async( + distinct_id, + self._ph_client, + "gemini", + trace_id, + properties, + privacy_mode, + groups, + self._base_url, + self._client.aio.models.generate_content, + **kwargs_with_contents, + ) + + async def _generate_content_streaming( + self, + model: str, + contents, + distinct_id: Optional[str], + trace_id: Optional[str], + properties: Optional[Dict[str, Any]], + privacy_mode: bool, + groups: Optional[Dict[str, Any]], + **kwargs: Any, + ): + start_time = time.time() + usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0) + accumulated_content = [] + + kwargs_without_stream = {"model": model, "contents": contents, **kwargs} + response = await self._client.aio.models.generate_content_stream( + **kwargs_without_stream + ) + + async def async_generator(): + nonlocal usage_stats + nonlocal accumulated_content + + try: + async for chunk in response: + # Extract usage stats from chunk + chunk_usage = extract_gemini_usage_from_chunk(chunk) + + if chunk_usage: + # Gemini reports cumulative totals, not incremental values + merge_usage_stats(usage_stats, chunk_usage, mode="cumulative") + + # Extract content from chunk (now returns content blocks) + content_block = extract_gemini_content_from_chunk(chunk) + + if content_block is not None: + accumulated_content.append(content_block) + + yield chunk + + finally: + end_time = time.time() + latency = end_time - start_time + + self._capture_streaming_event( + model, + contents, + distinct_id, + trace_id, + properties, + privacy_mode, + groups, + kwargs, + usage_stats, + latency, + accumulated_content, + ) + + return async_generator() + + def _capture_streaming_event( + self, + model: str, + contents, + distinct_id: Optional[str], + trace_id: Optional[str], + properties: Optional[Dict[str, Any]], + privacy_mode: bool, + groups: Optional[Dict[str, Any]], + kwargs: Dict[str, Any], + usage_stats: TokenUsage, + latency: float, + output: Any, + ): + # Prepare standardized event data + formatted_input = self._format_input(contents, **kwargs) + sanitized_input = sanitize_gemini(formatted_input) + + event_data = StreamingEventData( + provider="gemini", + model=model, + base_url=self._base_url, + kwargs=kwargs, + formatted_input=sanitized_input, + formatted_output=format_gemini_streaming_output(output), + usage_stats=usage_stats, + latency=latency, + distinct_id=distinct_id, + trace_id=trace_id, + properties=properties, + privacy_mode=privacy_mode, + groups=groups, + ) + + # Use the common capture function + capture_streaming_event(self._ph_client, event_data) + + def _format_input(self, contents, **kwargs): + """Format input contents for PostHog tracking""" + + # Create kwargs dict with contents for merge_system_prompt + input_kwargs = {"contents": contents, **kwargs} + return merge_system_prompt(input_kwargs, "gemini") + + async def generate_content_stream( + self, + model: str, + contents, + posthog_distinct_id: Optional[str] = None, + posthog_trace_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: Optional[bool] = None, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + # Merge PostHog parameters + distinct_id, trace_id, properties, privacy_mode, groups = ( + self._merge_posthog_params( + posthog_distinct_id, + posthog_trace_id, + posthog_properties, + posthog_privacy_mode, + posthog_groups, + ) + ) + + return await self._generate_content_streaming( + model, + contents, + distinct_id, + trace_id, + properties, + privacy_mode, + groups, + **kwargs, + ) diff --git a/posthog/test/ai/gemini/test_gemini_async.py b/posthog/test/ai/gemini/test_gemini_async.py new file mode 100644 index 00000000..624095f9 --- /dev/null +++ b/posthog/test/ai/gemini/test_gemini_async.py @@ -0,0 +1,843 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +try: + from google import genai as google_genai + + from posthog.ai.gemini import AsyncClient + + GEMINI_AVAILABLE = True +except ImportError: + GEMINI_AVAILABLE = False + +pytestmark = [ + pytest.mark.skipif( + not GEMINI_AVAILABLE, reason="Google Gemini package is not available" + ), + pytest.mark.asyncio, +] + + +@pytest.fixture +def mock_client(): + with patch("posthog.client.Client") as mock_client: + mock_client.privacy_mode = False + yield mock_client + + +@pytest.fixture +def mock_gemini_response(): + mock_response = MagicMock() + mock_response.text = "Test response from Gemini" + + mock_usage = MagicMock() + mock_usage.prompt_token_count = 20 + mock_usage.candidates_token_count = 10 + # Ensure cache and reasoning tokens are not present (not MagicMock) + mock_usage.cached_content_token_count = 0 + mock_usage.thoughts_token_count = 0 + mock_response.usage_metadata = mock_usage + + mock_candidate = MagicMock() + mock_candidate.text = "Test response from Gemini" + mock_content = MagicMock() + mock_part = MagicMock() + mock_part.text = "Test response from Gemini" + mock_content.parts = [mock_part] + mock_candidate.content = mock_content + mock_response.candidates = [mock_candidate] + + return mock_response + + +@pytest.fixture +def mock_google_genai_client(): + """Mock for the google-genai Client with async support""" + with patch.object(google_genai, "Client") as mock_client_class: + mock_client_instance = MagicMock() + mock_models = MagicMock() + mock_aio = MagicMock() + mock_aio_models = MagicMock() + + mock_client_instance.models = mock_models + mock_client_instance.aio = mock_aio + mock_aio.models = mock_aio_models + + mock_client_class.return_value = mock_client_instance + yield mock_client_instance + + +@pytest.fixture +def mock_gemini_response_with_function_calls(): + mock_response = MagicMock() + + # Mock usage metadata + mock_usage = MagicMock() + mock_usage.prompt_token_count = 25 + mock_usage.candidates_token_count = 15 + mock_usage.cached_content_token_count = 0 + mock_usage.thoughts_token_count = 0 + mock_response.usage_metadata = mock_usage + + # Mock function call + mock_function_call = MagicMock() + mock_function_call.name = "get_current_weather" + mock_function_call.args = {"location": "San Francisco"} + + # Mock text part 1 + mock_text_part1 = MagicMock() + mock_text_part1.text = "I'll check the weather for you." + type(mock_text_part1).text = mock_text_part1.text + + # Mock text part 2 + mock_text_part2 = MagicMock() + mock_text_part2.text = " Let me look that up." + type(mock_text_part2).text = mock_text_part2.text + + # Mock function call part + mock_function_part = MagicMock() + mock_function_part.function_call = mock_function_call + type(mock_function_part).function_call = mock_function_part.function_call + del mock_function_part.text + + # Mock content with 2 text parts and 1 function call part + mock_content = MagicMock() + mock_content.parts = [mock_text_part1, mock_text_part2, mock_function_part] + + # Mock candidate + mock_candidate = MagicMock() + mock_candidate.content = mock_content + mock_response.candidates = [mock_candidate] + + return mock_response + + +async def test_async_client_basic_generation( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test the async Client/AsyncModels API structure""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Tell me a fun fact about hedgehogs"], + posthog_distinct_id="test-id", + posthog_properties={"foo": "bar"}, + ) + + assert response == mock_gemini_response + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "gemini" + assert props["$ai_model"] == "gemini-2.0-flash" + assert props["$ai_input_tokens"] == 20 + assert props["$ai_output_tokens"] == 10 + assert props["foo"] == "bar" + assert "$ai_trace_id" in props + assert props["$ai_latency"] > 0 + + +async def test_async_client_streaming_with_generate_content_stream( + mock_client, mock_google_genai_client +): + """Test the async generate_content_stream method""" + + async def mock_streaming_response(): + mock_chunk1 = MagicMock() + mock_chunk1.text = "Hello " + mock_usage1 = MagicMock() + mock_usage1.prompt_token_count = 10 + mock_usage1.candidates_token_count = 5 + mock_usage1.cached_content_token_count = 0 + mock_usage1.thoughts_token_count = 0 + mock_chunk1.usage_metadata = mock_usage1 + yield mock_chunk1 + + mock_chunk2 = MagicMock() + mock_chunk2.text = "world!" + mock_usage2 = MagicMock() + mock_usage2.prompt_token_count = 10 + mock_usage2.candidates_token_count = 10 + mock_usage2.cached_content_token_count = 0 + mock_usage2.thoughts_token_count = 0 + mock_chunk2.usage_metadata = mock_usage2 + yield mock_chunk2 + + # Mock the async generate_content_stream method + mock_google_genai_client.aio.models.generate_content_stream = AsyncMock( + return_value=mock_streaming_response() + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content_stream( + model="gemini-2.0-flash", + contents=["Write a short story"], + posthog_distinct_id="test-id", + posthog_properties={"feature": "streaming"}, + ) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) == 2 + assert chunks[0].text == "Hello " + assert chunks[1].text == "world!" + + # Check that the streaming event was captured + assert mock_client.capture.call_count == 1 + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "gemini" + assert props["$ai_model"] == "gemini-2.0-flash" + assert props["$ai_input_tokens"] == 10 + assert props["$ai_output_tokens"] == 10 + assert props["feature"] == "streaming" + assert isinstance(props["$ai_latency"], float) + + +async def test_async_client_streaming_with_tools(mock_client, mock_google_genai_client): + """Test that tools are captured in async streaming mode""" + + async def mock_streaming_response(): + mock_chunk1 = MagicMock() + mock_chunk1.text = "I'll check " + mock_usage1 = MagicMock() + mock_usage1.prompt_token_count = 15 + mock_usage1.candidates_token_count = 5 + mock_usage1.cached_content_token_count = 0 + mock_usage1.thoughts_token_count = 0 + mock_chunk1.usage_metadata = mock_usage1 + yield mock_chunk1 + + mock_chunk2 = MagicMock() + mock_chunk2.text = "the weather" + mock_usage2 = MagicMock() + mock_usage2.prompt_token_count = 15 + mock_usage2.candidates_token_count = 10 + mock_usage2.cached_content_token_count = 0 + mock_usage2.thoughts_token_count = 0 + mock_chunk2.usage_metadata = mock_usage2 + yield mock_chunk2 + + # Mock the async generate_content_stream method + mock_google_genai_client.aio.models.generate_content_stream = AsyncMock( + return_value=mock_streaming_response() + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + # Create mock tools configuration + mock_tool = MagicMock() + mock_tool.function_declarations = [ + MagicMock( + name="get_current_weather", + description="Gets the current weather for a given location.", + parameters=MagicMock( + type="OBJECT", + properties={ + "location": MagicMock( + type="STRING", + description="The city and state, e.g. San Francisco, CA", + ) + }, + required=["location"], + ), + ) + ] + + mock_config = MagicMock() + mock_config.tools = [mock_tool] + + response = await client.models.generate_content_stream( + model="gemini-2.0-flash", + contents=["What's the weather in SF?"], + config=mock_config, + posthog_distinct_id="test-id", + posthog_properties={"feature": "streaming_with_tools"}, + ) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) == 2 + assert chunks[0].text == "I'll check " + assert chunks[1].text == "the weather" + + # Check that the streaming event was captured with tools + assert mock_client.capture.call_count == 1 + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "gemini" + assert props["$ai_model"] == "gemini-2.0-flash" + assert props["$ai_input_tokens"] == 15 + assert props["$ai_output_tokens"] == 10 + assert props["feature"] == "streaming_with_tools" + assert isinstance(props["$ai_latency"], float) + + # Verify that tools are captured in the $ai_tools property in streaming mode + assert props["$ai_tools"] == [mock_tool] + + +async def test_async_client_groups( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test groups functionality with async Client API""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello"], + posthog_distinct_id="test-id", + posthog_groups={"company": "company_123"}, + ) + + call_args = mock_client.capture.call_args[1] + assert call_args["groups"] == {"company": "company_123"} + + +async def test_async_client_privacy_mode_local( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test local privacy mode with async Client API""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello"], + posthog_distinct_id="test-id", + posthog_privacy_mode=True, + ) + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] is None + assert props["$ai_output_choices"] is None + + +async def test_async_client_privacy_mode_global( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test global privacy mode with async Client API""" + mock_client.privacy_mode = True + + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello"], + posthog_distinct_id="test-id", + ) + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] is None + assert props["$ai_output_choices"] is None + + +async def test_async_client_different_input_formats( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test different input formats with async Client API""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + # Test string input + await client.models.generate_content( + model="gemini-2.0-flash", contents="Hello", posthog_distinct_id="test-id" + ) + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] == [{"role": "user", "content": "Hello"}] + + # Test Gemini-specific format with parts array + mock_client.reset_mock() + await client.models.generate_content( + model="gemini-2.0-flash", + contents=[{"role": "user", "parts": [{"text": "hey"}]}], + posthog_distinct_id="test-id", + ) + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] == [{"role": "user", "content": "hey"}] + + # Test multiple parts in the parts array + mock_client.reset_mock() + await client.models.generate_content( + model="gemini-2.0-flash", + contents=[{"role": "user", "parts": [{"text": "Hello "}, {"text": "world"}]}], + posthog_distinct_id="test-id", + ) + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] == [{"role": "user", "content": "Hello world"}] + + # Test list input with string + mock_client.capture.reset_mock() + await client.models.generate_content( + model="gemini-2.0-flash", contents=["List item"], posthog_distinct_id="test-id" + ) + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] == [{"role": "user", "content": "List item"}] + + +async def test_async_client_model_parameters( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test model parameters with async Client API""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello"], + posthog_distinct_id="test-id", + temperature=0.7, + max_tokens=100, + ) + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_model_parameters"]["temperature"] == 0.7 + assert props["$ai_model_parameters"]["max_tokens"] == 100 + + +async def test_async_client_default_settings( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test async client with default PostHog settings""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient( + api_key="test-key", + posthog_client=mock_client, + posthog_distinct_id="default_user", + posthog_properties={"team": "ai"}, + posthog_privacy_mode=False, + posthog_groups={"company": "acme_corp"}, + ) + + # Call without overriding defaults + await client.models.generate_content(model="gemini-2.0-flash", contents=["Hello"]) + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "default_user" + assert call_args["groups"] == {"company": "acme_corp"} + assert props["team"] == "ai" + + +async def test_async_client_override_defaults( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test overriding async client defaults per call""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + client = AsyncClient( + api_key="test-key", + posthog_client=mock_client, + posthog_distinct_id="default_user", + posthog_properties={"team": "ai"}, + posthog_privacy_mode=False, + posthog_groups={"company": "acme_corp"}, + ) + + # Override defaults in call + await client.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello"], + posthog_distinct_id="specific_user", + posthog_properties={"feature": "chat", "urgent": True}, + posthog_privacy_mode=True, + posthog_groups={"organization": "special_org"}, + ) + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Check overrides + assert call_args["distinct_id"] == "specific_user" + assert call_args["groups"] == {"organization": "special_org"} + assert props["$ai_input"] is None # privacy mode was overridden + + # Check merged properties (defaults + call-specific) + assert props["team"] == "ai" # from defaults + assert props["feature"] == "chat" # from call + assert props["urgent"] is True # from call + + +async def test_async_vertex_ai_parameters_passed_through( + mock_client, mock_google_genai_client, mock_gemini_response +): + """Test that Vertex AI parameters are properly passed to genai.Client""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response + ) + + # Mock credentials object + mock_credentials = MagicMock() + mock_debug_config = MagicMock() + mock_http_options = MagicMock() + + # Create client with Vertex AI parameters + AsyncClient( + vertexai=True, + credentials=mock_credentials, + project="test-project", + location="us-central1", + debug_config=mock_debug_config, + http_options=mock_http_options, + posthog_client=mock_client, + ) + + # Verify genai.Client was called with correct parameters + google_genai.Client.assert_called_once_with( + vertexai=True, + credentials=mock_credentials, + project="test-project", + location="us-central1", + debug_config=mock_debug_config, + http_options=mock_http_options, + ) + + +async def test_async_api_key_mode(mock_client, mock_google_genai_client): + """Test API key authentication mode with async client""" + + # Create async client with just API key (traditional mode) + AsyncClient( + api_key="test-api-key", + posthog_client=mock_client, + ) + + # Verify genai.Client was called with only api_key + google_genai.Client.assert_called_once_with(api_key="test-api-key") + + +async def test_async_function_calls_in_output_choices( + mock_client, mock_google_genai_client, mock_gemini_response_with_function_calls +): + """Test that function calls are properly included in $ai_output_choices with async""" + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_gemini_response_with_function_calls + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=["What's the weather in San Francisco?"], + posthog_distinct_id="test-id", + ) + + assert response == mock_gemini_response_with_function_calls + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "gemini" + assert props["$ai_model"] == "gemini-2.5-flash" + assert props["$ai_output_choices"] == [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll check the weather for you."}, + {"type": "text", "text": " Let me look that up."}, + { + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": {"location": "San Francisco"}, + }, + }, + ], + } + ] + + # Check token usage + assert props["$ai_input_tokens"] == 25 + assert props["$ai_output_tokens"] == 15 + assert props["$ai_http_status"] == 200 + + +async def test_async_cache_and_reasoning_tokens(mock_client, mock_google_genai_client): + """Test that cache and reasoning tokens are properly extracted with async""" + # Create a mock response with cache and reasoning tokens + mock_response = MagicMock() + mock_response.text = "Test response with cache" + + mock_usage = MagicMock() + mock_usage.prompt_token_count = 100 + mock_usage.candidates_token_count = 50 + mock_usage.cached_content_token_count = 30 # Cache tokens + mock_usage.thoughts_token_count = 10 # Reasoning tokens + mock_response.usage_metadata = mock_usage + + # Mock candidates + mock_candidate = MagicMock() + mock_candidate.text = "Test response with cache" + mock_response.candidates = [mock_candidate] + + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content( + model="gemini-2.5-pro", + contents="Test with cache", + posthog_distinct_id="test-id", + ) + + assert response == mock_response + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Check that all token types are present + assert props["$ai_input_tokens"] == 100 + assert props["$ai_output_tokens"] == 50 + assert props["$ai_cache_read_input_tokens"] == 30 + assert props["$ai_reasoning_tokens"] == 10 + + +async def test_async_streaming_cache_and_reasoning_tokens( + mock_client, mock_google_genai_client +): + """Test that cache and reasoning tokens are properly extracted in async streaming""" + + async def mock_streaming_response(): + # Create mock chunks with cache and reasoning tokens + chunk1 = MagicMock() + chunk1.text = "Hello " + chunk1_usage = MagicMock() + chunk1_usage.prompt_token_count = 100 + chunk1_usage.candidates_token_count = 5 + chunk1_usage.cached_content_token_count = 30 # Cache tokens + chunk1_usage.thoughts_token_count = 0 + chunk1.usage_metadata = chunk1_usage + yield chunk1 + + chunk2 = MagicMock() + chunk2.text = "world!" + chunk2_usage = MagicMock() + chunk2_usage.prompt_token_count = 100 + chunk2_usage.candidates_token_count = 10 + chunk2_usage.cached_content_token_count = 30 # Same cache tokens + chunk2_usage.thoughts_token_count = 5 # Reasoning tokens + chunk2.usage_metadata = chunk2_usage + yield chunk2 + + mock_google_genai_client.aio.models.generate_content_stream = AsyncMock( + return_value=mock_streaming_response() + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content_stream( + model="gemini-2.5-pro", + contents="Test streaming with cache", + posthog_distinct_id="test-id", + ) + + # Consume the stream + result = [] + async for chunk in response: + result.append(chunk) + + assert len(result) == 2 + + # Check PostHog capture was called + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Check that all token types are present (should use final chunk's usage) + assert props["$ai_input_tokens"] == 100 + assert props["$ai_output_tokens"] == 10 + assert props["$ai_cache_read_input_tokens"] == 30 + assert props["$ai_reasoning_tokens"] == 5 + + +async def test_async_web_search_grounding(mock_client, mock_google_genai_client): + """Test async web search detection via grounding_metadata.""" + + # Create mock response with grounding metadata + mock_response = MagicMock() + + # Mock usage metadata + mock_usage = MagicMock() + mock_usage.prompt_token_count = 60 + mock_usage.candidates_token_count = 40 + mock_usage.cached_content_token_count = 0 + mock_usage.thoughts_token_count = 0 + mock_response.usage_metadata = mock_usage + + # Mock grounding metadata + mock_grounding_chunk = MagicMock() + mock_grounding_chunk.uri = "https://example.com" + + mock_grounding_metadata = MagicMock() + mock_grounding_metadata.grounding_chunks = [mock_grounding_chunk] + + # Mock text part + mock_text_part = MagicMock() + mock_text_part.text = "According to search results..." + type(mock_text_part).text = mock_text_part.text + + # Mock content with parts + mock_content = MagicMock() + mock_content.parts = [mock_text_part] + + # Mock candidate with grounding metadata + mock_candidate = MagicMock() + mock_candidate.content = mock_content + mock_candidate.grounding_metadata = mock_grounding_metadata + type(mock_candidate).grounding_metadata = mock_candidate.grounding_metadata + + mock_response.candidates = [mock_candidate] + mock_response.text = "According to search results..." + + # Mock the async generate_content method + mock_google_genai_client.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents="What's the latest news?", + posthog_distinct_id="test-id", + ) + + assert response == mock_response + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Verify web search count is detected (binary for grounding) + assert props["$ai_web_search_count"] == 1 + assert props["$ai_input_tokens"] == 60 + assert props["$ai_output_tokens"] == 40 + + +async def test_async_streaming_with_web_search(mock_client, mock_google_genai_client): + """Test that web search count is properly captured in async streaming mode.""" + + async def mock_streaming_response(): + # Create chunk 1 with grounding metadata + mock_chunk1 = MagicMock() + mock_chunk1.text = "According to " + + mock_usage1 = MagicMock() + mock_usage1.prompt_token_count = 30 + mock_usage1.candidates_token_count = 5 + mock_usage1.cached_content_token_count = 0 + mock_usage1.thoughts_token_count = 0 + mock_chunk1.usage_metadata = mock_usage1 + + # Add grounding metadata to first chunk + mock_grounding_chunk = MagicMock() + mock_grounding_chunk.uri = "https://example.com" + + mock_grounding_metadata = MagicMock() + mock_grounding_metadata.grounding_chunks = [mock_grounding_chunk] + + mock_candidate1 = MagicMock() + mock_candidate1.grounding_metadata = mock_grounding_metadata + type(mock_candidate1).grounding_metadata = mock_candidate1.grounding_metadata + + mock_chunk1.candidates = [mock_candidate1] + yield mock_chunk1 + + # Create chunk 2 + mock_chunk2 = MagicMock() + mock_chunk2.text = "search results..." + + mock_usage2 = MagicMock() + mock_usage2.prompt_token_count = 30 + mock_usage2.candidates_token_count = 15 + mock_usage2.cached_content_token_count = 0 + mock_usage2.thoughts_token_count = 0 + mock_chunk2.usage_metadata = mock_usage2 + + mock_candidate2 = MagicMock() + mock_chunk2.candidates = [mock_candidate2] + yield mock_chunk2 + + # Mock the async generate_content_stream method + mock_google_genai_client.aio.models.generate_content_stream = AsyncMock( + return_value=mock_streaming_response() + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content_stream( + model="gemini-2.5-flash", + contents="What's the latest news?", + posthog_distinct_id="test-id", + ) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) == 2 + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Verify web search count is detected (binary for grounding) + assert props["$ai_web_search_count"] == 1 + assert props["$ai_input_tokens"] == 30 + assert props["$ai_output_tokens"] == 15 diff --git a/posthog/version.py b/posthog/version.py index 639be2bc..75b03c92 100644 --- a/posthog/version.py +++ b/posthog/version.py @@ -1,4 +1,4 @@ -VERSION = "7.0.1" +VERSION = "7.1.0" if __name__ == "__main__": print(VERSION, end="") # noqa: T201