diff --git a/.gitignore b/.gitignore index cdc85cb..61abe29 100644 --- a/.gitignore +++ b/.gitignore @@ -61,4 +61,5 @@ datu-sqlserver/ test_schema_cache.json .cache/ -site/ \ No newline at end of file +site/ +.coverage* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 42c8756..14f17dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,8 +9,8 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ pip install --upgrade pip -RUN curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add - -RUN curl https://packages.microsoft.com/config/debian/11/prod.list > /etc/apt/sources.list.d/mssql-release.list +RUN curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > /etc/apt/trusted.gpg.d/microsoft.gpg \ +&& curl https://packages.microsoft.com/config/debian/11/prod.list -o /etc/apt/sources.list.d/mssql-release.list RUN apt-get update RUN env ACCEPT_EULA=Y apt-get install -y msodbcsql18 @@ -21,4 +21,4 @@ RUN uv sync --extra postgres --extra sqldb ENV PATH="/app/.venv/bin:$PATH" # Reset the entrypoint, don't invoke `uv` ENTRYPOINT [] -CMD ["uvicorn", "datu.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] +CMD ["uvicorn", "datu.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] \ No newline at end of file diff --git a/changelog.d/+7413cd44.added.md b/changelog.d/+7413cd44.added.md new file mode 100644 index 0000000..27472ed --- /dev/null +++ b/changelog.d/+7413cd44.added.md @@ -0,0 +1 @@ +Enable product Telemetry diff --git a/pyproject.toml b/pyproject.toml index 826b93c..12e616e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "fastmcp>=2.10.5", "mcp-use[search]>=1.3.7", "onnxruntime==1.19.2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'", + "posthog>=6.5.0", ] [project.urls] diff --git a/src/datu/app_config.py b/src/datu/app_config.py index ef0b2bc..002b373 100644 --- a/src/datu/app_config.py +++ b/src/datu/app_config.py @@ -17,6 +17,7 @@ from datu.integrations.config import IntegrationConfigs from datu.mcp.config import MCPConfig from datu.services.config import SchemaRAGConfig +from datu.telemetry.config import TelemetryConfig class Environment(Enum): @@ -56,6 +57,13 @@ class DatuConfig(BaseSettings): schema_sample_limit (int): The maximum number of rows to sample from the schema. schema_categorical_threshold (int): The threshold for categorical columns in the schema. enable_schema_rag (bool): Enable RAG for schema extraction. + enable_anonymized_telemetry (bool): Enable anonymized telemetry. Default is True. + app_environment (str): The application environment (e.g., "dev", "test", "prod"). Default is "dev". + telemetry (TelemetryConfig | None): Configuration settings for telemetry. + enable_mcp (bool): Whether to enable MCP integration. Default is False. + mcp (MCPConfig | None): Configuration settings for MCP integration. + enable_schema_rag (bool): Enable RAG for schema extraction. + schema_rag (SchemaRAGConfig | None): Configuration settings for schema RAG. Attributes: host (str): The host address for the application. @@ -77,7 +85,8 @@ class DatuConfig(BaseSettings): mcp (MCPConfig | None): Configuration settings for MCP integration. enable_schema_rag (bool): Enable RAG for schema extraction. schema_rag (SchemaRAGConfig | None): Configuration settings for schema RAG. - + enable_anonymized_telemetry (bool): Enable anonymized telemetry. + telemetry (TelemetryConfig | None): Configuration settings for telemetry. """ @@ -110,7 +119,8 @@ class DatuConfig(BaseSettings): description="Configuration settings for schema RAG (Retrieval-Augmented Generation).", ) enable_anonymization: bool = False - + enable_anonymized_telemetry: bool = True + telemetry: TelemetryConfig | None = Field(default_factory=TelemetryConfig) model_config = SettingsConfigDict( env_prefix="datu_", env_nested_delimiter="__", diff --git a/src/datu/factory/llm_client_factory.py b/src/datu/factory/llm_client_factory.py index 908c8c8..bf5906d 100644 --- a/src/datu/factory/llm_client_factory.py +++ b/src/datu/factory/llm_client_factory.py @@ -6,6 +6,8 @@ from typing import Literal from datu.llm_clients.openai_client import OpenAIClient +from datu.telemetry.product.events import MCPClientEvent +from datu.telemetry.product.posthog import get_posthog_client def get_llm_client(provider: Literal["openai"] | None = None) -> OpenAIClient | None: @@ -20,6 +22,15 @@ def get_llm_client(provider: Literal["openai"] | None = None) -> OpenAIClient | ValueError: If the specified provider is not supported. """ if provider == "openai": - return OpenAIClient() + openai_client = OpenAIClient() + if openai_client.agent: + posthog_client = get_posthog_client() + if openai_client.mcp_client and getattr(openai_client.mcp_client, "config", None): + servers = openai_client.mcp_client.config.get("mcpServers", {}) + server_names = list(servers.keys()) if servers else [] + else: + server_names = [] + posthog_client.capture(MCPClientEvent(server_names=server_names)) + return openai_client else: raise ValueError("Invalid LLM provider specified in configuration.") diff --git a/src/datu/main.py b/src/datu/main.py index 96679e0..46029f5 100644 --- a/src/datu/main.py +++ b/src/datu/main.py @@ -13,6 +13,8 @@ from datu.app_config import get_logger, settings from datu.routers import chat, metadata, transformations from datu.schema_extractor.schema_cache import load_schema_cache +from datu.telemetry.product.events import OpenAIEvent +from datu.telemetry.product.posthog import get_posthog_client logger = get_logger(__name__) @@ -56,6 +58,8 @@ def start_app() -> None: It also sets the logging level based on the configuration settings. """ logger.info("Starting the FastAPI application...") + posthog_client = get_posthog_client() + posthog_client.capture(OpenAIEvent()) uvicorn.run(app, host=settings.host, port=settings.port) diff --git a/tests/helpers/__init__.py b/src/datu/telemetry/__init__.py similarity index 100% rename from tests/helpers/__init__.py rename to src/datu/telemetry/__init__.py diff --git a/src/datu/telemetry/config.py b/src/datu/telemetry/config.py new file mode 100644 index 0000000..158a56a --- /dev/null +++ b/src/datu/telemetry/config.py @@ -0,0 +1,19 @@ +"""Telemetry configuration settings.""" + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class TelemetryConfig(BaseSettings): + """Telemetry configuration settings.""" + + api_key: str = "phc_m74dfR9nLpm2nipvkL2swyFDtNuQNC9o2FL2CSbh6Je" + package_name: str = "datu-core" + + model_config = SettingsConfigDict( + env_prefix="telemetry_", + env_nested_delimiter="__", + ) + + +def get_telemetry_settings() -> TelemetryConfig: + return TelemetryConfig() diff --git a/src/datu/telemetry/product/__init__.py b/src/datu/telemetry/product/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datu/telemetry/product/events.py b/src/datu/telemetry/product/events.py new file mode 100644 index 0000000..2389666 --- /dev/null +++ b/src/datu/telemetry/product/events.py @@ -0,0 +1,50 @@ +"""Telemetry events for product usage.""" + +from typing import Any, ClassVar, Dict + +from datu.app_config import settings + + +class ProductTelemetryEvent: + """Base class for all telemetry events.""" + + max_batch_size: ClassVar[int] = 1 + + def __init__(self, **kwargs): + self._props = kwargs + self.batch_size = 1 + + @property + def name(self) -> str: + return self.__class__.__name__ + + @property + def properties(self) -> Dict[str, Any]: + return self._props + + @property + def batch_key(self) -> str: + return self.name + + def batch(self, other: "ProductTelemetryEvent") -> "ProductTelemetryEvent": + """Simple batch: append counts together.""" + if self.name != other.name: + raise ValueError("Cannot batch different event types") + self.batch_size += other.batch_size + return self + + +class MCPClientEvent(ProductTelemetryEvent): + """Event for when the MCP client starts.""" + + def __init__(self, server_names: list[str]): + super().__init__() + self._props["mcp_server_names"] = server_names + + +class OpenAIEvent(ProductTelemetryEvent): + """Event for OpenAI-related telemetry.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._props["openai_model"] = settings.openai_model diff --git a/src/datu/telemetry/product/posthog.py b/src/datu/telemetry/product/posthog.py new file mode 100644 index 0000000..8c47c8c --- /dev/null +++ b/src/datu/telemetry/product/posthog.py @@ -0,0 +1,124 @@ +"""PostHog telemetry client for product usage tracking.""" + +import importlib +import logging +import platform +import sys +import uuid +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import posthog + +from datu.app_config import Environment, get_logger, settings +from datu.telemetry.config import TelemetryConfig as TelemetrySettings +from datu.telemetry.product.events import ProductTelemetryEvent + +logger = get_logger(__name__) + +POSTHOG_EVENT_SETTINGS = {"$process_person_profile": False} + + +class PostHogClient: + """Telemetry client with basic batching + config via Pydantic.""" + + UNKNOWN_USER_ID = "UNKNOWN" + USER_ID_PATH = Path.home() / ".cache" / "datu-core" / "telemetry_user_id" + + def __init__(self, telemetry_settings: Optional[TelemetrySettings]) -> None: + self.settings = telemetry_settings or TelemetrySettings() + self._batched_events: Dict[str, ProductTelemetryEvent] = {} + self._user_id: str = "" + self._user_id_path: Path = self.USER_ID_PATH + self.session_id = str(uuid.uuid4()) + + if ( + not settings.enable_anonymized_telemetry + or "pytest" in sys.modules + or settings.app_environment in [Environment.TEST.value] + ): + posthog.disabled = True + else: + logger.info("Enabled anonymized telemetry. See https://docs.datu.fi for more information.") + posthog.api_key = self.settings.api_key + posthog_logger = logging.getLogger("posthog") + posthog_logger.disabled = True + + @property + def user_id(self) -> str: + if self._user_id: + return self._user_id + + try: + if not self._user_id_path.exists(): + self._user_id_path.parent.mkdir(parents=True, exist_ok=True) + new_id = str(uuid.uuid4()) + self._user_id_path.write_text(new_id) + self._user_id = new_id + else: + self._user_id = self._user_id_path.read_text().strip() + except Exception: + self._user_id = self.UNKNOWN_USER_ID + + return self._user_id + + def _base_context(self) -> Dict[str, Any]: + try: + pkg_version = importlib.metadata.version(self.settings.package_name) + except importlib.metadata.PackageNotFoundError: + pkg_version = "unknown" + + extras_installed: Dict[str, bool] = {} + try: + dist = importlib.metadata.distribution(self.settings.package_name) + extras = dist.metadata.get_all("Provides-Extra") or [] + for extra in extras: + extras_installed[extra] = True + except importlib.metadata.PackageNotFoundError: + extras_installed = {} + + return { + "python_version": sys.version.split()[0], + "os": platform.system(), + "os_version": platform.release(), + "package_version": pkg_version, + "extras_installed": extras_installed, + } + + def capture(self, event: ProductTelemetryEvent) -> None: + """Capture an event (with simple batching).""" + if not settings.enable_anonymized_telemetry or not self.settings.api_key: + return + + if event.max_batch_size == 1: + self._send(event) + return + + batch_key = event.batch_key + if batch_key not in self._batched_events: + self._batched_events[batch_key] = event + return + + batched = self._batched_events[batch_key].batch(event) + self._batched_events[batch_key] = batched + + if batched.batch_size >= batched.max_batch_size: + self._send(batched) + del self._batched_events[batch_key] + + def _send(self, event: ProductTelemetryEvent) -> None: + try: + posthog.capture( + distinct_id=self.user_id, + event=event.name, + properties={**self._base_context(), **POSTHOG_EVENT_SETTINGS, **event.properties}, + ) + except Exception: + logger.debug("Failed to send telemetry event", exc_info=True) + + +@lru_cache(maxsize=1) +def get_posthog_client() -> PostHogClient: + """Get the PostHog telemetry client.""" + return PostHogClient(settings.telemetry) diff --git a/tests/integrations/sql_server/conftest.py b/tests/integrations/sql_server/conftest.py index a39f948..38de7fe 100644 --- a/tests/integrations/sql_server/conftest.py +++ b/tests/integrations/sql_server/conftest.py @@ -1,6 +1,6 @@ """Common fixtures for tests in integrations sql_server module .""" -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name disable=unused-argument disable=import-outside-toplevel import pytest diff --git a/tests/llm_clients/__init__.py b/tests/llm_clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/routers/test_chat.py b/tests/routers/test_chat.py index 804c388..119ebd3 100644 --- a/tests/routers/test_chat.py +++ b/tests/routers/test_chat.py @@ -1,4 +1,5 @@ -# tests/routers/test_chat.py +"""Test suite for the chat router.""" + import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -8,6 +9,7 @@ def _client() -> TestClient: + """Create a test client for the chat router.""" app = FastAPI() app.include_router(chat.router, prefix="/chat/sql") return TestClient(app) @@ -18,6 +20,8 @@ def _client() -> TestClient: @pytest.mark.asyncio async def test_happy_path_with_mcp(monkeypatch): + """Test that the happy path with MCP enabled works correctly.""" + async def fake_generate_response(_msgs, _sys): return "Query name: Sales\n```sql\nSELECT 1;\n```" @@ -36,6 +40,8 @@ async def fake_generate_response(_msgs, _sys): @pytest.mark.asyncio async def test_error_path_with_mcp(monkeypatch): + """Test that the error path with MCP enabled works correctly.""" + async def boom(*_a, **_k): raise RuntimeError("LLM down") @@ -55,6 +61,8 @@ async def boom(*_a, **_k): @pytest.mark.asyncio async def test_passthrough_to_generate_sql_core(monkeypatch): + """Test that the passthrough to generate_sql_core works correctly.""" + async def fake_generate_sql_core(request: ChatRequest): return {"assistant_response": "core path", "queries": []} diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/sample_schemas.py b/tests/services/conftest.py similarity index 78% rename from tests/helpers/sample_schemas.py rename to tests/services/conftest.py index b384689..91bf9f9 100644 --- a/tests/helpers/sample_schemas.py +++ b/tests/services/conftest.py @@ -1,5 +1,7 @@ """Sample schema fixtures for testing schema extraction and caching.""" +import pytest + from datu.base.base_connector import SchemaInfo, TableInfo from datu.schema_extractor.schema_cache import SchemaGlossary @@ -36,6 +38,7 @@ def cached_schema(): @staticmethod def raw_schema_dict(): + """Return a raw schema dictionary.""" return { "profile_name": "test_profile", "output_name": "test_output", @@ -52,3 +55,20 @@ def raw_schema_dict(): } ], } + + +@pytest.fixture +def sample_schema(): + """Return a function that can generate schemas with parameters.""" + + # Return a function that can generate schemas with parameters + def _factory(timestamp: float = 1234567890.0): + return SchemaTestFixtures.sample_schema(timestamp=timestamp) + + return _factory + + +@pytest.fixture +def raw_schema_dict(): + """Return a raw schema dictionary.""" + return SchemaTestFixtures.raw_schema_dict() diff --git a/tests/services/test_graph_rag.py b/tests/services/test_graph_rag.py index 5329139..25f441c 100644 --- a/tests/services/test_graph_rag.py +++ b/tests/services/test_graph_rag.py @@ -7,10 +7,6 @@ import networkx as nx import pytest -from datu.services.schema_rag import SchemaGraphBuilder, SchemaRAG, SchemaTripleExtractor - -from tests.helpers.sample_schemas import SchemaTestFixtures - TEST_GRAPH_DIR = "test_graph_rag" @@ -25,18 +21,21 @@ def clean_test_graph_cache(): shutil.rmtree(TEST_GRAPH_DIR) -def test_init_with_dict_schema(): +def test_init_with_dict_schema(raw_schema_dict): """Test SchemaGraphBuilder initialization with a raw schema dictionary.""" - schema_dict = SchemaTestFixtures.raw_schema_dict() - extractor = SchemaTripleExtractor(schema_dict) + from datu.services.schema_rag import SchemaTripleExtractor + + extractor = SchemaTripleExtractor(raw_schema_dict) extractor.create_schema_triples() assert extractor.timestamp == 1234567890.0 assert len(extractor.schema_profiles) == 1 -def test_extract_triples_output(): +def test_extract_triples_output(sample_schema): """Test that triples are extracted correctly from schema objects.""" - schema_profiles = SchemaTestFixtures.sample_schema() + from datu.services.schema_rag import SchemaTripleExtractor + + schema_profiles = sample_schema() extractor = SchemaTripleExtractor(schema_profiles) extractor.paths["triples"] = os.path.join(TEST_GRAPH_DIR, "triples.json") extractor.paths["meta"] = os.path.join(TEST_GRAPH_DIR, "meta.json") @@ -46,6 +45,8 @@ def test_extract_triples_output(): def test_get_attr_dict_vs_object(): """Test the helper method _get_attr for both dicts and objects.""" + from datu.services.schema_rag import SchemaTripleExtractor + extractor = SchemaTripleExtractor([]) obj_dict = {"key1": "val1"} obj_class = type("Dummy", (), {"attr1": "key2"})() @@ -55,15 +56,20 @@ def test_get_attr_dict_vs_object(): def test_is_graph_outdated_returns_true_for_missing_files(tmp_path): """Test is_graph_outdated returns True when graph or meta files are missing.""" + from datu.services.schema_rag import SchemaTripleExtractor + extractor = SchemaTripleExtractor({"timestamp": 1234, "schema_info": []}) extractor.graph_path = str(tmp_path / "missing_triples.json") extractor.meta_path = str(tmp_path / "missing_meta.json") assert extractor.is_rag_outdated() is True -def test_initialize_graph_rebuild_and_cache(): +@pytest.mark.parametrize("timestamp", [9999.0]) +def test_initialize_graph_rebuild_and_cache(sample_schema, timestamp): """Test graph initialization rebuilds and caches the graph correctly.""" - schema = SchemaTestFixtures.sample_schema(timestamp=9999.0) + from datu.services.schema_rag import SchemaGraphBuilder, SchemaTripleExtractor + + schema = sample_schema(timestamp=timestamp) extractor = SchemaTripleExtractor(schema) extractor.create_schema_triples() builder = SchemaGraphBuilder(triples=extractor.triples, is_rag_outdated=True) @@ -84,10 +90,11 @@ def test_initialize_graph_rebuild_and_cache(): @pytest.mark.requires_service -def test_schema_rag_run_query_returns_filtered_schema_dict(): +def test_schema_rag_run_query_returns_filtered_schema_dict(sample_schema): """Test SchemaRAG end-to-end run_query method returns filtered schema.""" - schema = SchemaTestFixtures.sample_schema() - rag = SchemaRAG(schema) + from datu.services.schema_rag import SchemaRAG + + rag = SchemaRAG(sample_schema) result = rag.run_query(["List all customer orders"]) assert isinstance(result, dict) assert "schema_info" in result diff --git a/tests/services/test_sql_generator.py b/tests/services/test_sql_generator.py index f9ccb18..434fd7e 100644 --- a/tests/services/test_sql_generator.py +++ b/tests/services/test_sql_generator.py @@ -23,6 +23,8 @@ from src.datu.services.sql_generator import core from src.datu.services.sql_generator.normalizer import normalize_for_preview +# pylint: disable=redefined-outer-name disable=unused-argument disable=import-outside-toplevel + @pytest.fixture(autouse=True) def _reset_db_connector(monkeypatch): diff --git a/tests/telemetry/__init__.py b/tests/telemetry/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/telemetry/conftest.py b/tests/telemetry/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/telemetry/product/__init__.py b/tests/telemetry/product/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/telemetry/product/conftest.py b/tests/telemetry/product/conftest.py new file mode 100644 index 0000000..83efaf1 --- /dev/null +++ b/tests/telemetry/product/conftest.py @@ -0,0 +1,50 @@ +"""Common fixtures for tests in telemetry product module.""" + +import pytest + +from datu.telemetry.product.events import ProductTelemetryEvent + + +# pylint: disable=import-outside-toplevel +@pytest.fixture +def sample_event_data(): + """Fixture for sample event data.""" + return { + "event_name": "test_event", + "user_id": "12345", + "timestamp": "2024-06-01T12:00:00Z", + "properties": {"plan": "pro", "source": "web"}, + } + + +@pytest.fixture +def event(sample_event_data): + """Fixture for a ProductTelemetryEvent.""" + return ProductTelemetryEvent( + event_name=sample_event_data["event_name"], + user_id=sample_event_data["user_id"], + timestamp=sample_event_data["timestamp"], + properties=sample_event_data["properties"], + ) + + +@pytest.fixture +def sample_event(): + """Fixture for a sample ProductTelemetryEvent.""" + return ProductTelemetryEvent(foo="bar") + + +@pytest.fixture +def telemetry_settings(): + """Fixture for telemetry settings.""" + from datu.telemetry.config import TelemetryConfig + + return TelemetryConfig(api_key="dummy_key", package_name="datu-core") + + +@pytest.fixture +def posthog_client(telemetry_settings): + """Fixture for PostHog client.""" + from datu.telemetry.product.posthog import PostHogClient + + return PostHogClient(telemetry_settings=telemetry_settings) diff --git a/tests/telemetry/product/test_events.py b/tests/telemetry/product/test_events.py new file mode 100644 index 0000000..e981490 --- /dev/null +++ b/tests/telemetry/product/test_events.py @@ -0,0 +1,85 @@ +"""Tests for telemetry product events.""" + +import pytest + +from datu.telemetry.product.events import MCPClientEvent, OpenAIEvent, ProductTelemetryEvent + + +def test_event_initialization(event, sample_event_data): + """Test that event initializes correctly with given data.""" + assert event.properties["event_name"] == sample_event_data["event_name"] + assert event.properties["user_id"] == sample_event_data["user_id"] + assert event.properties["timestamp"] == sample_event_data["timestamp"] + assert event.properties["properties"] == sample_event_data["properties"] + + +def test_event_name_and_batch_key(event): + """Test that event name and batch key are set correctly.""" + assert event.name == "ProductTelemetryEvent" + assert event.batch_key == event.name + + +def test_batching_same_type(event): + """Test that batching works for events of the same type.""" + other = ProductTelemetryEvent(event_name="other_event") + batched = event.batch(other) + + # Batch size increments + assert batched.batch_size == 2 + assert batched is event # batching modifies self + + +def test_batching_different_type_raises(): + """Test that batching raises an error for events of different types.""" + + class AnotherEvent(ProductTelemetryEvent): + pass + + e1 = ProductTelemetryEvent() + e2 = AnotherEvent() + with pytest.raises(ValueError): + e1.batch(e2) + + +def test_batch_size_increment(event): + """Test that batch size increments correctly.""" + # Initial batch_size + assert event.batch_size == 1 + event.batch(ProductTelemetryEvent()) + assert event.batch_size == 2 + event.batch(ProductTelemetryEvent()) + assert event.batch_size == 3 + + +def test_mcp_client_event_properties(): + """Test that MCPClientEvent initializes correctly with given data.""" + servers = ["playwright", "puppeteer"] + event = MCPClientEvent(server_names=servers) + + # Check properties + assert event.properties["mcp_server_names"] == servers + # Name and batch_key should come from base class + assert event.name == "MCPClientEvent" + assert event.batch_key == event.name + assert event.batch_size == 1 + + +def test_openai_event_properties(): + """Test that OpenAIEvent initializes correctly with given data.""" + from datu.app_config import get_app_settings # pylint: disable=import-outside-toplevel + + app_settings = get_app_settings() + data = {"user_id": "123", "action": "test"} + event = OpenAIEvent(**data) + + # Base properties + for k, v in data.items(): + assert event.properties[k] == v + + # Extra property added in subclass + assert event.properties["openai_model"] == app_settings.openai_model + + # Name and batch_key + assert event.name == "OpenAIEvent" + assert event.batch_key == event.name + assert event.batch_size == 1 diff --git a/tests/telemetry/product/test_posthog.py b/tests/telemetry/product/test_posthog.py new file mode 100644 index 0000000..5e83e8b --- /dev/null +++ b/tests/telemetry/product/test_posthog.py @@ -0,0 +1,112 @@ +"""Tests for PostHog telemetry client.""" + +from pathlib import Path +from unittest.mock import patch + +# pylint: disable=import-outside-toplevel disable=redefined-outer-name disable=unused-argument + + +def test_posthog_client_initialization(posthog_client, telemetry_settings): + """Test that PostHogClient initializes correctly.""" + from datu.telemetry.product.posthog import PostHogClient + + assert posthog_client.settings == telemetry_settings + assert isinstance(posthog_client._batched_events, dict) + assert isinstance(posthog_client.session_id, str) + assert posthog_client._user_id == "" + assert posthog_client._user_id_path == PostHogClient.USER_ID_PATH + + +def test_user_id_creation(tmp_path): + """Test that user ID is created and read correctly.""" + from datu.telemetry.config import TelemetryConfig + from datu.telemetry.product.posthog import PostHogClient + + path = tmp_path / "telemetry_user_id" + client = PostHogClient(telemetry_settings=TelemetryConfig()) + client._user_id_path = path + + # file does not exist yet + uid = client.user_id + assert uid != PostHogClient.UNKNOWN_USER_ID + assert path.read_text().strip() == uid + + # file exists, reads the same + uid2 = client.user_id + assert uid2 == uid + + +def test_user_id_fallback_patch(): + """Test that user ID falls back to unknown when file access fails.""" + from datu.telemetry.config import TelemetryConfig + from datu.telemetry.product.posthog import PostHogClient + + client = PostHogClient(telemetry_settings=TelemetryConfig()) + + with ( + patch.object(Path, "exists", side_effect=OSError("fail")), + patch.object(Path, "read_text", side_effect=OSError("fail")), + ): + uid = client.user_id + + assert uid == PostHogClient.UNKNOWN_USER_ID + + +def test_base_context(monkeypatch): + """Test that base context is created correctly.""" + from datu.telemetry.config import TelemetryConfig + from datu.telemetry.product.posthog import PostHogClient + + client = PostHogClient(telemetry_settings=TelemetryConfig(package_name="nonexistent_pkg")) + + context = client._base_context() + assert "python_version" in context + assert "os" in context + assert "os_version" in context + assert context["package_version"] == "unknown" + assert isinstance(context["extras_installed"], dict) + + +def test_capture_single_event(monkeypatch, posthog_client, sample_event): + """Test that a single event is captured correctly.""" + # Ensure _send is called + called = {} + + def fake_send(event): + called["event_name"] = event.name + + posthog_client._send = fake_send + + posthog_client.capture(sample_event) + assert called["event_name"] == sample_event.name + + +def test_capture_batching(monkeypatch): + """Test that event batching works correctly.""" + from datu.telemetry.config import TelemetryConfig + from datu.telemetry.product.events import ProductTelemetryEvent + from datu.telemetry.product.posthog import PostHogClient + + settings = TelemetryConfig(api_key="dummy") + client = PostHogClient(telemetry_settings=settings) + + class BatchEvent(ProductTelemetryEvent): + max_batch_size = 2 + + e1 = BatchEvent(foo=1) + e2 = BatchEvent(foo=2) + + sent = [] + + def fake_send(event): + sent.append(event) + + client._send = fake_send + + client.capture(e1) + assert client._batched_events[e1.batch_key].batch_size == 1 + assert sent == [] + + client.capture(e2) + assert sent[0].batch_size == 2 + assert e1.batch_key not in client._batched_events diff --git a/tests/telemetry/test_config.py b/tests/telemetry/test_config.py new file mode 100644 index 0000000..6ca523d --- /dev/null +++ b/tests/telemetry/test_config.py @@ -0,0 +1,10 @@ +"""Tests for telemetry configuration.""" + + +def test_config(): + """Test that telemetry settings are loaded correctly.""" + from datu.telemetry.config import get_telemetry_settings # pylint: disable=import-outside-toplevel + + settings = get_telemetry_settings() + assert settings.api_key == "phc_m74dfR9nLpm2nipvkL2swyFDtNuQNC9o2FL2CSbh6Je" + assert settings.package_name == "datu-core"