From 36fc979460f03a2fab3a6cd9e792037ce1f074ad Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:44:22 -0700 Subject: [PATCH 1/3] feat: cleanup + refactor instrumentation test package --- src/layerlens/instrument/_capture_config.py | 12 + src/layerlens/instrument/_collector.py | 44 +-- src/layerlens/instrument/_context.py | 14 +- src/layerlens/instrument/_decorator.py | 8 +- src/layerlens/instrument/_span.py | 4 +- .../adapters/frameworks/_base_framework.py | 158 +++++--- .../adapters/frameworks/langchain.py | 56 ++- .../adapters/frameworks/langgraph.py | 4 +- .../adapters/providers/_base_provider.py | 189 +++++----- .../adapters/providers/_emit_helpers.py | 100 +++++ .../adapters/providers/anthropic.py | 151 +++----- .../instrument/adapters/providers/litellm.py | 110 ++---- .../instrument/adapters/providers/openai.py | 145 ++------ tests/instrument/adapters/__init__.py | 0 .../adapters/frameworks/__init__.py | 0 .../adapters/frameworks/conftest.py | 30 ++ .../adapters/frameworks/test_langchain.py | 345 ++++++++++++++++++ .../adapters/frameworks/test_langgraph.py | 188 ++++++++++ .../instrument/adapters/providers/__init__.py | 0 .../instrument/adapters/providers/conftest.py | 100 +++++ .../adapters/providers/test_anthropic.py | 242 ++++++++++++ .../adapters/providers/test_litellm.py | 263 +++++++++++++ .../adapters/providers/test_openai.py | 244 +++++++++++++ .../{ => adapters}/test_registry.py | 0 tests/instrument/test_adapters.py | 167 --------- tests/instrument/test_capture_config.py | 4 +- tests/instrument/test_core.py | 9 +- tests/instrument/test_providers.py | 220 ----------- tests/instrument/test_types.py | 2 +- 29 files changed, 1911 insertions(+), 898 deletions(-) create mode 100644 src/layerlens/instrument/adapters/providers/_emit_helpers.py create mode 100644 tests/instrument/adapters/__init__.py create mode 100644 tests/instrument/adapters/frameworks/__init__.py create mode 100644 tests/instrument/adapters/frameworks/conftest.py create mode 100644 tests/instrument/adapters/frameworks/test_langchain.py create mode 100644 tests/instrument/adapters/frameworks/test_langgraph.py create mode 100644 tests/instrument/adapters/providers/__init__.py create mode 100644 tests/instrument/adapters/providers/conftest.py create mode 100644 tests/instrument/adapters/providers/test_anthropic.py create mode 100644 tests/instrument/adapters/providers/test_litellm.py create mode 100644 tests/instrument/adapters/providers/test_openai.py rename tests/instrument/{ => adapters}/test_registry.py (100%) delete mode 100644 tests/instrument/test_adapters.py delete mode 100644 tests/instrument/test_providers.py diff --git a/src/layerlens/instrument/_capture_config.py b/src/layerlens/instrument/_capture_config.py index 837cc5a..5381123 100644 --- a/src/layerlens/instrument/_capture_config.py +++ b/src/layerlens/instrument/_capture_config.py @@ -89,6 +89,18 @@ class CaptureConfig: # Gates LLM message content (prompts/completions) independently of L-layers capture_content: bool = True + def redact_payload( + self, event_type: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Return a copy of payload with fields removed per config.""" + if not self.capture_content and event_type == "model.invoke": + payload = { + k: v + for k, v in payload.items() + if k not in ("messages", "output_message") + } + return payload + def is_layer_enabled(self, event_type: str) -> bool: """Check if an event type is enabled by this config. diff --git a/src/layerlens/instrument/_collector.py b/src/layerlens/instrument/_collector.py index ba97373..031576f 100644 --- a/src/layerlens/instrument/_collector.py +++ b/src/layerlens/instrument/_collector.py @@ -44,13 +44,7 @@ def emit( if not self._config.is_layer_enabled(event_type): return - # Strip LLM message content when capture_content is off - if not self._config.capture_content and event_type == "model.invoke": - payload = { - k: v - for k, v in payload.items() - if k not in ("messages", "output_message") - } + payload = self._config.redact_payload(event_type, payload) self._sequence += 1 event: Dict[str, Any] = { @@ -66,11 +60,8 @@ def emit( self._chain.add_event(event) self._events.append(event) - def flush(self) -> None: - """Build attestation and upload the trace.""" - if not self._events: - return - + def _build_trace_payload(self) -> Dict[str, Any]: + """Build the attestation envelope and trace payload.""" try: trial = self._chain.finalize() attestation: Dict[str, Any] = { @@ -82,34 +73,21 @@ def flush(self) -> None: log.warning("Failed to build attestation chain", exc_info=True) attestation = {"attestation_error": str(exc)} - payload = { + return { "trace_id": self._trace_id, "events": self._events, "capture_config": self._config.to_dict(), "attestation": attestation, } - upload_trace(self._client, payload) + + def flush(self) -> None: + """Build attestation and upload the trace.""" + if not self._events: + return + upload_trace(self._client, self._build_trace_payload()) async def async_flush(self) -> None: """Async version of flush.""" if not self._events: return - - try: - trial = self._chain.finalize() - attestation: Dict[str, Any] = { - "chain": self._chain.to_dict(), - "root_hash": trial.hash, - "schema_version": "1.0", - } - except Exception as exc: - log.warning("Failed to build attestation chain", exc_info=True) - attestation = {"attestation_error": str(exc)} - - payload = { - "trace_id": self._trace_id, - "events": self._events, - "capture_config": self._config.to_dict(), - "attestation": attestation, - } - await async_upload_trace(self._client, payload) + await async_upload_trace(self._client, self._build_trace_payload()) diff --git a/src/layerlens/instrument/_context.py b/src/layerlens/instrument/_context.py index 0587a95..dc1f873 100644 --- a/src/layerlens/instrument/_context.py +++ b/src/layerlens/instrument/_context.py @@ -11,24 +11,24 @@ _current_span_name: ContextVar[Optional[str]] = ContextVar("_current_span_name", default=None) -class _SpanTokens(NamedTuple): +class _SpanSnapshot(NamedTuple): span_id: Any parent_span_id: Any span_name: Any -def _push_span(span_id: str, name: Optional[str] = None) -> _SpanTokens: +def _push_span(span_id: str, name: Optional[str] = None) -> _SpanSnapshot: """Push a new span onto the context stack. The current span becomes the parent.""" old_span_id = _current_span_id.get() - return _SpanTokens( + return _SpanSnapshot( span_id=_current_span_id.set(span_id), parent_span_id=_parent_span_id.set(old_span_id), span_name=_current_span_name.set(name), ) -def _pop_span(tokens: _SpanTokens) -> None: +def _pop_span(snapshot: _SpanSnapshot) -> None: """Restore the previous span context.""" - _current_span_name.reset(tokens.span_name) - _parent_span_id.reset(tokens.parent_span_id) - _current_span_id.reset(tokens.span_id) + _current_span_name.reset(snapshot.span_name) + _parent_span_id.reset(snapshot.parent_span_id) + _current_span_id.reset(snapshot.span_id) diff --git a/src/layerlens/instrument/_decorator.py b/src/layerlens/instrument/_decorator.py index bfaf570..b4a118c 100644 --- a/src/layerlens/instrument/_decorator.py +++ b/src/layerlens/instrument/_decorator.py @@ -29,7 +29,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: root_span_id = uuid.uuid4().hex[:16] col_token = _current_collector.set(collector) - span_tokens = _push_span(root_span_id, span_name) + span_snapshot = _push_span(root_span_id, span_name) try: collector.emit( "agent.input", @@ -58,7 +58,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: await collector.async_flush() raise finally: - _pop_span(span_tokens) + _pop_span(span_snapshot) _current_collector.reset(col_token) return async_wrapper @@ -71,7 +71,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: root_span_id = uuid.uuid4().hex[:16] col_token = _current_collector.set(collector) - span_tokens = _push_span(root_span_id, span_name) + span_snapshot = _push_span(root_span_id, span_name) try: collector.emit( "agent.input", @@ -100,7 +100,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: collector.flush() raise finally: - _pop_span(span_tokens) + _pop_span(span_snapshot) _current_collector.reset(col_token) return sync_wrapper diff --git a/src/layerlens/instrument/_span.py b/src/layerlens/instrument/_span.py index 9eb239f..0ea2ecd 100644 --- a/src/layerlens/instrument/_span.py +++ b/src/layerlens/instrument/_span.py @@ -18,8 +18,8 @@ def span(name: str) -> Generator[str, None, None]: Yields the span_id string. """ new_span_id = uuid.uuid4().hex[:16] - tokens = _push_span(new_span_id, name) + snapshot = _push_span(new_span_id, name) try: yield new_span_id finally: - _pop_span(tokens) + _pop_span(snapshot) diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index af082cc..197c65e 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -1,79 +1,137 @@ +"""Unified base class for all framework adapters. + +Framework adapters hook into a framework's callback / event / tracing +system and emit LayerLens events. They share a common lifecycle: + + 1. Lazy-init a :class:`TraceCollector` on first event. + 2. Emit events through a thread-safe helper. + 3. Flush the collector when a logical trace ends (root span completes, + agent run finishes, disconnect, etc.). + +Subclasses MUST set ``name`` and implement ``connect()``. +Subclasses SHOULD call ``super().disconnect()`` after unhooking. +""" from __future__ import annotations import uuid -from uuid import UUID -from typing import Any, Dict, Optional, Tuple +import threading +from typing import Any, Dict, Optional from .._base import AdapterInfo, BaseAdapter -from ..._capture_config import CaptureConfig from ..._collector import TraceCollector +from ..._capture_config import CaptureConfig -class FrameworkTracer(BaseAdapter): - """Base class for framework adapters that manage their own collector. - - Framework adapters (LangChain, LangGraph, etc.) receive callbacks - from the framework rather than wrapping SDK methods. They maintain - their own TraceCollector and map framework run_ids to span_ids. - """ +class FrameworkAdapter(BaseAdapter): + """Base for framework adapters with collector lifecycle management.""" - _adapter_name: str = "framework" + name: str # Subclass must set: "crewai", "llamaindex", etc. def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - self._client: Any = None + self._client = client self._config = capture_config or CaptureConfig.standard() + self._lock = threading.Lock() + self._connected = False self._collector: Optional[TraceCollector] = None + self._root_span_id: Optional[str] = None + # Optional run_id → span_id mapping for callback-style frameworks self._span_ids: Dict[str, str] = {} - self._root_run_id: Optional[str] = None - self.connect(client) - def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 - self._client = target - return target - - def disconnect(self) -> None: - self._span_ids.clear() - self._root_run_id = None - self._collector = None - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name=self._adapter_name, - adapter_type="framework", - connected=self._client is not None, - ) + # ------------------------------------------------------------------ + # Collector lifecycle + # ------------------------------------------------------------------ def _ensure_collector(self) -> TraceCollector: + """Lazily create a collector and root span ID.""" if self._collector is None: self._collector = TraceCollector(self._client, self._config) + self._root_span_id = uuid.uuid4().hex[:16] return self._collector - def _get_or_create_span_id( - self, run_id: UUID, parent_run_id: Optional[UUID] = None - ) -> Tuple[str, Optional[str]]: - rid = str(run_id) - if rid not in self._span_ids: - self._span_ids[rid] = uuid.uuid4().hex[:16] - span_id = self._span_ids[rid] - parent_span_id = self._span_ids.get(str(parent_run_id)) if parent_run_id else None - if self._root_run_id is None: - self._root_run_id = rid - return span_id, parent_span_id + @staticmethod + def _new_span_id() -> str: + return uuid.uuid4().hex[:16] + + # ------------------------------------------------------------------ + # Event emission (thread-safe) + # ------------------------------------------------------------------ def _emit( self, event_type: str, payload: Dict[str, Any], - run_id: UUID, - parent_run_id: Optional[UUID] = None, + span_id: Optional[str] = None, + parent_span_id: Optional[str] = None, + span_name: Optional[str] = None, ) -> None: - collector = self._ensure_collector() - span_id, parent_span_id = self._get_or_create_span_id(run_id, parent_run_id) - collector.emit(event_type, payload, span_id=span_id, parent_span_id=parent_span_id) + """Thread-safe event emission through the collector.""" + with self._lock: + collector = self._ensure_collector() + sid = span_id or self._new_span_id() + parent = parent_span_id or self._root_span_id + collector.emit( + event_type, payload, + span_id=sid, parent_span_id=parent, span_name=span_name, + ) - def _maybe_flush(self, run_id: UUID) -> None: - if str(run_id) == self._root_run_id and self._collector is not None: - self._collector.flush() - self._span_ids.clear() - self._root_run_id = None + # ------------------------------------------------------------------ + # Run ID → span ID mapping (opt-in for callback-style frameworks) + # ------------------------------------------------------------------ + + def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: + """Map a framework run_id to a span_id, creating one if needed. + + Returns ``(span_id, parent_span_id)``. Useful for frameworks + (LangChain, CrewAI, OpenAI Agents) that assign their own run + identifiers to each step. + """ + rid = str(run_id) + if rid not in self._span_ids: + self._span_ids[rid] = self._new_span_id() + span_id = self._span_ids[rid] + parent_span_id = self._span_ids.get(str(parent_run_id)) if parent_run_id else None + return span_id, parent_span_id + + # ------------------------------------------------------------------ + # Flush + # ------------------------------------------------------------------ + + def _flush_collector(self) -> None: + """Flush the current collector and reset state.""" + with self._lock: + collector = self._collector self._collector = None + self._root_span_id = None + self._span_ids.clear() + if collector is not None: + collector.flush() + + # ------------------------------------------------------------------ + # BaseAdapter interface + # ------------------------------------------------------------------ + + def connect(self, target: Any = None, **kwargs: Any) -> Any: + """Mark the adapter as connected. + + Callback-style adapters (LangChain, LangGraph) are passed directly + to the framework, so ``connect()`` just flips the flag. Adapters + that need registration (CrewAI, LlamaIndex, etc.) should override. + """ + self._connected = True + return target + + def disconnect(self) -> None: + """Flush remaining events and mark as disconnected. + + Subclasses should unhook from the framework first, then call + ``super().disconnect()``. + """ + self._flush_collector() + self._connected = False + + def adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name=self.name, + adapter_type="framework", + connected=self._connected, + ) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py index bfa1784..5b14f0e 100644 --- a/src/layerlens/instrument/adapters/frameworks/langchain.py +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -3,7 +3,7 @@ from uuid import UUID from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence -from ._base_framework import FrameworkTracer +from ._base_framework import FrameworkAdapter if TYPE_CHECKING: from ..._capture_config import CaptureConfig @@ -20,12 +20,32 @@ def __init_subclass__(cls, **kwargs: Any) -> None: ) -class LangChainCallbackHandler(BaseCallbackHandler, FrameworkTracer): - _adapter_name: str = "langchain" +class LangChainCallbackHandler(BaseCallbackHandler, FrameworkAdapter): + name = "langchain" def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: BaseCallbackHandler.__init__(self) - FrameworkTracer.__init__(self, client, capture_config=capture_config) + FrameworkAdapter.__init__(self, client, capture_config=capture_config) + self._root_run_id: Optional[str] = None + + def _emit_for_run( + self, + event_type: str, + payload: Dict[str, Any], + run_id: UUID, + parent_run_id: Optional[UUID] = None, + ) -> None: + """Emit an event, mapping framework run_ids to span_ids.""" + span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) + rid = str(run_id) + if self._root_run_id is None: + self._root_run_id = rid + self._emit(event_type, payload, span_id=span_id, parent_span_id=parent_span_id) + + def _maybe_flush(self, run_id: UUID) -> None: + if str(run_id) == self._root_run_id and self._collector is not None: + self._flush_collector() + self._root_run_id = None # -- Chain -- @@ -40,7 +60,7 @@ def on_chain_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) def on_chain_end( self, @@ -50,7 +70,7 @@ def on_chain_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.output", {"output": outputs, "status": "ok"}, run_id) + self._emit_for_run("agent.output", {"output": outputs, "status": "ok"}, run_id) self._maybe_flush(run_id) def on_chain_error( @@ -61,7 +81,7 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- LLM -- @@ -77,7 +97,7 @@ def on_llm_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit("model.invoke", {"name": name, "messages": prompts}, run_id, parent_run_id) + self._emit_for_run("model.invoke", {"name": name, "messages": prompts}, run_id, parent_run_id) def on_chat_model_start( self, @@ -90,7 +110,7 @@ def on_chat_model_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit( + self._emit_for_run( "model.invoke", {"name": name, "messages": [[_serialize_lc_message(m) for m in batch] for batch in messages]}, run_id, @@ -120,7 +140,7 @@ def on_llm_end( model_name = llm_output.get("model_name") if model_name or output: - self._emit( + self._emit_for_run( "model.invoke", {"model": model_name, "output_message": output}, run_id, @@ -129,7 +149,7 @@ def on_llm_end( usage = llm_output.get("token_usage", {}) if usage: - self._emit("cost.record", usage, run_id, parent_run_id) + self._emit_for_run("cost.record", usage, run_id, parent_run_id) self._maybe_flush(run_id) @@ -141,7 +161,7 @@ def on_llm_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- Tool -- @@ -156,7 +176,7 @@ def on_tool_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "tool") - self._emit("tool.call", {"name": name, "input": input_str}, run_id, parent_run_id) + self._emit_for_run("tool.call", {"name": name, "input": input_str}, run_id, parent_run_id) def on_tool_end( self, @@ -166,7 +186,7 @@ def on_tool_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("tool.result", {"output": output}, run_id) + self._emit_for_run("tool.result", {"output": output}, run_id) self._maybe_flush(run_id) def on_tool_error( @@ -177,7 +197,7 @@ def on_tool_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- Retriever -- @@ -192,7 +212,7 @@ def on_retriever_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "retriever") - self._emit("tool.call", {"name": name, "input": query}, run_id, parent_run_id) + self._emit_for_run("tool.call", {"name": name, "input": query}, run_id, parent_run_id) def on_retriever_end( self, @@ -203,7 +223,7 @@ def on_retriever_end( **kwargs: Any, ) -> None: output = [_serialize_lc_document(d) for d in documents] - self._emit("tool.result", {"output": output}, run_id) + self._emit_for_run("tool.result", {"output": output}, run_id) self._maybe_flush(run_id) def on_retriever_error( @@ -214,7 +234,7 @@ def on_retriever_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- Text (required by base) -- diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph.py b/src/layerlens/instrument/adapters/frameworks/langgraph.py index 47d2439..f4b666a 100644 --- a/src/layerlens/instrument/adapters/frameworks/langgraph.py +++ b/src/layerlens/instrument/adapters/frameworks/langgraph.py @@ -7,7 +7,7 @@ class LangGraphCallbackHandler(LangChainCallbackHandler): - _adapter_name: str = "langgraph" + name = "langgraph" def on_chain_start( self, @@ -38,4 +38,4 @@ def on_chain_start( if node_name: name = node_name - self._emit("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) diff --git a/src/layerlens/instrument/adapters/providers/_base_provider.py b/src/layerlens/instrument/adapters/providers/_base_provider.py index cd02bb2..a109c16 100644 --- a/src/layerlens/instrument/adapters/providers/_base_provider.py +++ b/src/layerlens/instrument/adapters/providers/_base_provider.py @@ -1,96 +1,101 @@ from __future__ import annotations -import uuid -from typing import Any, Dict, Callable - -from ..._context import _current_collector, _current_span_id - - -def emit_llm_events( - name: str, - kwargs: Dict[str, Any], - response: Any, - extract_output: Callable[[Any], Any], - extract_meta: Callable[[Any], Dict[str, Any]], - capture_params: frozenset[str], - latency_ms: float, -) -> None: - """Emit model.invoke + cost.record events for an LLM call. - - Builds the full payload -- the collector handles CaptureConfig gating - (L3 suppresses model.invoke entirely, capture_content strips messages). - """ - collector = _current_collector.get() - if collector is None: - return - - parent_span_id = _current_span_id.get() - span_id = uuid.uuid4().hex[:16] - response_meta = extract_meta(response) - - collector.emit( - "model.invoke", - { - "name": name, - "latency_ms": latency_ms, - "parameters": {k: kwargs[k] for k in capture_params if k in kwargs}, - "messages": _extract_messages(kwargs), - "output_message": extract_output(response), - **response_meta, - }, - span_id=span_id, - parent_span_id=parent_span_id, - ) - - usage = response_meta.get("usage", {}) - if usage: - collector.emit( - "cost.record", - { - "provider": name.split(".")[0], - "model": response_meta.get("response_model", kwargs.get("model")), - **usage, - }, - span_id=span_id, - parent_span_id=parent_span_id, - ) +import abc +import time +import logging +from typing import Any, Dict + +from .._base import AdapterInfo, BaseAdapter +from ._emit_helpers import emit_llm_events, emit_llm_error +from ..._context import _current_collector + +log: logging.Logger = logging.getLogger(__name__) + + +class MonkeyPatchProvider(BaseAdapter): + """Base for providers that monkey-patch SDK client or module methods.""" + + name: str + capture_params: frozenset[str] + + def __init__(self) -> None: + self._client: Any = None + self._originals: Dict[str, Any] = {} + + @staticmethod + @abc.abstractmethod + def extract_output(response: Any) -> Any: ... + @staticmethod + @abc.abstractmethod + def extract_meta(response: Any) -> Dict[str, Any]: ... -def emit_llm_error( - name: str, - error: Exception, - latency_ms: float, -) -> None: - """Emit agent.error event for a failed LLM call.""" - collector = _current_collector.get() - parent_span_id = _current_span_id.get() - if collector is None: - return - - span_id = uuid.uuid4().hex[:16] - collector.emit( - "agent.error", - {"name": name, "error": str(error), "latency_ms": latency_ms}, - span_id=span_id, - parent_span_id=parent_span_id, - ) - - -def _extract_messages(kwargs: Dict[str, Any]) -> Any: - messages = kwargs.get("messages") - if messages is not None: - return [_serialize_message(m) for m in messages] - for key in ("prompt", "contents", "input"): - val = kwargs.get(key) - if val is not None: - return val - return None - - -def _serialize_message(msg: Any) -> Any: - if isinstance(msg, dict): - return msg - try: - return {"role": msg.role, "content": msg.content} - except AttributeError: - return str(msg) + def _wrap_sync(self, event_name: str, original: Any) -> Any: + extract_output = self.extract_output + extract_meta = self.extract_meta + capture_params = self.capture_params + + def wrapped(*args: Any, **kwargs: Any) -> Any: + if _current_collector.get() is None: + return original(*args, **kwargs) + start = time.time() + try: + response = original(*args, **kwargs) + except Exception as exc: + latency_ms = (time.time() - start) * 1000 + emit_llm_error(event_name, exc, latency_ms) + raise + latency_ms = (time.time() - start) * 1000 + emit_llm_events( + event_name, kwargs, response, + extract_output, extract_meta, capture_params, latency_ms, + ) + return response + + return wrapped + + def _wrap_async(self, event_name: str, original: Any) -> Any: + extract_output = self.extract_output + extract_meta = self.extract_meta + capture_params = self.capture_params + + async def wrapped(*args: Any, **kwargs: Any) -> Any: + if _current_collector.get() is None: + return await original(*args, **kwargs) + start = time.time() + try: + response = await original(*args, **kwargs) + except Exception as exc: + latency_ms = (time.time() - start) * 1000 + emit_llm_error(event_name, exc, latency_ms) + raise + latency_ms = (time.time() - start) * 1000 + emit_llm_events( + event_name, kwargs, response, + extract_output, extract_meta, capture_params, latency_ms, + ) + return response + + return wrapped + + def disconnect(self) -> None: + if self._client is None: + return + for key, orig in self._originals.items(): + try: + parts = key.split(".") + obj = self._client + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], orig) + except Exception: + log.warning("Could not restore %s", key) + self._client = None + self._originals.clear() + + def adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name=self.name, + adapter_type="provider", + connected=self._client is not None, + ) diff --git a/src/layerlens/instrument/adapters/providers/_emit_helpers.py b/src/layerlens/instrument/adapters/providers/_emit_helpers.py new file mode 100644 index 0000000..d46a9ed --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_emit_helpers.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import uuid +from typing import Any, Dict, Callable + +from ..._context import _current_collector, _current_span_id + + +def emit_llm_events( + name: str, + kwargs: Dict[str, Any], + response: Any, + extract_output: Callable[[Any], Any], + extract_meta: Callable[[Any], Dict[str, Any]], + capture_params: frozenset[str], + latency_ms: float, +) -> None: + """Emit model.invoke + cost.record events for an LLM call. + + Builds the full payload -- the collector handles CaptureConfig gating + (L3 suppresses model.invoke entirely, capture_content strips messages). + """ + collector = _current_collector.get() + if collector is None: + return + + parent_span_id = _current_span_id.get() + span_id = uuid.uuid4().hex[:16] + response_meta = extract_meta(response) + + # Resolve model name: prefer response_model (actual model used), fall back to kwargs + model_name = response_meta.get("response_model") or kwargs.get("model") + + collector.emit( + "model.invoke", + { + "name": name, + "model": model_name, + "latency_ms": latency_ms, + "parameters": {k: kwargs[k] for k in capture_params if k in kwargs}, + "messages": _extract_messages(kwargs), + "output_message": extract_output(response), + **response_meta, + }, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + usage = response_meta.get("usage", {}) + if usage: + collector.emit( + "cost.record", + { + "provider": name.split(".")[0], + "model": response_meta.get("response_model", kwargs.get("model")), + **usage, + }, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + +def emit_llm_error( + name: str, + error: Exception, + latency_ms: float, +) -> None: + """Emit agent.error event for a failed LLM call.""" + collector = _current_collector.get() + parent_span_id = _current_span_id.get() + if collector is None: + return + + span_id = uuid.uuid4().hex[:16] + collector.emit( + "agent.error", + {"name": name, "error": str(error), "latency_ms": latency_ms}, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + +def _extract_messages(kwargs: Dict[str, Any]) -> Any: + messages = kwargs.get("messages") + if messages is not None: + return [_serialize_message(m) for m in messages] + for key in ("prompt", "contents", "input"): + val = kwargs.get(key) + if val is not None: + return val + return None + + +def _serialize_message(msg: Any) -> Any: + if isinstance(msg, dict): + return msg + try: + return {"role": msg.role, "content": msg.content} + except AttributeError: + return str(msg) diff --git a/src/layerlens/instrument/adapters/providers/anthropic.py b/src/layerlens/instrument/adapters/providers/anthropic.py index 675b558..0a2b17d 100644 --- a/src/layerlens/instrument/adapters/providers/anthropic.py +++ b/src/layerlens/instrument/adapters/providers/anthropic.py @@ -1,14 +1,8 @@ from __future__ import annotations -import time -import logging from typing import Any, Dict -from .._base import AdapterInfo, BaseAdapter -from ._base_provider import emit_llm_events, emit_llm_error -from ..._context import _current_collector - -log: logging.Logger = logging.getLogger(__name__) +from ._base_provider import MonkeyPatchProvider _CAPTURE_PARAMS = frozenset( { @@ -23,10 +17,42 @@ ) -class AnthropicProvider(BaseAdapter): - def __init__(self) -> None: - self._client: Any = None - self._originals: Dict[str, Any] = {} +class AnthropicProvider(MonkeyPatchProvider): + name = "anthropic" + capture_params = _CAPTURE_PARAMS + + @staticmethod + def extract_output(response: Any) -> Any: + try: + content = response.content + if content: + block = content[0] + return {"type": block.type, "text": getattr(block, "text", None)} + except (AttributeError, IndexError): + pass + return None + + @staticmethod + def extract_meta(response: Any) -> Dict[str, Any]: + meta: Dict[str, Any] = {} + try: + usage = response.usage + if usage is not None: + meta["usage"] = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + except AttributeError: + pass + try: + meta["response_model"] = response.model + except AttributeError: + pass + try: + meta["stop_reason"] = response.stop_reason + except AttributeError: + pass + return meta def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 self._client = target @@ -34,110 +60,19 @@ def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 if hasattr(target, "messages"): orig = target.messages.create self._originals["messages.create"] = orig - target.messages.create = self._wrap_sync(orig) + target.messages.create = self._wrap_sync( + "anthropic.messages.create", orig + ) if hasattr(target.messages, "acreate"): async_orig = target.messages.acreate self._originals["messages.acreate"] = async_orig - target.messages.acreate = self._wrap_async(async_orig) + target.messages.acreate = self._wrap_async( + "anthropic.messages.create", async_orig + ) return target - def disconnect(self) -> None: - if self._client is None: - return - for key, orig in self._originals.items(): - try: - parts = key.split(".") - obj = self._client - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], orig) - except Exception: - log.warning("Could not restore %s", key) - self._client = None - self._originals.clear() - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name="anthropic", - adapter_type="provider", - connected=self._client is not None, - ) - - def _wrap_sync(self, original: Any) -> Any: - def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return original(*args, **kwargs) - start = time.time() - try: - response = original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("anthropic.messages.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "anthropic.messages.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - def _wrap_async(self, original: Any) -> Any: - async def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return await original(*args, **kwargs) - start = time.time() - try: - response = await original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("anthropic.messages.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "anthropic.messages.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - -def _extract_output(response: Any) -> Any: - try: - content = response.content - if content: - block = content[0] - return {"type": block.type, "text": getattr(block, "text", None)} - except (AttributeError, IndexError): - pass - return None - - -def _extract_response_meta(response: Any) -> Dict[str, Any]: - meta: Dict[str, Any] = {} - try: - usage = response.usage - if usage is not None: - meta["usage"] = { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - } - except AttributeError: - pass - try: - meta["response_model"] = response.model - except AttributeError: - pass - try: - meta["stop_reason"] = response.stop_reason - except AttributeError: - pass - return meta - # --- Convenience API --- diff --git a/src/layerlens/instrument/adapters/providers/litellm.py b/src/layerlens/instrument/adapters/providers/litellm.py index e7bbda8..784e7e8 100644 --- a/src/layerlens/instrument/adapters/providers/litellm.py +++ b/src/layerlens/instrument/adapters/providers/litellm.py @@ -1,12 +1,9 @@ from __future__ import annotations -import time -from typing import Any +from typing import Any, Dict -from .._base import AdapterInfo, BaseAdapter -from .openai import _extract_output, _extract_response_meta -from ._base_provider import emit_llm_events, emit_llm_error -from ..._context import _current_collector +from ._base_provider import MonkeyPatchProvider +from .openai import OpenAIProvider _CAPTURE_PARAMS = frozenset( { @@ -21,91 +18,40 @@ ) -class LiteLLMProvider(BaseAdapter): - def __init__(self) -> None: - self._original_completion: Any = None - self._original_acompletion: Any = None - self._connected = False +class LiteLLMProvider(MonkeyPatchProvider): + name = "litellm" + capture_params = _CAPTURE_PARAMS + + @staticmethod + def extract_output(response: Any) -> Any: + return OpenAIProvider.extract_output(response) + + @staticmethod + def extract_meta(response: Any) -> Dict[str, Any]: + return OpenAIProvider.extract_meta(response) def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 try: import litellm except ImportError as err: raise ImportError( - "The 'litellm' package is required for LiteLLM instrumentation. Install it with: pip install litellm" + "The 'litellm' package is required for LiteLLM instrumentation. " + "Install it with: pip install litellm" ) from err - if self._original_completion is None: - self._original_completion = litellm.completion - orig_sync = self._original_completion - - def patched_completion(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return orig_sync(*args, **kwargs) - start = time.time() - try: - response = orig_sync(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("litellm.completion", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "litellm.completion", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - litellm.completion = patched_completion - - if self._original_acompletion is None: - self._original_acompletion = litellm.acompletion - orig_async = self._original_acompletion - - async def patched_acompletion(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return await orig_async(*args, **kwargs) - start = time.time() - try: - response = await orig_async(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("litellm.acompletion", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "litellm.acompletion", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - litellm.acompletion = patched_acompletion - - self._connected = True - return target + self._client = litellm - def disconnect(self) -> None: - try: - import litellm - except ImportError: - self._connected = False - return - - if self._original_completion is not None: - litellm.completion = self._original_completion - self._original_completion = None - if self._original_acompletion is not None: - litellm.acompletion = self._original_acompletion - self._original_acompletion = None - - self._connected = False - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name="litellm", - adapter_type="provider", - connected=self._connected, - ) + if "completion" not in self._originals: + orig_sync = litellm.completion + self._originals["completion"] = orig_sync + litellm.completion = self._wrap_sync("litellm.completion", orig_sync) + + if "acompletion" not in self._originals: + orig_async = litellm.acompletion + self._originals["acompletion"] = orig_async + litellm.acompletion = self._wrap_async("litellm.acompletion", orig_async) + + return target # --- Convenience API --- diff --git a/src/layerlens/instrument/adapters/providers/openai.py b/src/layerlens/instrument/adapters/providers/openai.py index bdb8480..d09779f 100644 --- a/src/layerlens/instrument/adapters/providers/openai.py +++ b/src/layerlens/instrument/adapters/providers/openai.py @@ -1,14 +1,8 @@ from __future__ import annotations -import time -import logging from typing import Any, Dict -from .._base import AdapterInfo, BaseAdapter -from ._base_provider import emit_llm_events, emit_llm_error -from ..._context import _current_collector - -log: logging.Logger = logging.getLogger(__name__) +from ._base_provider import MonkeyPatchProvider _CAPTURE_PARAMS = frozenset( { @@ -24,10 +18,39 @@ ) -class OpenAIProvider(BaseAdapter): - def __init__(self) -> None: - self._client: Any = None - self._originals: Dict[str, Any] = {} +class OpenAIProvider(MonkeyPatchProvider): + name = "openai" + capture_params = _CAPTURE_PARAMS + + @staticmethod + def extract_output(response: Any) -> Any: + try: + choices = response.choices + if choices: + msg = choices[0].message + return {"role": msg.role, "content": msg.content} + except (AttributeError, IndexError): + pass + return None + + @staticmethod + def extract_meta(response: Any) -> Dict[str, Any]: + meta: Dict[str, Any] = {} + try: + usage = response.usage + if usage is not None: + meta["usage"] = { + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + } + except AttributeError: + pass + try: + meta["response_model"] = response.model + except AttributeError: + pass + return meta def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 self._client = target @@ -35,107 +58,19 @@ def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 if hasattr(target, "chat") and hasattr(target.chat, "completions"): orig = target.chat.completions.create self._originals["chat.completions.create"] = orig - target.chat.completions.create = self._wrap_sync(orig) + target.chat.completions.create = self._wrap_sync( + "openai.chat.completions.create", orig + ) if hasattr(target.chat.completions, "acreate"): async_orig = target.chat.completions.acreate self._originals["chat.completions.acreate"] = async_orig - target.chat.completions.acreate = self._wrap_async(async_orig) + target.chat.completions.acreate = self._wrap_async( + "openai.chat.completions.create", async_orig + ) return target - def disconnect(self) -> None: - if self._client is None: - return - for key, orig in self._originals.items(): - try: - parts = key.split(".") - obj = self._client - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], orig) - except Exception: - log.warning("Could not restore %s", key) - self._client = None - self._originals.clear() - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name="openai", - adapter_type="provider", - connected=self._client is not None, - ) - - def _wrap_sync(self, original: Any) -> Any: - def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return original(*args, **kwargs) - start = time.time() - try: - response = original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("openai.chat.completions.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "openai.chat.completions.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - def _wrap_async(self, original: Any) -> Any: - async def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return await original(*args, **kwargs) - start = time.time() - try: - response = await original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("openai.chat.completions.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "openai.chat.completions.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - -def _extract_output(response: Any) -> Any: - try: - choices = response.choices - if choices: - msg = choices[0].message - return {"role": msg.role, "content": msg.content} - except (AttributeError, IndexError): - pass - return None - - -def _extract_response_meta(response: Any) -> Dict[str, Any]: - meta: Dict[str, Any] = {} - try: - usage = response.usage - if usage is not None: - meta["usage"] = { - "prompt_tokens": usage.prompt_tokens, - "completion_tokens": usage.completion_tokens, - "total_tokens": usage.total_tokens, - } - except AttributeError: - pass - try: - meta["response_model"] = response.model - except AttributeError: - pass - return meta - # --- Convenience API --- diff --git a/tests/instrument/adapters/__init__.py b/tests/instrument/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/frameworks/__init__.py b/tests/instrument/adapters/frameworks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/frameworks/conftest.py b/tests/instrument/adapters/frameworks/conftest.py new file mode 100644 index 0000000..fb8d90e --- /dev/null +++ b/tests/instrument/adapters/frameworks/conftest.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import json +from typing import Any, Dict +from unittest.mock import Mock + + +# Re-export from root conftest so framework tests can do `from .conftest import ...` +from ...conftest import find_event, find_events # noqa: F401 + + +def capture_framework_trace(mock_client: Mock) -> Dict[str, Any]: + """Capture the uploaded trace payload from a framework adapter. + + Accumulates events across multiple flushes (some adapters use + multiple collectors). + """ + uploaded: Dict[str, Any] = {"events": []} + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + payload = data[0] + uploaded["trace_id"] = payload.get("trace_id") + uploaded["events"].extend(payload.get("events", [])) + uploaded["capture_config"] = payload.get("capture_config", {}) + uploaded["attestation"] = payload.get("attestation", {}) + + mock_client.traces.upload.side_effect = _capture + return uploaded diff --git a/tests/instrument/adapters/frameworks/test_langchain.py b/tests/instrument/adapters/frameworks/test_langchain.py new file mode 100644 index 0000000..d2a3057 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_langchain.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from uuid import uuid4 +from unittest.mock import Mock + +from langchain_core.callbacks import BaseCallbackHandler + +from layerlens.instrument.adapters.frameworks.langchain import LangChainCallbackHandler + +from .conftest import capture_framework_trace, find_event, find_events + + +# --------------------------------------------------------------------------- +# Sanity: real base class +# --------------------------------------------------------------------------- + + +class TestBaseClass: + def test_inherits_langchain_base(self): + assert issubclass(LangChainCallbackHandler, BaseCallbackHandler) + + def test_name(self): + handler = LangChainCallbackHandler(Mock()) + assert handler.name == "langchain" + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def test_chain_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "RunnableSequence", "id": ["RunnableSequence"]}, + {"question": "What is AI?"}, + run_id=chain_id, + ) + handler.on_chain_end({"output": "AI is..."}, run_id=chain_id) + + events = uploaded["events"] + agent_input = find_event(events, "agent.input") + assert agent_input["payload"]["name"] == "RunnableSequence" + assert agent_input["payload"]["input"] == {"question": "What is AI?"} + + agent_output = find_event(events, "agent.output") + assert agent_output["payload"]["status"] == "ok" + + def test_llm_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start( + {"name": "Chain"}, {"input": "x"}, run_id=chain_id, + ) + handler.on_llm_start( + {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, + ["What is AI?"], + run_id=llm_id, + parent_run_id=chain_id, + ) + + llm_response = Mock() + llm_response.generations = [[Mock(text="AI is...")]] + llm_response.llm_output = { + "token_usage": {"total_tokens": 50}, + "model_name": "gpt-4", + } + handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_chain_end({"output": "AI is..."}, run_id=chain_id) + + events = uploaded["events"] + + model_invokes = find_events(events, "model.invoke") + assert len(model_invokes) >= 1 + # Start event has name and messages + start_invoke = [m for m in model_invokes if m["payload"].get("name") == "ChatOpenAI"] + assert len(start_invoke) == 1 + # End event has model and output + end_invoke = [m for m in model_invokes if m["payload"].get("model") == "gpt-4"] + assert len(end_invoke) == 1 + assert end_invoke[0]["payload"]["output_message"] == "AI is..." + + cost = find_event(events, "cost.record") + assert cost["payload"]["total_tokens"] == 50 + + def test_chat_model_start(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + chat_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + msg = Mock() + msg.type = "human" + msg.content = "Hello" + handler.on_chat_model_start( + {"name": "ChatAnthropic"}, + [[msg]], + run_id=chat_id, + parent_run_id=chain_id, + ) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + invoke = find_event(events, "model.invoke") + assert invoke["payload"]["name"] == "ChatAnthropic" + assert invoke["payload"]["messages"] == [[{"type": "human", "content": "Hello"}]] + + +# --------------------------------------------------------------------------- +# Tool and retriever events +# --------------------------------------------------------------------------- + + +class TestToolsAndRetrievers: + def test_tool_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start( + {"name": "search"}, "query text", + run_id=tool_id, parent_run_id=chain_id, + ) + handler.on_tool_end("search results", run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["name"] == "search" + assert tool_call["payload"]["input"] == "query text" + + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["output"] == "search results" + + def test_retriever_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start( + {"name": "vectorstore"}, "query", + run_id=ret_id, parent_run_id=chain_id, + ) + docs = [Mock(page_content="doc text", metadata={"source": "a.txt"})] + handler.on_retriever_end(docs, run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["name"] == "vectorstore" + + tool_result = find_event(events, "tool.result") + output = tool_result["payload"]["output"] + assert output[0]["page_content"] == "doc text" + assert output[0]["metadata"] == {"source": "a.txt"} + + def test_combined_tools_and_retrievers(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("results", run_id=tool_id) + handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) + handler.on_retriever_end([Mock(page_content="d", metadata={})], run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + assert len(find_events(events, "tool.call")) == 2 + assert len(find_events(events, "tool.result")) == 2 + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + def test_chain_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) + handler.on_chain_error(ValueError("broke"), run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "broke" + assert error["payload"]["status"] == "error" + + def test_llm_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_error(RuntimeError("timeout"), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "timeout" + + def test_tool_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_error(RuntimeError("404"), run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "404" + + def test_retriever_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) + handler.on_retriever_error(ConnectionError("down"), run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "down" + + +# --------------------------------------------------------------------------- +# Parent-child span relationships +# --------------------------------------------------------------------------- + + +class TestSpanRelationships: + def test_llm_parent_is_chain(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start( + {"name": "LLM"}, ["prompt"], + run_id=llm_id, parent_run_id=chain_id, + ) + llm_response = Mock() + llm_response.generations = [[Mock(text="out")]] + llm_response.llm_output = {} + handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + chain_input = find_event(events, "agent.input") + llm_invoke = [e for e in find_events(events, "model.invoke") if e["payload"].get("name") == "LLM"][0] + assert llm_invoke["parent_span_id"] == chain_input["span_id"] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_null_serialized(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + run_id = uuid4() + handler.on_chain_start(None, {"input": "x"}, run_id=run_id) + handler.on_chain_end({}, run_id=run_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "unknown" + + def test_empty_serialized_id(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + run_id = uuid4() + handler.on_chain_start({"id": ["FallbackName"]}, {}, run_id=run_id) + handler.on_chain_end({}, run_id=run_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "FallbackName" + + def test_llm_end_no_output(self, mock_client): + """LLM response with no generations should not crash.""" + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["p"], run_id=llm_id, parent_run_id=chain_id) + + empty_response = Mock() + empty_response.generations = [] + empty_response.llm_output = None + handler.on_llm_end(empty_response, run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + # Should complete without error — no model.invoke end event since no output/model + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info(self): + handler = LangChainCallbackHandler(Mock()) + info = handler.adapter_info() + assert info.name == "langchain" + assert info.adapter_type == "framework" diff --git a/tests/instrument/adapters/frameworks/test_langgraph.py b/tests/instrument/adapters/frameworks/test_langgraph.py new file mode 100644 index 0000000..7ff6e9d --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_langgraph.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from uuid import uuid4 +from unittest.mock import Mock + +from langchain_core.callbacks import BaseCallbackHandler + +from layerlens.instrument.adapters.frameworks.langgraph import LangGraphCallbackHandler + +from .conftest import capture_framework_trace, find_event, find_events + + +# --------------------------------------------------------------------------- +# Sanity: real base class +# --------------------------------------------------------------------------- + + +class TestBaseClass: + def test_inherits_langchain_base(self): + assert issubclass(LangGraphCallbackHandler, BaseCallbackHandler) + + def test_name(self): + handler = LangGraphCallbackHandler(Mock()) + assert handler.name == "langgraph" + + +# --------------------------------------------------------------------------- +# Inherited LangChain behavior +# --------------------------------------------------------------------------- + + +class TestInheritedBehavior: + """LangGraph inherits all LangChain callbacks except on_chain_start.""" + + def test_llm_events_inherited(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Graph"}, {}, run_id=chain_id) + handler.on_llm_start( + {"name": "ChatOpenAI"}, ["prompt"], + run_id=llm_id, parent_run_id=chain_id, + ) + llm_response = Mock() + llm_response.generations = [[Mock(text="output")]] + llm_response.llm_output = {"model_name": "gpt-4", "token_usage": {"total_tokens": 10}} + handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + assert len(find_events(events, "model.invoke")) >= 1 + assert find_event(events, "cost.record")["payload"]["total_tokens"] == 10 + + def test_tool_events_inherited(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Graph"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("results", run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + assert find_event(events, "tool.call")["payload"]["name"] == "search" + assert find_event(events, "tool.result")["payload"]["output"] == "results" + + def test_error_handling_inherited(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start({"name": "Graph"}, {}, run_id=chain_id) + handler.on_chain_error(RuntimeError("graph failed"), run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "graph failed" + + +# --------------------------------------------------------------------------- +# LangGraph-specific: on_chain_start node extraction +# --------------------------------------------------------------------------- + + +class TestNodeExtraction: + def test_extracts_node_from_tags(self, mock_client): + """LangGraph passes node names as plain tags (no colon).""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "RunnableSequence"}, + {"input": "hello"}, + run_id=chain_id, + tags=["graph:step:1", "retriever_node"], + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "retriever_node" + + def test_extracts_node_from_metadata(self, mock_client): + """LangGraph puts node name in metadata.langgraph_node.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "RunnableSequence"}, + {"input": "hello"}, + run_id=chain_id, + metadata={"langgraph_node": "agent_node"}, + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "agent_node" + + def test_metadata_overrides_tags(self, mock_client): + """When both tags and metadata provide a node name, metadata wins.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "Seq"}, + {}, + run_id=chain_id, + tags=["tag_node"], + metadata={"langgraph_node": "meta_node"}, + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "meta_node" + + def test_falls_back_to_serialized_name(self, mock_client): + """Without tags or metadata, falls back to serialized name.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "MyCustomChain"}, + {}, + run_id=chain_id, + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "MyCustomChain" + + def test_skips_graph_step_tags(self, mock_client): + """Tags starting with 'graph:step:' should be skipped.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "Default"}, + {}, + run_id=chain_id, + tags=["graph:step:0", "graph:step:1"], + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + # No usable tags — falls back to serialized name + assert agent_input["payload"]["name"] == "Default" + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info(self): + handler = LangGraphCallbackHandler(Mock()) + info = handler.adapter_info() + assert info.name == "langgraph" + assert info.adapter_type == "framework" diff --git a/tests/instrument/adapters/providers/__init__.py b/tests/instrument/adapters/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/providers/conftest.py b/tests/instrument/adapters/providers/conftest.py new file mode 100644 index 0000000..48aa1e7 --- /dev/null +++ b/tests/instrument/adapters/providers/conftest.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types import CompletionUsage + +from anthropic.types import Message, TextBlock, Usage + + +def make_openai_response( + content: str = "Hello!", + role: str = "assistant", + model: str = "gpt-4", + prompt_tokens: int = 10, + completion_tokens: int = 5, + total_tokens: int = 15, +) -> ChatCompletion: + """Build a real OpenAI ChatCompletion response.""" + return ChatCompletion( + id="chatcmpl-test", + model=model, + object="chat.completion", + created=1700000000, + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage(role=role, content=content), + ) + ], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + ) + + +def make_openai_response_no_usage(model: str = "gpt-4") -> ChatCompletion: + """Build an OpenAI response with no usage data.""" + return ChatCompletion( + id="chatcmpl-test", + model=model, + object="chat.completion", + created=1700000000, + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage(role="assistant", content="Hello!"), + ) + ], + usage=None, + ) + + +def make_openai_response_empty_choices(model: str = "gpt-4") -> ChatCompletion: + """Build an OpenAI response with empty choices.""" + return ChatCompletion( + id="chatcmpl-test", + model=model, + object="chat.completion", + created=1700000000, + choices=[], + usage=None, + ) + + +def make_anthropic_response( + text: str = "I'm Claude!", + model: str = "claude-3-opus-20240229", + input_tokens: int = 20, + output_tokens: int = 10, + stop_reason: str = "end_turn", +) -> Message: + """Build a real Anthropic Message response.""" + return Message( + id="msg-test", + type="message", + role="assistant", + model=model, + content=[TextBlock(type="text", text=text)], + usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens), + stop_reason=stop_reason, + ) + + +def make_anthropic_response_empty_content( + model: str = "claude-3-opus-20240229", +) -> Message: + """Build an Anthropic response with empty content.""" + return Message( + id="msg-test", + type="message", + role="assistant", + model=model, + content=[], + usage=Usage(input_tokens=0, output_tokens=0), + stop_reason="end_turn", + ) diff --git a/tests/instrument/adapters/providers/test_anthropic.py b/tests/instrument/adapters/providers/test_anthropic.py new file mode 100644 index 0000000..7dcf22f --- /dev/null +++ b/tests/instrument/adapters/providers/test_anthropic.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from layerlens.instrument import trace +from layerlens.instrument.adapters.providers.anthropic import ( + AnthropicProvider, + instrument_anthropic, + uninstrument_anthropic, +) + +from ...conftest import find_event +from .conftest import make_anthropic_response, make_anthropic_response_empty_content + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def test_model_invoke_and_cost_record(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=make_anthropic_response()) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + r = anthropic_client.messages.create( + model="claude-3-opus-20240229", max_tokens=1024, + messages=[{"role": "user", "content": "Hi"}], + ) + return r.content[0].text + + my_agent() + events = capture_trace["events"] + + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["name"] == "anthropic.messages.create" + assert model_invoke["payload"]["response_model"] == "claude-3-opus-20240229" + assert model_invoke["payload"]["output_message"]["type"] == "text" + assert model_invoke["payload"]["output_message"]["text"] == "I'm Claude!" + assert model_invoke["payload"]["usage"]["input_tokens"] == 20 + assert model_invoke["payload"]["usage"]["output_tokens"] == 10 + assert model_invoke["payload"]["stop_reason"] == "end_turn" + assert "latency_ms" in model_invoke["payload"] + + cost = find_event(events, "cost.record") + assert cost["payload"]["provider"] == "anthropic" + assert cost["payload"]["input_tokens"] == 20 + assert cost["payload"]["output_tokens"] == 10 + + def test_error_emits_agent_error(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(side_effect=RuntimeError("overloaded")) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + try: + anthropic_client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + events = capture_trace["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "overloaded" + assert "latency_ms" in error["payload"] + + +# --------------------------------------------------------------------------- +# Passthrough / no-op behavior +# --------------------------------------------------------------------------- + + +class TestPassthrough: + def test_no_op_outside_trace(self): + response = make_anthropic_response() + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=response) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + result = anthropic_client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[]) + assert result.content[0].text == "I'm Claude!" + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def test_disconnect_restores_original(self): + anthropic_client = Mock() + original = anthropic_client.messages.create + + provider = AnthropicProvider() + provider.connect(anthropic_client) + assert anthropic_client.messages.create is not original + + provider.disconnect() + assert anthropic_client.messages.create is original + + def test_disconnect_when_not_connected(self): + provider = AnthropicProvider() + provider.disconnect() # should not raise + + def test_double_connect_replaces_wrapper(self): + anthropic_client = Mock() + provider = AnthropicProvider() + provider.connect(anthropic_client) + first_wrapper = anthropic_client.messages.create + + provider2 = AnthropicProvider() + provider2.connect(anthropic_client) + assert anthropic_client.messages.create is not first_wrapper + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info_before_connect(self): + provider = AnthropicProvider() + info = provider.adapter_info() + assert info.name == "anthropic" + assert info.adapter_type == "provider" + assert info.connected is False + + def test_info_after_connect(self): + provider = AnthropicProvider() + provider.connect(Mock()) + info = provider.adapter_info() + assert info.connected is True + + def test_info_after_disconnect(self): + provider = AnthropicProvider() + provider.connect(Mock()) + provider.disconnect() + assert provider.adapter_info().connected is False + + +# --------------------------------------------------------------------------- +# Convenience API +# --------------------------------------------------------------------------- + + +class TestConvenienceAPI: + def test_instrument_and_uninstrument(self): + anthropic_client = Mock() + original = anthropic_client.messages.create + instrument_anthropic(anthropic_client) + assert anthropic_client.messages.create is not original + uninstrument_anthropic() + + +# --------------------------------------------------------------------------- +# capture_params filtering +# --------------------------------------------------------------------------- + + +class TestCaptureParams: + def test_captured_params_included(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=make_anthropic_response()) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + anthropic_client.messages.create( + model="claude-3-opus-20240229", max_tokens=1024, temperature=0.5, top_k=40, + messages=[{"role": "user", "content": "Hi"}], + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert params["model"] == "claude-3-opus-20240229" + assert params["max_tokens"] == 1024 + assert params["temperature"] == 0.5 + assert params["top_k"] == 40 + + def test_non_captured_params_excluded(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=make_anthropic_response()) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + anthropic_client.messages.create( + model="claude-3-opus-20240229", max_tokens=1024, + messages=[], stream=True, metadata={"user_id": "abc"}, + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert "stream" not in params + assert "metadata" not in params + assert "messages" not in params + + +# --------------------------------------------------------------------------- +# Extractor edge cases (using real SDK types) +# --------------------------------------------------------------------------- + + +class TestExtractors: + def test_extract_output_normal(self): + r = make_anthropic_response(text="Hello world") + output = AnthropicProvider.extract_output(r) + assert output == {"type": "text", "text": "Hello world"} + + def test_extract_output_empty_content(self): + r = make_anthropic_response_empty_content() + assert AnthropicProvider.extract_output(r) is None + + def test_extract_meta_normal(self): + r = make_anthropic_response( + model="claude-3-5-sonnet-20241022", + input_tokens=100, output_tokens=50, + stop_reason="max_tokens", + ) + meta = AnthropicProvider.extract_meta(r) + assert meta["response_model"] == "claude-3-5-sonnet-20241022" + assert meta["usage"]["input_tokens"] == 100 + assert meta["usage"]["output_tokens"] == 50 + assert meta["stop_reason"] == "max_tokens" diff --git a/tests/instrument/adapters/providers/test_litellm.py b/tests/instrument/adapters/providers/test_litellm.py new file mode 100644 index 0000000..50588ce --- /dev/null +++ b/tests/instrument/adapters/providers/test_litellm.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import sys +import types +from unittest.mock import Mock + +from layerlens.instrument import trace +from layerlens.instrument.adapters.providers.litellm import ( + LiteLLMProvider, + instrument_litellm, + uninstrument_litellm, +) + +from ...conftest import find_event +from .conftest import make_openai_response, make_openai_response_empty_choices, make_openai_response_no_usage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _install_mock_litellm(response=None): + """Inject a fake litellm module into sys.modules with real OpenAI response types.""" + mock_mod = types.ModuleType("litellm") + mock_mod.completion = Mock(return_value=response or make_openai_response()) + mock_mod.acompletion = Mock() + sys.modules["litellm"] = mock_mod + return mock_mod + + +def _remove_mock_litellm(): + uninstrument_litellm() + for key in list(sys.modules.keys()): + if key.startswith("litellm"): + del sys.modules[key] + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_model_invoke_and_cost_record(self, mock_client, capture_trace): + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + r = litellm.completion( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + return r.choices[0].message.content + + my_agent() + events = capture_trace["events"] + + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["name"] == "litellm.completion" + assert model_invoke["payload"]["model"] == "gpt-4" + assert model_invoke["payload"]["output_message"]["content"] == "Hello!" + assert model_invoke["payload"]["usage"]["total_tokens"] == 15 + assert "latency_ms" in model_invoke["payload"] + + cost = find_event(events, "cost.record") + assert cost["payload"]["provider"] == "litellm" + assert cost["payload"]["total_tokens"] == 15 + + def test_error_emits_agent_error(self, mock_client, capture_trace): + self.mock_litellm.completion = Mock(side_effect=RuntimeError("rate limited")) + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + try: + litellm.completion(model="gpt-4", messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + events = capture_trace["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "rate limited" + assert "latency_ms" in error["payload"] + + +# --------------------------------------------------------------------------- +# Passthrough / no-op behavior +# --------------------------------------------------------------------------- + + +class TestPassthrough: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_no_op_outside_trace(self): + instrument_litellm() + import litellm + result = litellm.completion(model="gpt-4", messages=[]) + assert result.choices[0].message.content == "Hello!" + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_uninstrument_restores_original(self): + original = self.mock_litellm.completion + instrument_litellm() + assert self.mock_litellm.completion is not original + uninstrument_litellm() + assert self.mock_litellm.completion is original + + def test_disconnect_when_not_connected(self): + provider = LiteLLMProvider() + provider.disconnect() # should not raise + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_info_before_connect(self): + provider = LiteLLMProvider() + info = provider.adapter_info() + assert info.name == "litellm" + assert info.adapter_type == "provider" + assert info.connected is False + + def test_info_after_connect(self): + provider = LiteLLMProvider() + provider.connect() + info = provider.adapter_info() + assert info.connected is True + + def test_info_after_disconnect(self): + provider = LiteLLMProvider() + provider.connect() + provider.disconnect() + assert provider.adapter_info().connected is False + + +# --------------------------------------------------------------------------- +# Convenience API +# --------------------------------------------------------------------------- + + +class TestConvenienceAPI: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_instrument_and_uninstrument(self): + original = self.mock_litellm.completion + instrument_litellm() + assert self.mock_litellm.completion is not original + uninstrument_litellm() + assert self.mock_litellm.completion is original + + +# --------------------------------------------------------------------------- +# capture_params filtering +# --------------------------------------------------------------------------- + + +class TestCaptureParams: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_captured_params_included(self, mock_client, capture_trace): + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + litellm.completion( + model="gpt-4", temperature=0.7, top_p=0.9, + messages=[{"role": "user", "content": "Hi"}], + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert params["model"] == "gpt-4" + assert params["temperature"] == 0.7 + assert params["top_p"] == 0.9 + + def test_non_captured_params_excluded(self, mock_client, capture_trace): + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + litellm.completion( + model="gpt-4", messages=[], stream=True, api_key="sk-123", + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert "stream" not in params + assert "api_key" not in params + assert "messages" not in params + + +# --------------------------------------------------------------------------- +# Extractor edge cases (LiteLLM reuses OpenAI extractors, real types) +# --------------------------------------------------------------------------- + + +class TestExtractors: + def test_extract_output_normal(self): + r = make_openai_response(content="LiteLLM response") + output = LiteLLMProvider.extract_output(r) + assert output == {"role": "assistant", "content": "LiteLLM response"} + + def test_extract_output_empty_choices(self): + r = make_openai_response_empty_choices() + assert LiteLLMProvider.extract_output(r) is None + + def test_extract_meta_normal(self): + r = make_openai_response(model="gpt-4o", prompt_tokens=5, completion_tokens=3, total_tokens=8) + meta = LiteLLMProvider.extract_meta(r) + assert meta["response_model"] == "gpt-4o" + assert meta["usage"]["total_tokens"] == 8 + + def test_extract_meta_no_usage(self): + r = make_openai_response_no_usage() + meta = LiteLLMProvider.extract_meta(r) + assert "usage" not in meta diff --git a/tests/instrument/adapters/providers/test_openai.py b/tests/instrument/adapters/providers/test_openai.py new file mode 100644 index 0000000..42641b4 --- /dev/null +++ b/tests/instrument/adapters/providers/test_openai.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from layerlens.instrument import trace +from layerlens.instrument.adapters.providers.openai import ( + OpenAIProvider, + instrument_openai, + uninstrument_openai, +) + +from ...conftest import find_event +from .conftest import ( + make_openai_response, + make_openai_response_no_usage, + make_openai_response_empty_choices, +) + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def test_model_invoke_and_cost_record(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=make_openai_response()) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + r = openai_client.chat.completions.create( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + return r.choices[0].message.content + + my_agent() + events = capture_trace["events"] + + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["name"] == "openai.chat.completions.create" + assert model_invoke["payload"]["model"] == "gpt-4" + assert model_invoke["payload"]["output_message"]["role"] == "assistant" + assert model_invoke["payload"]["output_message"]["content"] == "Hello!" + assert model_invoke["payload"]["usage"]["prompt_tokens"] == 10 + assert model_invoke["payload"]["usage"]["completion_tokens"] == 5 + assert model_invoke["payload"]["usage"]["total_tokens"] == 15 + assert "latency_ms" in model_invoke["payload"] + + cost = find_event(events, "cost.record") + assert cost["payload"]["provider"] == "openai" + assert cost["payload"]["total_tokens"] == 15 + + def test_error_emits_agent_error(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(side_effect=RuntimeError("API error")) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + try: + openai_client.chat.completions.create(model="gpt-4", messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + events = capture_trace["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "API error" + assert "latency_ms" in error["payload"] + + +# --------------------------------------------------------------------------- +# Passthrough / no-op behavior +# --------------------------------------------------------------------------- + + +class TestPassthrough: + def test_no_op_outside_trace(self): + response = make_openai_response() + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=response) + + provider = OpenAIProvider() + provider.connect(openai_client) + + result = openai_client.chat.completions.create(model="gpt-4", messages=[]) + assert result.choices[0].message.content == "Hello!" + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def test_disconnect_restores_original(self): + openai_client = Mock() + original = openai_client.chat.completions.create + + provider = OpenAIProvider() + provider.connect(openai_client) + assert openai_client.chat.completions.create is not original + + provider.disconnect() + assert openai_client.chat.completions.create is original + + def test_disconnect_when_not_connected(self): + provider = OpenAIProvider() + provider.disconnect() # should not raise + + def test_double_connect_replaces_wrapper(self): + openai_client = Mock() + provider = OpenAIProvider() + provider.connect(openai_client) + first_wrapper = openai_client.chat.completions.create + + provider2 = OpenAIProvider() + provider2.connect(openai_client) + assert openai_client.chat.completions.create is not first_wrapper + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info_before_connect(self): + provider = OpenAIProvider() + info = provider.adapter_info() + assert info.name == "openai" + assert info.adapter_type == "provider" + assert info.connected is False + + def test_info_after_connect(self): + provider = OpenAIProvider() + provider.connect(Mock()) + info = provider.adapter_info() + assert info.connected is True + + def test_info_after_disconnect(self): + provider = OpenAIProvider() + provider.connect(Mock()) + provider.disconnect() + assert provider.adapter_info().connected is False + + +# --------------------------------------------------------------------------- +# Convenience API +# --------------------------------------------------------------------------- + + +class TestConvenienceAPI: + def test_instrument_and_uninstrument(self): + openai_client = Mock() + original = openai_client.chat.completions.create + instrument_openai(openai_client) + assert openai_client.chat.completions.create is not original + uninstrument_openai() + + +# --------------------------------------------------------------------------- +# capture_params filtering +# --------------------------------------------------------------------------- + + +class TestCaptureParams: + def test_captured_params_included(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=make_openai_response()) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + openai_client.chat.completions.create( + model="gpt-4", temperature=0.7, top_p=0.9, + messages=[{"role": "user", "content": "Hi"}], + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert params["model"] == "gpt-4" + assert params["temperature"] == 0.7 + assert params["top_p"] == 0.9 + + def test_non_captured_params_excluded(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=make_openai_response()) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + openai_client.chat.completions.create( + model="gpt-4", messages=[], stream=True, user="test-user", + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert "stream" not in params + assert "user" not in params + assert "messages" not in params + + +# --------------------------------------------------------------------------- +# Extractor edge cases (using real SDK types) +# --------------------------------------------------------------------------- + + +class TestExtractors: + def test_extract_output_normal(self): + r = make_openai_response(content="Hi there", role="assistant") + output = OpenAIProvider.extract_output(r) + assert output == {"role": "assistant", "content": "Hi there"} + + def test_extract_output_empty_choices(self): + r = make_openai_response_empty_choices() + assert OpenAIProvider.extract_output(r) is None + + def test_extract_meta_normal(self): + r = make_openai_response(model="gpt-4o", prompt_tokens=5, completion_tokens=3, total_tokens=8) + meta = OpenAIProvider.extract_meta(r) + assert meta["response_model"] == "gpt-4o" + assert meta["usage"]["prompt_tokens"] == 5 + assert meta["usage"]["completion_tokens"] == 3 + assert meta["usage"]["total_tokens"] == 8 + + def test_extract_meta_no_usage(self): + r = make_openai_response_no_usage(model="gpt-4") + meta = OpenAIProvider.extract_meta(r) + assert "usage" not in meta + assert meta["response_model"] == "gpt-4" diff --git a/tests/instrument/test_registry.py b/tests/instrument/adapters/test_registry.py similarity index 100% rename from tests/instrument/test_registry.py rename to tests/instrument/adapters/test_registry.py diff --git a/tests/instrument/test_adapters.py b/tests/instrument/test_adapters.py deleted file mode 100644 index 11752c2..0000000 --- a/tests/instrument/test_adapters.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -import json -import sys -import types -import importlib -from uuid import uuid4 -from unittest.mock import Mock - -from .conftest import find_events, find_event - - -def _capture_framework_trace(mock_client): - """Helper to capture uploaded trace from framework adapters (which manage their own collector).""" - uploaded = {} - - def _capture(path): - with open(path) as f: - data = json.load(f) - payload = data[0] - uploaded["trace_id"] = payload.get("trace_id") - uploaded["events"] = payload.get("events", []) - uploaded["capture_config"] = payload.get("capture_config", {}) - uploaded["attestation"] = payload.get("attestation", {}) - - mock_client.traces.upload.side_effect = _capture - return uploaded - - -class TestLangChainAdapter: - def _setup_langchain_mock(self): - mock_lc_core = types.ModuleType("langchain_core") - mock_lc_callbacks = types.ModuleType("langchain_core.callbacks") - - class FakeBaseCallbackHandler: - def __init__(self): - pass - - mock_lc_callbacks.BaseCallbackHandler = FakeBaseCallbackHandler - mock_lc_core.callbacks = mock_lc_callbacks - - sys.modules["langchain_core"] = mock_lc_core - sys.modules["langchain_core.callbacks"] = mock_lc_callbacks - - def _teardown_langchain_mock(self): - for key in list(sys.modules.keys()): - if key.startswith("langchain_core"): - del sys.modules[key] - - def _get_handler(self, mock_client): - from layerlens.instrument.adapters.frameworks import langchain as lc_mod - - importlib.reload(lc_mod) - return lc_mod.LangChainCallbackHandler(mock_client) - - def test_emits_flat_events(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - chain_run_id = uuid4() - llm_run_id = uuid4() - - handler.on_chain_start( - {"name": "RunnableSequence", "id": ["RunnableSequence"]}, - {"question": "What is AI?"}, - run_id=chain_run_id, - ) - handler.on_llm_start( - {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, - ["What is AI?"], - run_id=llm_run_id, - parent_run_id=chain_run_id, - ) - - llm_response = Mock() - llm_response.generations = [[Mock(text="AI is...")]] - llm_response.llm_output = {"token_usage": {"total_tokens": 50}, "model_name": "gpt-4"} - handler.on_llm_end(llm_response, run_id=llm_run_id) - handler.on_chain_end({"output": "AI is..."}, run_id=chain_run_id) - - events = uploaded["events"] - # Should have: agent.input, model.invoke (start), model.invoke (end), cost.record, agent.output - agent_input = find_event(events, "agent.input") - assert agent_input["payload"]["name"] == "RunnableSequence" - assert agent_input["payload"]["input"] == {"question": "What is AI?"} - - model_invokes = find_events(events, "model.invoke") - assert len(model_invokes) >= 1 - # The end event has model name and output - end_invoke = [m for m in model_invokes if m["payload"].get("model") == "gpt-4"] - assert len(end_invoke) == 1 - assert end_invoke[0]["payload"]["output_message"] == "AI is..." - - cost = find_event(events, "cost.record") - assert cost["payload"]["total_tokens"] == 50 - - agent_output = find_event(events, "agent.output") - assert agent_output["payload"]["status"] == "ok" - - # Parent-child: LLM events should reference chain's span_id as parent - chain_span_id = agent_input["span_id"] - llm_start = [m for m in model_invokes if m["payload"].get("name") == "ChatOpenAI"][0] - assert llm_start["parent_span_id"] == chain_span_id - finally: - self._teardown_langchain_mock() - - def test_tracks_tools_and_retrievers(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - chain_id = uuid4() - tool_id = uuid4() - retriever_id = uuid4() - - handler.on_chain_start({"name": "Agent"}, {"input": "test"}, run_id=chain_id) - handler.on_tool_start({"name": "search"}, "query", run_id=tool_id, parent_run_id=chain_id) - handler.on_tool_end("results", run_id=tool_id) - handler.on_retriever_start({"name": "vectorstore"}, "query", run_id=retriever_id, parent_run_id=chain_id) - - docs = [Mock(page_content="doc1", metadata={"source": "a"})] - handler.on_retriever_end(docs, run_id=retriever_id) - handler.on_chain_end({"output": "done"}, run_id=chain_id) - - events = uploaded["events"] - tool_calls = find_events(events, "tool.call") - assert len(tool_calls) == 2 # tool + retriever both emit tool.call - tool_results = find_events(events, "tool.result") - assert len(tool_results) == 2 - finally: - self._teardown_langchain_mock() - - def test_error_on_chain(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - chain_id = uuid4() - handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) - handler.on_chain_error(ValueError("broke"), run_id=chain_id) - - events = uploaded["events"] - error = find_event(events, "agent.error") - assert error["payload"]["error"] == "broke" - assert error["payload"]["status"] == "error" - finally: - self._teardown_langchain_mock() - - def test_null_serialized_handled(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - run_id = uuid4() - handler.on_chain_start(None, {"input": "x"}, run_id=run_id) - handler.on_chain_end({"output": "done"}, run_id=run_id) - - events = uploaded["events"] - agent_input = find_event(events, "agent.input") - assert agent_input["payload"]["name"] == "unknown" - finally: - self._teardown_langchain_mock() diff --git a/tests/instrument/test_capture_config.py b/tests/instrument/test_capture_config.py index a70dfb4..5b00390 100644 --- a/tests/instrument/test_capture_config.py +++ b/tests/instrument/test_capture_config.py @@ -5,9 +5,9 @@ import pytest -from layerlens.instrument import trace, CaptureConfig -from .conftest import find_events, find_event +from layerlens.instrument import CaptureConfig, trace +from .conftest import find_event, find_events # --------------------------------------------------------------------------- # CaptureConfig unit tests diff --git a/tests/instrument/test_core.py b/tests/instrument/test_core.py index 89ed459..d16277d 100644 --- a/tests/instrument/test_core.py +++ b/tests/instrument/test_core.py @@ -1,12 +1,11 @@ from __future__ import annotations -import os - import pytest -from layerlens.instrument import span, emit, trace -from layerlens.instrument._context import _current_collector, _current_span_id -from .conftest import find_events, find_event +from layerlens.instrument import emit, span, trace +from layerlens.instrument._context import _current_span_id, _current_collector + +from .conftest import find_event class TestTraceDecorator: diff --git a/tests/instrument/test_providers.py b/tests/instrument/test_providers.py deleted file mode 100644 index ede576f..0000000 --- a/tests/instrument/test_providers.py +++ /dev/null @@ -1,220 +0,0 @@ -from __future__ import annotations - -import sys -import types -from unittest.mock import Mock - -from layerlens.instrument import trace -from .conftest import find_events, find_event - - -def _openai_response(): - r = Mock() - r.choices = [Mock()] - r.choices[0].message = Mock() - r.choices[0].message.role = "assistant" - r.choices[0].message.content = "Hello!" - r.usage = Mock() - r.usage.prompt_tokens = 10 - r.usage.completion_tokens = 5 - r.usage.total_tokens = 15 - r.model = "gpt-4" - return r - - -def _anthropic_response(): - r = Mock() - block = Mock() - block.type = "text" - block.text = "I'm Claude!" - r.content = [block] - r.usage = Mock() - r.usage.input_tokens = 20 - r.usage.output_tokens = 10 - r.model = "claude-3-opus" - r.stop_reason = "end_turn" - return r - - -class TestOpenAIProvider: - def test_instrument_emits_events(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - openai_client.chat.completions.create = Mock(return_value=_openai_response()) - - provider = OpenAIProvider() - provider.connect(openai_client) - - @trace(mock_client) - def my_agent(): - return ( - openai_client.chat.completions.create(model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) - .choices[0] - .message.content - ) - - my_agent() - events = capture_trace["events"] - model_invoke = find_event(events, "model.invoke") - assert model_invoke["payload"]["name"] == "openai.chat.completions.create" - assert model_invoke["payload"]["parameters"]["model"] == "gpt-4" - assert model_invoke["payload"]["usage"]["total_tokens"] == 15 - assert model_invoke["payload"]["output_message"]["content"] == "Hello!" - - cost = find_event(events, "cost.record") - assert cost["payload"]["provider"] == "openai" - assert cost["payload"]["total_tokens"] == 15 - - def test_passthrough_without_trace(self): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - openai_client.chat.completions.create = Mock(return_value=_openai_response()) - - provider = OpenAIProvider() - provider.connect(openai_client) - - result = openai_client.chat.completions.create(model="gpt-4", messages=[]) - assert result.choices[0].message.content == "Hello!" - - def test_disconnect_restores(self): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - original = openai_client.chat.completions.create - - provider = OpenAIProvider() - provider.connect(openai_client) - assert openai_client.chat.completions.create is not original - - provider.disconnect() - assert openai_client.chat.completions.create is original - - def test_instrument_convenience_function(self): - from layerlens.instrument.adapters.providers.openai import instrument_openai, uninstrument_openai - - openai_client = Mock() - original = openai_client.chat.completions.create - instrument_openai(openai_client) - assert openai_client.chat.completions.create is not original - uninstrument_openai() - - -class TestAnthropicProvider: - def test_instrument_emits_events(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.anthropic import AnthropicProvider - - anthropic_client = Mock() - anthropic_client.messages.create = Mock(return_value=_anthropic_response()) - - provider = AnthropicProvider() - provider.connect(anthropic_client) - - @trace(mock_client) - def my_agent(): - return ( - anthropic_client.messages.create( - model="claude-3-opus", max_tokens=1024, messages=[{"role": "user", "content": "Hi"}] - ) - .content[0] - .text - ) - - my_agent() - events = capture_trace["events"] - model_invoke = find_event(events, "model.invoke") - assert model_invoke["payload"]["output_message"]["text"] == "I'm Claude!" - assert model_invoke["payload"]["usage"]["input_tokens"] == 20 - assert model_invoke["payload"]["response_model"] == "claude-3-opus" - assert model_invoke["payload"]["stop_reason"] == "end_turn" - - def test_disconnect_restores(self): - from layerlens.instrument.adapters.providers.anthropic import AnthropicProvider - - anthropic_client = Mock() - original = anthropic_client.messages.create - - provider = AnthropicProvider() - provider.connect(anthropic_client) - provider.disconnect() - assert anthropic_client.messages.create is original - - -class TestLiteLLMProvider: - def setup_method(self): - self.mock_litellm = types.ModuleType("litellm") - self.mock_litellm.completion = Mock(return_value=_openai_response()) - self.mock_litellm.acompletion = Mock() - sys.modules["litellm"] = self.mock_litellm - - def teardown_method(self): - from layerlens.instrument.adapters.providers.litellm import uninstrument_litellm - - uninstrument_litellm() - for key in list(sys.modules.keys()): - if key.startswith("litellm"): - del sys.modules[key] - - def test_instrument_emits_events(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.litellm import instrument_litellm - - instrument_litellm() - - @trace(mock_client) - def my_agent(): - import litellm - - return ( - litellm.completion(model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) - .choices[0] - .message.content - ) - - my_agent() - events = capture_trace["events"] - model_invoke = find_event(events, "model.invoke") - assert model_invoke["payload"]["name"] == "litellm.completion" - assert model_invoke["payload"]["parameters"]["model"] == "gpt-4" - - def test_passthrough_without_trace(self): - from layerlens.instrument.adapters.providers.litellm import instrument_litellm - - instrument_litellm() - import litellm - - result = litellm.completion(model="gpt-4", messages=[]) - assert result.choices[0].message.content == "Hello!" - - def test_uninstrument(self): - from layerlens.instrument.adapters.providers.litellm import instrument_litellm, uninstrument_litellm - - original = self.mock_litellm.completion - instrument_litellm() - assert self.mock_litellm.completion is not original - uninstrument_litellm() - assert self.mock_litellm.completion is original - - -class TestProviderErrorHandling: - def test_error_emits_event(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - openai_client.chat.completions.create = Mock(side_effect=RuntimeError("API error")) - - provider = OpenAIProvider() - provider.connect(openai_client) - - @trace(mock_client) - def my_agent(): - try: - openai_client.chat.completions.create(model="gpt-4", messages=[]) - except RuntimeError: - pass - return "recovered" - - my_agent() - events = capture_trace["events"] - error = find_event(events, "agent.error") - assert error["payload"]["error"] == "API error" diff --git a/tests/instrument/test_types.py b/tests/instrument/test_types.py index 63927e0..618ebd0 100644 --- a/tests/instrument/test_types.py +++ b/tests/instrument/test_types.py @@ -1,7 +1,7 @@ from __future__ import annotations from layerlens.instrument._span import span -from layerlens.instrument._context import _current_span_id, _parent_span_id, _current_span_name +from layerlens.instrument._context import _parent_span_id, _current_span_id, _current_span_name class TestSpan: From ca2834ed049f4f3805cc5c048add476d6c52866c Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:51:50 -0700 Subject: [PATCH 2/3] feat: context propagation and upload circuit breaker --- src/layerlens/instrument/__init__.py | 3 + .../instrument/_context_propagation.py | 93 +++ src/layerlens/instrument/_upload.py | 68 +- .../adapters/frameworks/_base_framework.py | 107 ++- tests/instrument/test_trace_context.py | 642 ++++++++++++++++++ 5 files changed, 882 insertions(+), 31 deletions(-) create mode 100644 src/layerlens/instrument/_context_propagation.py create mode 100644 tests/instrument/test_trace_context.py diff --git a/src/layerlens/instrument/__init__.py b/src/layerlens/instrument/__init__.py index a7237a0..04e4667 100644 --- a/src/layerlens/instrument/__init__.py +++ b/src/layerlens/instrument/__init__.py @@ -5,6 +5,7 @@ from ._capture_config import CaptureConfig from ._collector import TraceCollector from ._decorator import trace +from ._context_propagation import trace_context, get_trace_context from .adapters._base import AdapterInfo, BaseAdapter __all__ = [ @@ -13,6 +14,8 @@ "CaptureConfig", "TraceCollector", "emit", + "get_trace_context", "span", "trace", + "trace_context", ] diff --git a/src/layerlens/instrument/_context_propagation.py b/src/layerlens/instrument/_context_propagation.py new file mode 100644 index 0000000..f1ced8d --- /dev/null +++ b/src/layerlens/instrument/_context_propagation.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import uuid +from typing import Any, Dict, Generator, Optional +from contextlib import contextmanager + +from ._collector import TraceCollector +from ._capture_config import CaptureConfig +from ._context import ( + _current_collector, + _current_span_id, + _parent_span_id, + _push_span, + _pop_span, +) + + +@contextmanager +def trace_context( + client: Any, + *, + capture_config: Optional[CaptureConfig] = None, + from_context: Optional[Dict[str, Any]] = None, +) -> Generator[TraceCollector, None, None]: + """Establish a shared trace context for multiple adapters. + + Creates a :class:`TraceCollector` and sets it as the active collector + in ``contextvars`` so that any adapter emitting events inside the + block will use the same ``trace_id`` and span hierarchy. + + When *from_context* is provided (a dict from :func:`get_trace_context`), + the new collector reuses the original ``trace_id`` so events on both + sides of a boundary belong to the same trace. + + The collector is flushed automatically when the context exits. + + Args: + client: A :class:`~layerlens.Stratix` (or compatible) client used + for uploading the trace on flush. + capture_config: Optional capture configuration. Falls back to + :meth:`CaptureConfig.standard` if not provided. + from_context: Optional dict produced by :func:`get_trace_context`. + When supplied the collector inherits the original trace_id. + + Yields: + The shared :class:`TraceCollector`. + """ + config = capture_config or CaptureConfig.standard() + collector = TraceCollector(client, config) + + if from_context is not None: + collector._trace_id = from_context["trace_id"] # noqa: SLF001 + + root_span_id = uuid.uuid4().hex[:16] + + col_token = _current_collector.set(collector) + span_snapshot = _push_span(root_span_id, "trace_context") + try: + yield collector + finally: + _pop_span(span_snapshot) + _current_collector.reset(col_token) + collector.flush() + + +def get_trace_context() -> Optional[Dict[str, Any]]: + """Snapshot the current trace context as a plain dict. + + Returns ``None`` when called outside a ``@trace`` / ``trace_context`` + block. The returned dict is safe to serialise (JSON, headers, message + queues, etc.) and restore via ``trace_context(client, from_context=ctx)``. + + Keys: + + * ``trace_id`` — 16-char hex trace identifier + * ``span_id`` — current span (becomes the parent in the remote scope) + * ``parent_span_id`` — optional grandparent for reference + * ``version`` — format version for forward compatibility + """ + collector = _current_collector.get() + if collector is None: + return None + + span_id = _current_span_id.get() + if span_id is None: + return None + + return { + "trace_id": collector.trace_id, + "span_id": span_id, + "parent_span_id": _parent_span_id.get(), + "version": 1, + } diff --git a/src/layerlens/instrument/_upload.py b/src/layerlens/instrument/_upload.py index c594d29..ae42048 100644 --- a/src/layerlens/instrument/_upload.py +++ b/src/layerlens/instrument/_upload.py @@ -2,16 +2,70 @@ import os import json +import time import asyncio import logging import tempfile +import threading from typing import Any, Dict log: logging.Logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Circuit breaker +# --------------------------------------------------------------------------- + +_lock = threading.Lock() +_error_count = 0 +_circuit_open = False +_opened_at: float = 0.0 + +_THRESHOLD = 10 +_COOLDOWN_S = 60.0 + + +def _allow() -> bool: + global _circuit_open, _error_count + with _lock: + if not _circuit_open: + return True + if time.monotonic() - _opened_at >= _COOLDOWN_S: + _circuit_open = False + _error_count = 0 + log.info("layerlens: upload circuit breaker half-open, retrying") + return True + return False + + +def _on_success() -> None: + global _error_count, _circuit_open + with _lock: + if _error_count > 0: + _error_count = 0 + _circuit_open = False + + +def _on_failure() -> None: + global _error_count, _circuit_open, _opened_at + with _lock: + _error_count += 1 + if _error_count >= _THRESHOLD and not _circuit_open: + _circuit_open = True + _opened_at = time.monotonic() + log.warning( + "layerlens: upload circuit breaker OPEN after %d errors (cooldown %.0fs)", + _error_count, + _COOLDOWN_S, + ) + + +# --------------------------------------------------------------------------- +# Upload +# --------------------------------------------------------------------------- + + def _write_trace_file(payload: Dict[str, Any]) -> str: - """Write trace payload to a temp file and return its path.""" fd, path = tempfile.mkstemp(suffix=".json", prefix="layerlens_trace_") with os.fdopen(fd, "w") as f: json.dump([payload], f, default=str) @@ -19,9 +73,15 @@ def _write_trace_file(payload: Dict[str, Any]) -> str: def upload_trace(client: Any, payload: Dict[str, Any]) -> None: + if not _allow(): + return path = _write_trace_file(payload) try: client.traces.upload(path) + _on_success() + except Exception: + _on_failure() + log.warning("layerlens: trace upload failed", exc_info=True) finally: try: os.unlink(path) @@ -30,9 +90,15 @@ def upload_trace(client: Any, payload: Dict[str, Any]) -> None: async def async_upload_trace(client: Any, payload: Dict[str, Any]) -> None: + if not _allow(): + return path = await asyncio.to_thread(_write_trace_file, payload) try: await client.traces.upload(path) + _on_success() + except Exception: + _on_failure() + log.warning("layerlens: async trace upload failed", exc_info=True) finally: try: os.unlink(path) diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 197c65e..20ddbdb 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -1,12 +1,4 @@ -"""Unified base class for all framework adapters. - -Framework adapters hook into a framework's callback / event / tracing -system and emit LayerLens events. They share a common lifecycle: - - 1. Lazy-init a :class:`TraceCollector` on first event. - 2. Emit events through a thread-safe helper. - 3. Flush the collector when a logical trace ends (root span completes, - agent run finishes, disconnect, etc.). +"""Base class for framework adapters. Subclasses MUST set ``name`` and implement ``connect()``. Subclasses SHOULD call ``super().disconnect()`` after unhooking. @@ -14,12 +6,17 @@ from __future__ import annotations import uuid +import logging import threading -from typing import Any, Dict, Optional +from typing import Any, Dict, Generator, Optional +from contextlib import contextmanager from .._base import AdapterInfo, BaseAdapter from ..._collector import TraceCollector from ..._capture_config import CaptureConfig +from ..._context import _current_collector, _current_span_id, _push_span, _pop_span + +log = logging.getLogger(__name__) class FrameworkAdapter(BaseAdapter): @@ -34,16 +31,27 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._connected = False self._collector: Optional[TraceCollector] = None self._root_span_id: Optional[str] = None + self._using_shared_collector = False # Optional run_id → span_id mapping for callback-style frameworks self._span_ids: Dict[str, str] = {} + # Subclasses populate during connect() for adapter_info() metadata + self._metadata: Dict[str, Any] = {} # ------------------------------------------------------------------ # Collector lifecycle # ------------------------------------------------------------------ def _ensure_collector(self) -> TraceCollector: - """Lazily create a collector and root span ID.""" + """Return the shared collector from ContextVars, or create a private one.""" + shared = _current_collector.get() + if shared is not None: + self._using_shared_collector = True + if self._root_span_id is None: + self._root_span_id = _current_span_id.get() + return shared + if self._collector is None: + self._using_shared_collector = False self._collector = TraceCollector(self._client, self._config) self._root_span_id = uuid.uuid4().hex[:16] return self._collector @@ -52,6 +60,55 @@ def _ensure_collector(self) -> TraceCollector: def _new_span_id() -> str: return uuid.uuid4().hex[:16] + # ------------------------------------------------------------------ + # Callback scope — bridges framework callbacks to ContextVars + # ------------------------------------------------------------------ + + @contextmanager + def _callback_scope( + self, + span_name: Optional[str] = None, + ) -> Generator[str, None, None]: + """Push collector + new span into ContextVars; yields the span_id.""" + collector = self._ensure_collector() + span_id = self._new_span_id() + + # Only set the collector ContextVar if no shared one exists already + needs_collector_push = _current_collector.get() is None + col_token = None + if needs_collector_push: + col_token = _current_collector.set(collector) + + snapshot = _push_span(span_id, span_name) + try: + yield span_id + finally: + _pop_span(snapshot) + if col_token is not None: + _current_collector.reset(col_token) + + def _traced_call( + self, + original: Any, + *args: Any, + _span_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Call *original* inside a _callback_scope so providers see this collector.""" + with self._callback_scope(_span_name): + return original(*args, **kwargs) + + async def _async_traced_call( + self, + original: Any, + *args: Any, + _span_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Async version of _traced_call.""" + with self._callback_scope(_span_name): + return await original(*args, **kwargs) + # ------------------------------------------------------------------ # Event emission (thread-safe) # ------------------------------------------------------------------ @@ -79,12 +136,7 @@ def _emit( # ------------------------------------------------------------------ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: - """Map a framework run_id to a span_id, creating one if needed. - - Returns ``(span_id, parent_span_id)``. Useful for frameworks - (LangChain, CrewAI, OpenAI Agents) that assign their own run - identifiers to each step. - """ + """Map a framework run_id to a (span_id, parent_span_id) pair.""" rid = str(run_id) if rid not in self._span_ids: self._span_ids[rid] = self._new_span_id() @@ -97,13 +149,15 @@ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Opt # ------------------------------------------------------------------ def _flush_collector(self) -> None: - """Flush the current collector and reset state.""" + """Flush private collector (no-op for shared collectors).""" with self._lock: collector = self._collector + is_shared = self._using_shared_collector self._collector = None self._root_span_id = None + self._using_shared_collector = False self._span_ids.clear() - if collector is not None: + if collector is not None and not is_shared: collector.flush() # ------------------------------------------------------------------ @@ -111,27 +165,20 @@ def _flush_collector(self) -> None: # ------------------------------------------------------------------ def connect(self, target: Any = None, **kwargs: Any) -> Any: - """Mark the adapter as connected. - - Callback-style adapters (LangChain, LangGraph) are passed directly - to the framework, so ``connect()`` just flips the flag. Adapters - that need registration (CrewAI, LlamaIndex, etc.) should override. - """ + """Mark as connected. Subclasses override for framework registration.""" self._connected = True return target def disconnect(self) -> None: - """Flush remaining events and mark as disconnected. - - Subclasses should unhook from the framework first, then call - ``super().disconnect()``. - """ + """Flush remaining events and mark as disconnected.""" self._flush_collector() self._connected = False + self._metadata.clear() def adapter_info(self) -> AdapterInfo: return AdapterInfo( name=self.name, adapter_type="framework", connected=self._connected, + metadata=self._metadata, ) diff --git a/tests/instrument/test_trace_context.py b/tests/instrument/test_trace_context.py new file mode 100644 index 0000000..03a09be --- /dev/null +++ b/tests/instrument/test_trace_context.py @@ -0,0 +1,642 @@ +"""Tests for trace context: shared collectors, context propagation, +callback scope, and upload circuit breaker. +""" +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest + +from layerlens.instrument import ( + trace, + trace_context, + emit, + span, + get_trace_context, + CaptureConfig, +) +from layerlens.instrument._context import _current_collector, _current_span_id +from layerlens.instrument._collector import TraceCollector +from layerlens.instrument import _upload +from layerlens.instrument.adapters.frameworks._base_framework import FrameworkAdapter + +from .conftest import find_event, find_events + + +# --------------------------------------------------------------------------- +# Minimal concrete adapter for testing +# --------------------------------------------------------------------------- + +class StubAdapter(FrameworkAdapter): + name = "stub" + + def connect(self, target: Any = None, **kwargs: Any) -> Any: + self._connected = True + return target + + def fire_event(self, event_type: str, payload: Dict[str, Any], + span_id: Optional[str] = None, + parent_span_id: Optional[str] = None) -> None: + self._emit(event_type, payload, span_id=span_id, + parent_span_id=parent_span_id, span_name=event_type) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_client(): + client = Mock() + client.traces = Mock() + client.traces.upload = Mock() + return client + + +@pytest.fixture +def capture_trace(mock_client): + """Capture uploaded trace payloads. Supports multiple uploads.""" + uploads: List[Dict[str, Any]] = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + uploads.append(data[0]) + + mock_client.traces.upload.side_effect = _capture + return uploads + + +@pytest.fixture(autouse=True) +def reset_circuit_breaker(): + """Reset the upload circuit breaker between tests.""" + _upload._error_count = 0 + _upload._circuit_open = False + _upload._opened_at = 0.0 + yield + _upload._error_count = 0 + _upload._circuit_open = False + _upload._opened_at = 0.0 + + +# =================================================================== +# 1. Shared trace_id via @trace +# =================================================================== + +class TestSharedCollectorViaTrace: + + def test_framework_adapter_shares_trace_id_with_trace_decorator( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + adapter.fire_event("agent.lifecycle", {"action": "crew.start"}) + return "done" + + agent_run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + lifecycle = find_event(events, "agent.lifecycle") + agent_input = find_event(events, "agent.input") + assert lifecycle["trace_id"] == agent_input["trace_id"] + + def test_multiple_adapters_share_same_trace( + self, mock_client, capture_trace, + ): + adapter_a = StubAdapter(mock_client) + adapter_b = StubAdapter(mock_client) + adapter_a.connect() + adapter_b.connect() + + @trace(mock_client) + def agent_run(): + adapter_a.fire_event("agent.lifecycle", {"source": "A"}) + adapter_b.fire_event("agent.lifecycle", {"source": "B"}) + return "done" + + agent_run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + lifecycles = find_events(events, "agent.lifecycle") + assert len(lifecycles) == 2 + assert lifecycles[0]["trace_id"] == lifecycles[1]["trace_id"] + + def test_framework_adapter_standalone_creates_own_trace( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + adapter.fire_event("agent.lifecycle", {"action": "standalone"}) + adapter.disconnect() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + assert len(events) == 1 + assert events[0]["event_type"] == "agent.lifecycle" + + +# =================================================================== +# 2. Cross-adapter parent-child spans +# =================================================================== + +class TestCrossAdapterSpanHierarchy: + + def test_framework_events_parent_to_trace_root_span( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + adapter.fire_event("agent.lifecycle", {"action": "start"}) + return "done" + + agent_run() + + events = capture_trace[0]["events"] + agent_input = find_event(events, "agent.input") + lifecycle = find_event(events, "agent.lifecycle") + root_span = agent_input["span_id"] + assert lifecycle["parent_span_id"] == root_span + + def test_framework_events_parent_to_active_span( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + with span("retrieval"): + adapter.fire_event("tool.call", {"name": "search", "input": "q"}) + return "done" + + agent_run() + + events = capture_trace[0]["events"] + agent_input = find_event(events, "agent.input") + tool_call = find_event(events, "tool.call") + assert tool_call["parent_span_id"] is not None + assert tool_call["trace_id"] == agent_input["trace_id"] + + def test_adapter_with_explicit_parent_overrides_default( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + explicit_parent = "custom_parent_id" + + @trace(mock_client) + def agent_run(): + adapter.fire_event( + "agent.lifecycle", {"action": "step"}, + parent_span_id=explicit_parent, + ) + return "done" + + agent_run() + + events = capture_trace[0]["events"] + lifecycle = find_event(events, "agent.lifecycle") + assert lifecycle["parent_span_id"] == explicit_parent + + +# =================================================================== +# 3. trace_context() +# =================================================================== + +class TestTraceContext: + + def test_creates_shared_collector(self, mock_client, capture_trace): + adapter_a = StubAdapter(mock_client) + adapter_b = StubAdapter(mock_client) + adapter_a.connect() + adapter_b.connect() + + with trace_context(mock_client): + adapter_a.fire_event("agent.lifecycle", {"source": "A"}) + adapter_b.fire_event("agent.lifecycle", {"source": "B"}) + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + assert len(events) == 2 + assert events[0]["trace_id"] == events[1]["trace_id"] + + def test_flushes_on_exit(self, mock_client, capture_trace): + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + + def test_cleans_up_on_exit(self, mock_client): + with trace_context(mock_client): + assert _current_collector.get() is not None + + assert _current_collector.get() is None + assert _current_span_id.get() is None + + def test_cleans_up_on_error(self, mock_client): + with pytest.raises(RuntimeError): + with trace_context(mock_client): + raise RuntimeError("boom") + + assert _current_collector.get() is None + assert _current_span_id.get() is None + + def test_yields_collector(self, mock_client): + with trace_context(mock_client) as collector: + assert isinstance(collector, TraceCollector) + assert len(collector.trace_id) == 16 + + def test_with_custom_capture_config(self, mock_client, capture_trace): + config = CaptureConfig.standard() + + with trace_context(mock_client, capture_config=config): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + assert capture_trace[0]["capture_config"] == config.to_dict() + + +# =================================================================== +# 4. Context serialisation (get_trace_context / from_context) +# =================================================================== + +class TestGetTraceContext: + + def test_returns_none_outside_trace(self): + assert get_trace_context() is None + + def test_returns_dict_inside_trace(self, mock_client, capture_trace): + @trace(mock_client) + def run(): + ctx = get_trace_context() + assert ctx is not None + assert "trace_id" in ctx + assert "span_id" in ctx + assert "parent_span_id" in ctx + assert ctx["version"] == 1 + return ctx + + ctx = run() + assert len(ctx["trace_id"]) == 16 + assert len(ctx["span_id"]) == 16 + + def test_returns_dict_inside_trace_context(self, mock_client, capture_trace): + with trace_context(mock_client): + ctx = get_trace_context() + assert ctx is not None + assert len(ctx["trace_id"]) == 16 + + def test_span_id_updates_inside_child_span(self, mock_client, capture_trace): + @trace(mock_client) + def run(): + ctx_outer = get_trace_context() + with span("inner"): + ctx_inner = get_trace_context() + return ctx_outer, ctx_inner + + outer, inner = run() + assert outer["trace_id"] == inner["trace_id"] + assert outer["span_id"] != inner["span_id"] + + +class TestTraceContextFromContext: + + def test_restores_trace_id(self, mock_client, capture_trace): + with trace_context(mock_client): + original_ctx = get_trace_context() + emit("tool.call", {"name": "origin", "input": "x"}) + + original_trace_id = original_ctx["trace_id"] + + with trace_context(mock_client, from_context=original_ctx) as restored: + assert restored.trace_id == original_trace_id + emit("tool.call", {"name": "remote", "input": "y"}) + + assert len(capture_trace) == 2 + assert capture_trace[0]["trace_id"] == capture_trace[1]["trace_id"] + + def test_creates_child_span(self, mock_client, capture_trace): + with trace_context(mock_client): + original_ctx = get_trace_context() + emit("tool.call", {"name": "origin", "input": "x"}) + + with trace_context(mock_client, from_context=original_ctx): + ctx_inside = get_trace_context() + + assert ctx_inside["span_id"] != original_ctx["span_id"] + assert ctx_inside["trace_id"] == original_ctx["trace_id"] + + +# =================================================================== +# 5. Flush semantics +# =================================================================== + +class TestFlushSemantics: + + def test_adapter_disconnect_does_not_flush_shared_collector( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + adapter.fire_event("agent.lifecycle", {"action": "start"}) + adapter.disconnect() + emit("tool.call", {"name": "post_disconnect", "input": "x"}) + return "done" + + agent_run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + types = [e["event_type"] for e in events] + assert "agent.lifecycle" in types + assert "tool.call" in types + assert "agent.output" in types + + def test_adapter_disconnect_flushes_own_collector_when_standalone( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + adapter.fire_event("agent.lifecycle", {"action": "standalone"}) + adapter.disconnect() + + assert len(capture_trace) == 1 + + def test_multiple_adapters_disconnect_independently_under_shared_context( + self, mock_client, capture_trace, + ): + adapter_a = StubAdapter(mock_client) + adapter_b = StubAdapter(mock_client) + adapter_a.connect() + adapter_b.connect() + + with trace_context(mock_client): + adapter_a.fire_event("agent.lifecycle", {"source": "A"}) + adapter_a.disconnect() + adapter_b.fire_event("agent.lifecycle", {"source": "B"}) + adapter_b.disconnect() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + sources = [e["payload"]["source"] for e in events] + assert "A" in sources + assert "B" in sources + + +# =================================================================== +# 6. Callback scope + _traced_call +# =================================================================== + +class TestCallbackScope: + + def test_pushes_collector_when_standalone(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + assert _current_collector.get() is None + with adapter._callback_scope("test_scope") as scope_span_id: + assert _current_collector.get() is not None + assert _current_span_id.get() == scope_span_id + emit("tool.call", {"name": "test", "input": "x"}) + + assert _current_collector.get() is None + + def test_preserves_shared_collector(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def run(): + shared_collector = _current_collector.get() + with adapter._callback_scope("inner") as scope_span: + assert _current_collector.get() is shared_collector + assert _current_span_id.get() == scope_span + emit("tool.call", {"name": "inner_tool", "input": "x"}) + return "done" + + run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["name"] == "inner_tool" + + def test_creates_child_span(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def run(): + root_span = _current_span_id.get() + with adapter._callback_scope("child"): + child_span = _current_span_id.get() + assert child_span != root_span + emit("tool.call", {"name": "scoped", "input": "x"}) + assert _current_span_id.get() == root_span + return "done" + + run() + + def test_cleans_up_on_error(self, mock_client): + adapter = StubAdapter(mock_client) + adapter.connect() + + with pytest.raises(RuntimeError): + with adapter._callback_scope("failing"): + raise RuntimeError("boom") + + assert _current_collector.get() is None + assert _current_span_id.get() is None + + def test_traced_call_makes_providers_visible(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + def fake_agent_run(prompt): + assert _current_collector.get() is not None + emit("model.invoke", {"model": "gpt-4", "input": prompt}) + return "result" + + assert _current_collector.get() is None + result = adapter._traced_call(fake_agent_run, "hello", _span_name="agent.run") + assert result == "result" + assert _current_collector.get() is None + + adapter.disconnect() + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + model_event = find_event(events, "model.invoke") + assert model_event["payload"]["model"] == "gpt-4" + + def test_traced_call_under_shared_context(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + def fake_agent_run(prompt): + emit("model.invoke", {"model": "gpt-4", "input": prompt}) + return "result" + + @trace(mock_client) + def run(): + return adapter._traced_call(fake_agent_run, "hello", _span_name="agent.run") + + run() + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + assert find_event(events, "model.invoke") + assert find_event(events, "agent.input") + + +# =================================================================== +# 7. Upload circuit breaker +# =================================================================== + +class TestUploadCircuitBreaker: + + def test_successful_upload(self, mock_client, capture_trace): + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + assert _upload._error_count == 0 + + def test_upload_failure_records_error(self, mock_client): + mock_client.traces.upload.side_effect = RuntimeError("network error") + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert _upload._error_count == 1 + assert not _upload._circuit_open + + def test_circuit_opens_after_threshold(self, mock_client): + mock_client.traces.upload.side_effect = RuntimeError("network error") + + for _ in range(_upload._THRESHOLD): + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert _upload._circuit_open + assert _upload._error_count == _upload._THRESHOLD + + def test_open_circuit_skips_upload(self, mock_client): + _upload._circuit_open = True + _upload._opened_at = __import__("time").monotonic() + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + mock_client.traces.upload.assert_not_called() + + def test_circuit_resets_after_cooldown(self, mock_client, capture_trace): + _upload._circuit_open = True + _upload._error_count = _upload._THRESHOLD + _upload._opened_at = ( + __import__("time").monotonic() - _upload._COOLDOWN_S - 1 + ) + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + assert not _upload._circuit_open + assert _upload._error_count == 0 + + def test_success_after_failures_resets_count(self, mock_client, capture_trace): + _upload._error_count = 5 + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert _upload._error_count == 0 + + def test_protects_trace_decorator(self, mock_client): + _upload._circuit_open = True + _upload._opened_at = __import__("time").monotonic() + + @trace(mock_client) + def run(): + emit("tool.call", {"name": "test", "input": "x"}) + return "done" + + run() + mock_client.traces.upload.assert_not_called() + + def test_protects_framework_adapter(self, mock_client): + adapter = StubAdapter(mock_client) + adapter.connect() + + _upload._circuit_open = True + _upload._opened_at = __import__("time").monotonic() + + adapter.fire_event("tool.call", {"name": "test", "input": "x"}) + adapter.disconnect() + + mock_client.traces.upload.assert_not_called() + + +# =================================================================== +# 8. Edge cases +# =================================================================== + +class TestEdgeCases: + + def test_adapter_used_across_multiple_traces( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def run_1(): + adapter.fire_event("agent.lifecycle", {"run": 1}) + return "done" + + @trace(mock_client) + def run_2(): + adapter.fire_event("agent.lifecycle", {"run": 2}) + return "done" + + run_1() + run_2() + + assert len(capture_trace) == 2 + assert capture_trace[0]["trace_id"] != capture_trace[1]["trace_id"] + + def test_no_events_means_no_upload(self, mock_client): + with trace_context(mock_client): + pass + + mock_client.traces.upload.assert_not_called() + + def test_standalone_adapter_unaffected_by_previous_shared_context( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + with trace_context(mock_client): + adapter.fire_event("agent.lifecycle", {"phase": "shared"}) + + adapter.disconnect() + + adapter = StubAdapter(mock_client) + adapter.connect() + adapter.fire_event("agent.lifecycle", {"phase": "standalone"}) + adapter.disconnect() + + assert len(capture_trace) == 2 + assert capture_trace[0]["trace_id"] != capture_trace[1]["trace_id"] From 3d969ec73b72a91e3864cfe11b8e17ef50c45a25 Mon Sep 17 00:00:00 2001 From: Garrett Allen <59334078+garrettallen14@users.noreply.github.com> Date: Fri, 3 Apr 2026 23:09:08 -0700 Subject: [PATCH 3/3] feat: updates + new adapters (#82) --- src/layerlens/instrument/_context.py | 23 +- .../adapters/frameworks/_base_framework.py | 269 +++++- .../instrument/adapters/frameworks/_utils.py | 69 ++ .../instrument/adapters/frameworks/crewai.py | 475 ++++++++++ .../adapters/frameworks/langchain.py | 232 +++-- .../adapters/frameworks/langgraph.py | 7 +- .../adapters/frameworks/openai_agents.py | 306 +++++++ .../adapters/frameworks/pydantic_ai.py | 350 ++++++++ .../adapters/frameworks/semantic_kernel.py | 389 +++++++++ .../adapters/frameworks/test_crewai.py | 808 +++++++++++++++++ .../adapters/frameworks/test_langchain.py | 347 ++++++-- .../adapters/frameworks/test_langgraph.py | 2 +- .../adapters/frameworks/test_openai_agents.py | 823 ++++++++++++++++++ .../adapters/frameworks/test_pydantic_ai.py | 471 ++++++++++ .../frameworks/test_semantic_kernel.py | 753 ++++++++++++++++ 15 files changed, 5158 insertions(+), 166 deletions(-) create mode 100644 src/layerlens/instrument/adapters/frameworks/_utils.py create mode 100644 src/layerlens/instrument/adapters/frameworks/crewai.py create mode 100644 src/layerlens/instrument/adapters/frameworks/openai_agents.py create mode 100644 src/layerlens/instrument/adapters/frameworks/pydantic_ai.py create mode 100644 src/layerlens/instrument/adapters/frameworks/semantic_kernel.py create mode 100644 tests/instrument/adapters/frameworks/test_crewai.py create mode 100644 tests/instrument/adapters/frameworks/test_openai_agents.py create mode 100644 tests/instrument/adapters/frameworks/test_pydantic_ai.py create mode 100644 tests/instrument/adapters/frameworks/test_semantic_kernel.py diff --git a/src/layerlens/instrument/_context.py b/src/layerlens/instrument/_context.py index dc1f873..fce18d2 100644 --- a/src/layerlens/instrument/_context.py +++ b/src/layerlens/instrument/_context.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Optional, NamedTuple +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, NamedTuple from contextvars import ContextVar from ._collector import TraceCollector @@ -11,6 +12,26 @@ _current_span_name: ContextVar[Optional[str]] = ContextVar("_current_span_name", default=None) +@dataclass +class RunState: + """Per-run state isolated via ContextVar. + + Each concurrent run (agent invocation, crew kickoff, etc.) gets its own + RunState stored in ``_current_run``. This isolates the collector, root span, + timers, and any adapter-specific data so concurrent runs on the same adapter + instance don't clobber each other. + """ + + collector: TraceCollector + root_span_id: str + timers: Dict[str, int] = field(default_factory=dict) + data: Dict[str, Any] = field(default_factory=dict) + _token: Any = field(default=None, repr=False) + + +_current_run: ContextVar[Optional[RunState]] = ContextVar("_current_run", default=None) + + class _SpanSnapshot(NamedTuple): span_id: Any parent_span_id: Any diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 20ddbdb..8190510 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import time import uuid import logging import threading @@ -14,15 +15,27 @@ from .._base import AdapterInfo, BaseAdapter from ..._collector import TraceCollector from ..._capture_config import CaptureConfig -from ..._context import _current_collector, _current_span_id, _push_span, _pop_span +from ..._context import _current_collector, _current_span_id, _push_span, _pop_span, _current_run, RunState log = logging.getLogger(__name__) +_UNSET: Any = object() # sentinel: distinguish "not passed" from explicit None + class FrameworkAdapter(BaseAdapter): """Base for framework adapters with collector lifecycle management.""" name: str # Subclass must set: "crewai", "llamaindex", etc. + package: str = "" # pip extra name, e.g. "crewai" → pip install layerlens[crewai] + + def _check_dependency(self, available: bool) -> None: + """Raise ImportError with a helpful install message if the dependency is missing.""" + if not available: + pkg = self.package or self.name + raise ImportError( + "The '%s' package is required for %s instrumentation. " + "Install it with: pip install layerlens[%s]" % (pkg, self.name, pkg) + ) def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: self._client = client @@ -34,15 +47,68 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._using_shared_collector = False # Optional run_id → span_id mapping for callback-style frameworks self._span_ids: Dict[str, str] = {} + # Root run tracking for auto-flush on outermost callback completion + self._root_run_id: Optional[str] = None + # Timing: key → start_ns for _start_timer / _stop_timer + self._timers: Dict[str, int] = {} # Subclasses populate during connect() for adapter_info() metadata self._metadata: Dict[str, Any] = {} + # ------------------------------------------------------------------ + # Per-run state (ContextVar-based isolation for concurrent runs) + # ------------------------------------------------------------------ + + def _begin_run(self) -> RunState: + """Start a new run with its own collector, root span, and timers. + + Stores the RunState in a ContextVar so all subsequent calls to + ``_ensure_collector``, ``_start_timer``, ``_stop_timer``, and + ``_get_root_span`` use per-run state instead of instance state. + + ContextVars are automatically isolated per ``asyncio.Task``, so + concurrent runs on the same adapter get independent state. + """ + run = RunState( + collector=TraceCollector(self._client, self._config), + root_span_id=uuid.uuid4().hex[:16], + ) + run._token = _current_run.set(run) + return run + + def _end_run(self) -> None: + """Flush the current run's collector and restore the previous ContextVar state.""" + run = _current_run.get() + if run is None: + return + if run._token is not None: + try: + _current_run.reset(run._token) + except ValueError: + # Token created in a different Context (e.g. framework copies + # contexts between hook callbacks). Fall back to plain set. + _current_run.set(None) + else: + _current_run.set(None) + run.collector.flush() + + def _get_run(self) -> Optional[RunState]: + """Return the current RunState, or None if not inside a ``_begin_run`` scope.""" + return _current_run.get() + # ------------------------------------------------------------------ # Collector lifecycle # ------------------------------------------------------------------ def _ensure_collector(self) -> TraceCollector: - """Return the shared collector from ContextVars, or create a private one.""" + """Return the collector for the current context. + + Checks (in order): active RunState, shared collector from ContextVars, + then creates a private instance-level collector as fallback. + """ + run = _current_run.get() + if run is not None: + return run.collector + shared = _current_collector.get() if shared is not None: self._using_shared_collector = True @@ -60,32 +126,141 @@ def _ensure_collector(self) -> TraceCollector: def _new_span_id() -> str: return uuid.uuid4().hex[:16] + # ------------------------------------------------------------------ + # Shared helpers — payload, timing, tokens, content gating + # ------------------------------------------------------------------ + + def _payload(self, **extra: Any) -> Dict[str, Any]: + """Start a payload dict with ``framework: self.name``. + + Usage:: + + payload = self._payload(agent_name="foo", status="ok") + """ + p: Dict[str, Any] = {"framework": self.name} + if extra: + p.update(extra) + return p + + def _get_root_span(self) -> str: + """Return the root span ID for the current run. + + Checks RunState first, then falls back to instance-level ``_root_span_id``. + If neither is set, generates a new one. + """ + run = _current_run.get() + if run is not None: + return run.root_span_id + + with self._lock: + sid = self._root_span_id + if sid is not None: + return sid + sid = self._new_span_id() + with self._lock: + self._root_span_id = sid + return sid + + def _start_timer(self, key: str) -> None: + """Record a start timestamp (nanoseconds) under *key*.""" + run = _current_run.get() + if run is not None: + run.timers[key] = time.time_ns() + return + with self._lock: + self._timers[key] = time.time_ns() + + def _stop_timer(self, key: str) -> Optional[float]: + """Pop the start time for *key* and return elapsed ``latency_ms``, or ``None``.""" + run = _current_run.get() + if run is not None: + start_ns = run.timers.pop(key, 0) + else: + with self._lock: + start_ns = self._timers.pop(key, 0) + if not start_ns: + return None + return (time.time_ns() - start_ns) / 1_000_000 + + @staticmethod + def _normalize_tokens(usage: Any) -> Dict[str, Any]: + """Extract token counts from any usage object or dict. + + Handles field-name variants across providers: + ``prompt_tokens`` / ``input_tokens`` → ``tokens_prompt`` + ``completion_tokens`` / ``output_tokens`` → ``tokens_completion`` + + Returns a dict with ``tokens_prompt``, ``tokens_completion``, + ``tokens_total`` — only keys that have non-zero values. + """ + tokens: Dict[str, Any] = {} + if usage is None: + return tokens + + if isinstance(usage, dict): + prompt = usage.get("prompt_tokens") or usage.get("input_tokens") + completion = usage.get("completion_tokens") or usage.get("output_tokens") + total = usage.get("total_tokens") + else: + prompt = ( + getattr(usage, "prompt_tokens", None) + or getattr(usage, "input_tokens", None) + ) + completion = ( + getattr(usage, "completion_tokens", None) + or getattr(usage, "output_tokens", None) + ) + total = getattr(usage, "total_tokens", None) + + if prompt is not None: + tokens["tokens_prompt"] = int(prompt) + if completion is not None: + tokens["tokens_completion"] = int(completion) + if prompt is not None and completion is not None: + tokens["tokens_total"] = int(prompt) + int(completion) + elif total is not None: + tokens["tokens_total"] = int(total) + return tokens + + def _set_if_capturing(self, payload: Dict[str, Any], key: str, value: Any) -> None: + """Set ``payload[key] = value`` only if ``capture_content`` is enabled.""" + if self._config.capture_content and value is not None: + payload[key] = value + # ------------------------------------------------------------------ # Callback scope — bridges framework callbacks to ContextVars # ------------------------------------------------------------------ + def _push_context(self, span_id: str, span_name: Optional[str] = None) -> Any: + """Push collector + span into ContextVars. Returns an opaque token for ``_pop_context``.""" + with self._lock: + collector = self._ensure_collector() + needs_collector_push = _current_collector.get() is None + col_token = _current_collector.set(collector) if needs_collector_push else None + snapshot = _push_span(span_id, span_name) + return (snapshot, col_token) + + def _pop_context(self, token: Any) -> None: + """Restore ContextVars from a token returned by ``_push_context``.""" + if token is None: + return + snapshot, col_token = token + _pop_span(snapshot) + if col_token is not None: + _current_collector.reset(col_token) + @contextmanager def _callback_scope( self, span_name: Optional[str] = None, ) -> Generator[str, None, None]: """Push collector + new span into ContextVars; yields the span_id.""" - collector = self._ensure_collector() span_id = self._new_span_id() - - # Only set the collector ContextVar if no shared one exists already - needs_collector_push = _current_collector.get() is None - col_token = None - if needs_collector_push: - col_token = _current_collector.set(collector) - - snapshot = _push_span(span_id, span_name) + token = self._push_context(span_id, span_name) try: yield span_id finally: - _pop_span(snapshot) - if col_token is not None: - _current_collector.reset(col_token) + self._pop_context(token) def _traced_call( self, @@ -118,14 +293,43 @@ def _emit( event_type: str, payload: Dict[str, Any], span_id: Optional[str] = None, - parent_span_id: Optional[str] = None, + parent_span_id: Any = _UNSET, span_name: Optional[str] = None, + run_id: Any = None, + parent_run_id: Any = None, ) -> None: - """Thread-safe event emission through the collector.""" + """Thread-safe event emission through the collector. + + When *run_id* is provided, it is translated to a span_id via + ``_span_id_for`` and the first run_id seen is tracked as the root + (for flush-on-completion in callback-style frameworks). + + When *parent_span_id* is omitted, falls back to ``_root_span_id``. + Pass ``parent_span_id=None`` explicitly to emit with no parent + (for adapters that manage their own span hierarchy). + """ + # RunState path: per-run isolation, no lock needed + run = _current_run.get() + if run is not None: + if run_id is not None: + span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) + sid = span_id or self._new_span_id() + parent = run.root_span_id if parent_span_id is _UNSET else parent_span_id + run.collector.emit( + event_type, payload, + span_id=sid, parent_span_id=parent, span_name=span_name, + ) + return + + # Legacy path: instance-level state with lock + if run_id is not None: + span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) + if self._root_run_id is None: + self._root_run_id = str(run_id) with self._lock: collector = self._ensure_collector() sid = span_id or self._new_span_id() - parent = parent_span_id or self._root_span_id + parent = self._root_span_id if parent_span_id is _UNSET else parent_span_id collector.emit( event_type, payload, span_id=sid, parent_span_id=parent, span_name=span_name, @@ -136,12 +340,19 @@ def _emit( # ------------------------------------------------------------------ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: - """Map a framework run_id to a (span_id, parent_span_id) pair.""" + """Map a framework run_id to a (span_id, parent_span_id) pair. + + When a RunState is active, span_ids are stored per-run in + ``run.data["span_ids"]`` for concurrent-run isolation. + Falls back to instance-level ``_span_ids`` otherwise. + """ + run = _current_run.get() + span_ids = run.data.setdefault("span_ids", {}) if run is not None else self._span_ids rid = str(run_id) - if rid not in self._span_ids: - self._span_ids[rid] = self._new_span_id() - span_id = self._span_ids[rid] - parent_span_id = self._span_ids.get(str(parent_run_id)) if parent_run_id else None + if rid not in span_ids: + span_ids[rid] = self._new_span_id() + span_id = span_ids[rid] + parent_span_id = span_ids.get(str(parent_run_id)) if parent_run_id else None return span_id, parent_span_id # ------------------------------------------------------------------ @@ -165,16 +376,26 @@ def _flush_collector(self) -> None: # ------------------------------------------------------------------ def connect(self, target: Any = None, **kwargs: Any) -> Any: - """Mark as connected. Subclasses override for framework registration.""" + """Check dependencies, run framework-specific setup, and mark as connected.""" + self._on_connect(target, **kwargs) self._connected = True return target + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + """Override to set up framework-specific resources (subscribe, wrap, etc.).""" + pass + def disconnect(self) -> None: - """Flush remaining events and mark as disconnected.""" + """Clean up framework resources, flush events, and mark as disconnected.""" + self._on_disconnect() self._flush_collector() self._connected = False self._metadata.clear() + def _on_disconnect(self) -> None: + """Override to clean up framework-specific resources (unsubscribe, restore, etc.).""" + pass + def adapter_info(self) -> AdapterInfo: return AdapterInfo( name=self.name, diff --git a/src/layerlens/instrument/adapters/frameworks/_utils.py b/src/layerlens/instrument/adapters/frameworks/_utils.py new file mode 100644 index 0000000..fdd66be --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/_utils.py @@ -0,0 +1,69 @@ +"""Shared utilities for framework adapters. + +Centralises helpers that were previously copy-pasted across adapter +files: serialisation, span ID generation, and text truncation. +""" +from __future__ import annotations + +import uuid +from typing import Any + +# --------------------------------------------------------------------------- +# Span IDs +# --------------------------------------------------------------------------- + + +def new_span_id() -> str: + """Generate a short random span identifier.""" + return uuid.uuid4().hex[:16] + + +# --------------------------------------------------------------------------- +# Serialisation +# --------------------------------------------------------------------------- + + +def safe_serialize(value: Any) -> Any: + """Best-effort conversion of *value* into a JSON-friendly form. + + Handles Pydantic models (``model_dump``), objects with ``to_dict``, + dicts, lists/tuples, and falls back to ``str()``. + """ + if value is None: + return None + if isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [safe_serialize(v) for v in value] + if hasattr(value, "model_dump"): + try: + return value.model_dump() + except Exception: + pass + if hasattr(value, "to_dict"): + try: + return value.to_dict() + except Exception: + pass + if isinstance(value, dict): + return {str(k): safe_serialize(v) for k, v in value.items()} + return str(value) + + +# --------------------------------------------------------------------------- +# Text truncation +# --------------------------------------------------------------------------- + + +def truncate(text: Any, max_len: int = 2000) -> Any: + """Truncate *text* to *max_len* characters, appending ``'...'``. + + Returns *None* unchanged. Non-string values are stringified first. + """ + if text is None: + return None + if not isinstance(text, str): + text = str(text) + if len(text) <= max_len: + return text + return text[:max_len] + "..." diff --git a/src/layerlens/instrument/adapters/frameworks/crewai.py b/src/layerlens/instrument/adapters/frameworks/crewai.py new file mode 100644 index 0000000..b922748 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +try: + from crewai.events import BaseEventListener as _BaseEventListener # pyright: ignore[reportMissingImports] +except (ImportError, TypeError): + _BaseEventListener = None + + +class CrewAIAdapter(FrameworkAdapter): + """CrewAI adapter using the typed event bus API (crewai >= 1.0). + + Subscribes to CrewAI's event bus to capture crew lifecycle, agent + execution, LLM calls, tool usage, flows, and MCP tool events as + flat layerlens events. + + Usage:: + + adapter = CrewAIAdapter(client) + adapter.connect() + crew.kickoff() # events flow automatically via event bus + adapter.disconnect() + """ + + name = "crewai" + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + super().__init__(client, capture_config) + self._registered_handlers: list = [] + + # Span tracking: crew/flow → task → agent → leaf hierarchy + self._crew_span_id: Optional[str] = None + self._task_span_ids: Dict[str, str] = {} # task name → span_id + self._current_task_span_id: Optional[str] = None + self._agent_span_ids: Dict[str, str] = {} # agent_role → span_id + self._current_agent_span_id: Optional[str] = None + # tool.call span IDs keyed by tool_name+id for pairing start/end + self._tool_span_ids: Dict[str, str] = {} + + # Event name → handler method name; resolved to real classes at subscribe time. + _EVENT_MAP = [ + ("CrewKickoffStartedEvent", "_on_crew_started"), + ("CrewKickoffCompletedEvent", "_on_crew_completed"), + ("CrewKickoffFailedEvent", "_on_crew_failed"), + ("TaskStartedEvent", "_on_task_started"), + ("TaskCompletedEvent", "_on_task_completed"), + ("TaskFailedEvent", "_on_task_failed"), + ("AgentExecutionStartedEvent", "_on_agent_execution_started"), + ("AgentExecutionCompletedEvent", "_on_agent_execution_completed"), + ("AgentExecutionErrorEvent", "_on_agent_execution_error"), + ("LLMCallStartedEvent", "_on_llm_started"), + ("LLMCallCompletedEvent", "_on_llm_completed"), + ("LLMCallFailedEvent", "_on_llm_failed"), + ("ToolUsageStartedEvent", "_on_tool_started"), + ("ToolUsageFinishedEvent", "_on_tool_finished"), + ("ToolUsageErrorEvent", "_on_tool_error"), + ("FlowStartedEvent", "_on_flow_started"), + ("FlowFinishedEvent", "_on_flow_finished"), + ("MCPToolExecutionCompletedEvent", "_on_mcp_tool_completed"), + ("MCPToolExecutionFailedEvent", "_on_mcp_tool_failed"), + ] + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_BaseEventListener is not None) + self._subscribe() + + def _on_disconnect(self) -> None: + self._unsubscribe() + self._registered_handlers.clear() + self._reset_spans() + + # ------------------------------------------------------------------ + # Event bus wiring + # ------------------------------------------------------------------ + + def _subscribe(self) -> None: + """Register all event handlers on the CrewAI bus.""" + import crewai.events as ev # pyright: ignore[reportMissingImports] + + for event_name, method_name in self._EVENT_MAP: + event_cls = getattr(ev, event_name) + method = getattr(self, method_name) + + def _handler(source: Any, event: Any, _m: Any = method) -> None: + try: + _m(source, event) + except Exception: + log.warning("layerlens: error in CrewAI event handler", exc_info=True) + + ev.crewai_event_bus.on(event_cls)(_handler) + self._registered_handlers.append((event_cls, _handler)) + + def _unsubscribe(self) -> None: + """Remove all previously registered handlers from the CrewAI bus.""" + try: + from crewai.events import crewai_event_bus # pyright: ignore[reportMissingImports] + except ImportError: + return + for event_cls, handler in self._registered_handlers: + try: + crewai_event_bus.off(event_cls, handler) + except Exception: + log.debug("layerlens: could not unregister %s handler", event_cls.__name__, exc_info=True) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_name(obj: Any) -> str: + return getattr(obj, "name", None) or type(obj).__name__ + + @staticmethod + def _get_task_name(event: Any) -> str: + """Extract task name from a CrewAI event.""" + name = getattr(event, "task_name", None) + if name: + return str(name) + task = getattr(event, "task", None) + if task: + return str(getattr(task, "description", None) or getattr(task, "name", ""))[:200] + return "" + + @staticmethod + def _tool_event_key(event: Any) -> str: + """Build a key to correlate ToolUsageStarted with ToolUsageFinished.""" + tool_name = getattr(event, "tool_name", None) or "" + agent_key = getattr(event, "agent_key", None) or "" + return f"{tool_name}:{agent_key}" + + def _leaf_parent_span_id(self) -> Optional[str]: + """Return the innermost active parent span for leaf events (LLM, tool).""" + with self._lock: + return self._current_agent_span_id or self._current_task_span_id or self._crew_span_id + + def _reset_spans(self) -> None: + """Clear all span tracking state.""" + with self._lock: + self._crew_span_id = None + self._task_span_ids.clear() + self._current_task_span_id = None + self._agent_span_ids.clear() + self._current_agent_span_id = None + self._tool_span_ids.clear() + + def _end_trace(self) -> None: + """Reset spans and flush — called when a crew/flow run completes.""" + self._reset_spans() + self._flush_collector() + + # ------------------------------------------------------------------ + # Crew lifecycle + # ------------------------------------------------------------------ + + def _on_crew_started(self, source: Any, event: Any) -> None: + span_id = self._new_span_id() + with self._lock: + self._crew_span_id = span_id + self._start_timer("crew") + crew_name = getattr(event, "crew_name", None) or self._get_name(source) + payload = self._payload(crew_name=crew_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "inputs", None))) + self._emit("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=crew_name) + + def _on_crew_completed(self, source: Any, event: Any) -> None: + latency_ms = self._stop_timer("crew") + crew_name = getattr(event, "crew_name", None) or self._get_name(source) + payload = self._payload(crew_name=crew_name) + if latency_ms is not None: + payload["duration_ns"] = int(latency_ms * 1_000_000) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + total_tokens = getattr(event, "total_tokens", None) + if total_tokens is not None: + payload["tokens_total"] = total_tokens + self._emit( + "agent.output", payload, + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, span_name=crew_name, + ) + if total_tokens: + self._emit( + "cost.record", + self._payload(tokens_total=total_tokens), + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, + ) + self._end_trace() + + def _on_crew_failed(self, source: Any, event: Any) -> None: + error = str(getattr(event, "error", "unknown error")) + crew_name = getattr(event, "crew_name", None) or self._get_name(source) + self._emit( + "agent.error", + self._payload(crew_name=crew_name, error=error), + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, span_name=crew_name, + ) + self._end_trace() + + # ------------------------------------------------------------------ + # Task lifecycle + # ------------------------------------------------------------------ + + def _on_task_started(self, source: Any, event: Any) -> None: + task_name = self._get_task_name(event) + span_id = self._new_span_id() + with self._lock: + self._task_span_ids[task_name] = span_id + self._current_task_span_id = span_id + parent = self._crew_span_id + agent_role = getattr(event, "agent_role", None) + payload = self._payload(task_name=task_name) + if agent_role: + payload["agent_role"] = agent_role + if self._config.capture_content: + context = getattr(event, "context", None) + if context: + payload["context"] = str(context)[:500] + self._emit( + "agent.input", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"task:{task_name[:60]}", + ) + + def _on_task_completed(self, source: Any, event: Any) -> None: + task_name = self._get_task_name(event) + with self._lock: + span_id = self._task_span_ids.pop(task_name, self._current_task_span_id or self._new_span_id()) + parent = self._crew_span_id + payload = self._payload(task_name=task_name) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + self._emit( + "agent.output", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"task:{task_name[:60]}", + ) + + def _on_task_failed(self, source: Any, event: Any) -> None: + task_name = self._get_task_name(event) + with self._lock: + span_id = self._task_span_ids.pop(task_name, self._current_task_span_id or self._new_span_id()) + parent = self._crew_span_id + error = str(getattr(event, "error", "unknown error")) + self._emit( + "agent.error", + self._payload(task_name=task_name, error=error), + span_id=span_id, parent_span_id=parent, + ) + + # ------------------------------------------------------------------ + # Agent execution lifecycle + # ------------------------------------------------------------------ + + def _on_agent_execution_started(self, source: Any, event: Any) -> None: + agent = getattr(event, "agent", None) + agent_role = getattr(event, "agent_role", None) or ( + getattr(agent, "role", None) if agent else None + ) or "unknown" + span_id = self._new_span_id() + with self._lock: + self._agent_span_ids[agent_role] = span_id + self._current_agent_span_id = span_id + parent = self._current_task_span_id or self._crew_span_id + + payload = self._payload(agent_role=agent_role) + tools = getattr(event, "tools", None) + if tools: + payload["tools"] = [getattr(t, "name", str(t)) for t in tools] + if self._config.capture_content: + task_prompt = getattr(event, "task_prompt", None) + if task_prompt: + payload["task_prompt"] = str(task_prompt)[:500] + self._emit( + "agent.input", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"agent:{agent_role[:60]}", + ) + + def _on_agent_execution_completed(self, source: Any, event: Any) -> None: + agent = getattr(event, "agent", None) + agent_role = getattr(event, "agent_role", None) or ( + getattr(agent, "role", None) if agent else None + ) or "unknown" + with self._lock: + span_id = self._agent_span_ids.pop(agent_role, self._current_agent_span_id or self._new_span_id()) + parent = self._current_task_span_id or self._crew_span_id + if self._current_agent_span_id == span_id: + self._current_agent_span_id = None + + payload = self._payload(agent_role=agent_role, status="ok") + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + self._emit( + "agent.output", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"agent:{agent_role[:60]}", + ) + + def _on_agent_execution_error(self, source: Any, event: Any) -> None: + agent = getattr(event, "agent", None) + agent_role = getattr(event, "agent_role", None) or ( + getattr(agent, "role", None) if agent else None + ) or "unknown" + error = str(getattr(event, "error", "unknown error")) + with self._lock: + span_id = self._agent_span_ids.pop(agent_role, self._current_agent_span_id or self._new_span_id()) + parent = self._current_task_span_id or self._crew_span_id + if self._current_agent_span_id == span_id: + self._current_agent_span_id = None + + self._emit( + "agent.error", + self._payload(agent_role=agent_role, error=error), + span_id=span_id, parent_span_id=parent, + span_name=f"agent:{agent_role[:60]}", + ) + + # ------------------------------------------------------------------ + # LLM calls + # ------------------------------------------------------------------ + + def _on_llm_started(self, source: Any, event: Any) -> None: + call_id = getattr(event, "call_id", None) + if call_id: + self._start_timer(f"llm:{call_id}") + + def _on_llm_completed(self, source: Any, event: Any) -> None: + model = getattr(event, "model", None) + response = getattr(event, "response", None) + # Unwrap .usage from the response before normalizing + usage = getattr(response, "usage", None) if response and not isinstance(response, dict) else ( + response.get("usage") if isinstance(response, dict) else None + ) + tokens = self._normalize_tokens(usage) + payload = self._payload() + if model: + payload["model"] = model + call_id = getattr(event, "call_id", None) + if call_id: + latency_ms = self._stop_timer(f"llm:{call_id}") + if latency_ms is not None: + payload["latency_ms"] = latency_ms + payload.update(tokens) + parent = self._leaf_parent_span_id() + span_id = self._new_span_id() + self._emit("model.invoke", payload, span_id=span_id, parent_span_id=parent) + if tokens: + self._emit( + "cost.record", + self._payload(model=model, **tokens), + span_id=span_id, parent_span_id=parent, + ) + + def _on_llm_failed(self, source: Any, event: Any) -> None: + error = str(getattr(event, "error", "unknown error")) + model = getattr(event, "model", None) + payload = self._payload(error=error) + if model: + payload["model"] = model + parent = self._leaf_parent_span_id() + self._emit("agent.error", payload, parent_span_id=parent) + + # ------------------------------------------------------------------ + # Tool usage — split into tool.call (start) and tool.result (end) + # ------------------------------------------------------------------ + + def _on_tool_started(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + span_id = self._new_span_id() + tool_key = self._tool_event_key(event) + with self._lock: + self._tool_span_ids[tool_key] = span_id + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "tool_args", None))) + parent = self._leaf_parent_span_id() + self._emit("tool.call", payload, span_id=span_id, parent_span_id=parent) + + def _on_tool_finished(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + tool_key = self._tool_event_key(event) + with self._lock: + span_id = self._tool_span_ids.pop(tool_key, None) + if span_id is None: + span_id = self._new_span_id() + + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + # Compute latency from started_at/finished_at + started_at = getattr(event, "started_at", None) + finished_at = getattr(event, "finished_at", None) + if started_at is not None and finished_at is not None: + try: + payload["latency_ms"] = (finished_at - started_at).total_seconds() * 1000 + except Exception: + pass + from_cache = getattr(event, "from_cache", None) + if from_cache: + payload["from_cache"] = True + parent = self._leaf_parent_span_id() + self._emit("tool.result", payload, span_id=span_id, parent_span_id=parent) + + def _on_tool_error(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + error = str(getattr(event, "error", "unknown error")) + tool_key = self._tool_event_key(event) + with self._lock: + self._tool_span_ids.pop(tool_key, None) + parent = self._leaf_parent_span_id() + self._emit( + "agent.error", + self._payload(tool_name=tool_name, error=error), + parent_span_id=parent, + ) + + # ------------------------------------------------------------------ + # Flow events + # ------------------------------------------------------------------ + + def _on_flow_started(self, source: Any, event: Any) -> None: + span_id = self._new_span_id() + with self._lock: + self._crew_span_id = span_id + self._start_timer("crew") + flow_name = getattr(event, "flow_name", None) or self._get_name(source) + payload = self._payload(flow_name=flow_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "inputs", None))) + self._emit("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=f"flow:{flow_name}") + + def _on_flow_finished(self, source: Any, event: Any) -> None: + latency_ms = self._stop_timer("crew") + flow_name = getattr(event, "flow_name", None) or self._get_name(source) + payload = self._payload(flow_name=flow_name) + if latency_ms is not None: + payload["duration_ns"] = int(latency_ms * 1_000_000) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "result", None))) + self._emit( + "agent.output", payload, + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, span_name=f"flow:{flow_name}", + ) + self._end_trace() + + # ------------------------------------------------------------------ + # MCP tool events + # ------------------------------------------------------------------ + + def _on_mcp_tool_completed(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + server_name = getattr(event, "server_name", None) + latency_ms = getattr(event, "execution_duration_ms", None) + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "result", None))) + if server_name: + payload["mcp_server"] = server_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + parent = self._leaf_parent_span_id() + self._emit("tool.call", payload, parent_span_id=parent) + + def _on_mcp_tool_failed(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + error = str(getattr(event, "error", "unknown error")) + server_name = getattr(event, "server_name", None) + payload = self._payload(tool_name=tool_name, error=error) + if server_name: + payload["mcp_server"] = server_name + parent = self._leaf_parent_span_id() + self._emit("agent.error", payload, parent_span_id=parent) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py index 5b14f0e..a69a7d6 100644 --- a/src/layerlens/instrument/adapters/frameworks/langchain.py +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -1,12 +1,27 @@ from __future__ import annotations +import functools from uuid import UUID -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from ._base_framework import FrameworkAdapter +from ..._capture_config import CaptureConfig + + +def _auto_flush(fn): # type: ignore[type-arg] + """Decorator: after the callback returns, flush if this was the outermost run.""" + @functools.wraps(fn) + def wrapper(self, *args, run_id, **kwargs): # type: ignore[no-untyped-def] + fn(self, *args, run_id=run_id, **kwargs) + run = self._get_run() + if run is not None: + if str(run_id) == run.data.get("root_run_id"): + self._end_run() + elif str(run_id) == self._root_run_id and self._collector is not None: + self._flush_collector() + self._root_run_id = None + return wrapper -if TYPE_CHECKING: - from ..._capture_config import CaptureConfig try: from langchain_core.callbacks import BaseCallbackHandler # pyright: ignore[reportAssignmentType] @@ -26,28 +41,14 @@ class LangChainCallbackHandler(BaseCallbackHandler, FrameworkAdapter): def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: BaseCallbackHandler.__init__(self) FrameworkAdapter.__init__(self, client, capture_config=capture_config) - self._root_run_id: Optional[str] = None + # Pending LLM runs: run_id -> {name, messages, parent_run_id} + self._pending_llm: Dict[str, Dict[str, Any]] = {} + # Context tokens for span propagation: run_id -> token from _push_context + self._run_contexts: Dict[str, Any] = {} - def _emit_for_run( - self, - event_type: str, - payload: Dict[str, Any], - run_id: UUID, - parent_run_id: Optional[UUID] = None, - ) -> None: - """Emit an event, mapping framework run_ids to span_ids.""" - span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) - rid = str(run_id) - if self._root_run_id is None: - self._root_run_id = rid - self._emit(event_type, payload, span_id=span_id, parent_span_id=parent_span_id) - - def _maybe_flush(self, run_id: UUID) -> None: - if str(run_id) == self._root_run_id and self._collector is not None: - self._flush_collector() - self._root_run_id = None - - # -- Chain -- + # ------------------------------------------------------------------ + # Chain callbacks + # ------------------------------------------------------------------ def on_chain_start( self, @@ -58,10 +59,16 @@ def on_chain_start( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: + if parent_run_id is None: + run = self._begin_run() + run.data["root_run_id"] = str(run_id) serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", inputs) + self._emit("agent.input", payload, run_id=run_id, parent_run_id=parent_run_id) + @_auto_flush def on_chain_end( self, outputs: Dict[str, Any], @@ -70,9 +77,11 @@ def on_chain_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.output", {"output": outputs, "status": "ok"}, run_id) - self._maybe_flush(run_id) + payload = self._payload(status="ok") + self._set_if_capturing(payload, "output", outputs) + self._emit("agent.output", payload, run_id=run_id) + @_auto_flush def on_chain_error( self, error: BaseException, @@ -81,10 +90,11 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) - # -- LLM -- + # ------------------------------------------------------------------ + # LLM callbacks — merged into single model.invoke on end + # ------------------------------------------------------------------ def on_llm_start( self, @@ -97,7 +107,15 @@ def on_llm_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit_for_run("model.invoke", {"name": name, "messages": prompts}, run_id, parent_run_id) + self._start_timer(str(run_id)) + pending: Dict[str, Any] = { + "name": name, + "parent_run_id": parent_run_id, + } + self._set_if_capturing(pending, "messages", prompts) + self._pending_llm[str(run_id)] = pending + span_id, _ = self._span_id_for(run_id) + self._run_contexts[str(run_id)] = self._push_context(span_id) def on_chat_model_start( self, @@ -110,13 +128,20 @@ def on_chat_model_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit_for_run( - "model.invoke", - {"name": name, "messages": [[_serialize_lc_message(m) for m in batch] for batch in messages]}, - run_id, - parent_run_id, + self._start_timer(str(run_id)) + pending: Dict[str, Any] = { + "name": name, + "parent_run_id": parent_run_id, + } + self._set_if_capturing( + pending, "messages", + [[_serialize_lc_message(m) for m in batch] for batch in messages], ) + self._pending_llm[str(run_id)] = pending + span_id, _ = self._span_id_for(run_id) + self._run_contexts[str(run_id)] = self._push_context(span_id) + @_auto_flush def on_llm_end( self, response: Any, @@ -125,6 +150,10 @@ def on_llm_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: + self._pop_context(self._run_contexts.pop(str(run_id), None)) + pending = self._pending_llm.pop(str(run_id), {}) + + # Extract response data output = None try: generations = response.generations @@ -139,20 +168,40 @@ def on_llm_end( llm_output = {} model_name = llm_output.get("model_name") - if model_name or output: - self._emit_for_run( - "model.invoke", - {"model": model_name, "output_message": output}, - run_id, - parent_run_id, - ) - usage = llm_output.get("token_usage", {}) - if usage: - self._emit_for_run("cost.record", usage, run_id, parent_run_id) + # Build single merged model.invoke event + payload = self._payload() + if pending.get("name"): + payload["name"] = pending["name"] + if model_name: + payload["model"] = model_name + self._set_if_capturing(payload, "messages", pending.get("messages")) + self._set_if_capturing(payload, "output_message", output) + + # Latency + latency_ms = self._stop_timer(str(run_id)) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + # Tokens + usage = llm_output.get("token_usage") or llm_output.get("usage_metadata") + tokens = self._normalize_tokens(usage) + payload.update(tokens) + + self._emit( + "model.invoke", payload, + run_id=run_id, parent_run_id=pending.get("parent_run_id"), + ) - self._maybe_flush(run_id) + # Separate cost.record if we have token data + if tokens: + cost_payload = self._payload() + if model_name: + cost_payload["model"] = model_name + cost_payload.update(tokens) + self._emit("cost.record", cost_payload, run_id=run_id, parent_run_id=pending.get("parent_run_id")) + @_auto_flush def on_llm_error( self, error: BaseException, @@ -161,10 +210,22 @@ def on_llm_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._pop_context(self._run_contexts.pop(str(run_id), None)) + pending = self._pending_llm.pop(str(run_id), {}) - # -- Tool -- + payload = self._payload(error=str(error)) + if pending.get("name"): + payload["name"] = pending["name"] + latency_ms = self._stop_timer(str(run_id)) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + self._emit("model.invoke", payload, run_id=run_id, parent_run_id=pending.get("parent_run_id")) + + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) + + # ------------------------------------------------------------------ + # Tool callbacks + # ------------------------------------------------------------------ def on_tool_start( self, @@ -176,8 +237,11 @@ def on_tool_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "tool") - self._emit_for_run("tool.call", {"name": name, "input": input_str}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", input_str) + self._emit("tool.call", payload, run_id=run_id, parent_run_id=parent_run_id) + @_auto_flush def on_tool_end( self, output: str, @@ -186,9 +250,11 @@ def on_tool_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("tool.result", {"output": output}, run_id) - self._maybe_flush(run_id) + payload = self._payload() + self._set_if_capturing(payload, "output", output) + self._emit("tool.result", payload, run_id=run_id) + @_auto_flush def on_tool_error( self, error: BaseException, @@ -197,10 +263,11 @@ def on_tool_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) - # -- Retriever -- + # ------------------------------------------------------------------ + # Retriever callbacks + # ------------------------------------------------------------------ def on_retriever_start( self, @@ -212,8 +279,11 @@ def on_retriever_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "retriever") - self._emit_for_run("tool.call", {"name": name, "input": query}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", query) + self._emit("tool.call", payload, run_id=run_id, parent_run_id=parent_run_id) + @_auto_flush def on_retriever_end( self, documents: Sequence[Any], @@ -222,10 +292,14 @@ def on_retriever_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - output = [_serialize_lc_document(d) for d in documents] - self._emit_for_run("tool.result", {"output": output}, run_id) - self._maybe_flush(run_id) + payload = self._payload() + self._set_if_capturing( + payload, "output", + [_serialize_lc_document(d) for d in documents], + ) + self._emit("tool.result", payload, run_id=run_id) + @_auto_flush def on_retriever_error( self, error: BaseException, @@ -234,10 +308,42 @@ def on_retriever_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) - # -- Text (required by base) -- + # ------------------------------------------------------------------ + # Agent callbacks + # ------------------------------------------------------------------ + + def on_agent_action( + self, + action: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + payload = self._payload(tool=getattr(action, "tool", "unknown")) + self._set_if_capturing(payload, "tool_input", getattr(action, "tool_input", None)) + self._set_if_capturing(payload, "log", getattr(action, "log", None) or None) + self._emit("agent.input", payload, run_id=run_id, parent_run_id=parent_run_id) + + @_auto_flush + def on_agent_finish( + self, + finish: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + payload = self._payload(status="ok") + self._set_if_capturing(payload, "output", getattr(finish, "return_values", None)) + self._set_if_capturing(payload, "log", getattr(finish, "log", None) or None) + self._emit("agent.output", payload, run_id=run_id, parent_run_id=parent_run_id) + + # ------------------------------------------------------------------ + # No-ops (required by base) + # ------------------------------------------------------------------ def on_text(self, text: str, **kwargs: Any) -> None: pass diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph.py b/src/layerlens/instrument/adapters/frameworks/langgraph.py index f4b666a..35de3c4 100644 --- a/src/layerlens/instrument/adapters/frameworks/langgraph.py +++ b/src/layerlens/instrument/adapters/frameworks/langgraph.py @@ -19,6 +19,9 @@ def on_chain_start( tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: + if parent_run_id is None: + run = self._begin_run() + run.data["root_run_id"] = str(run_id) serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] @@ -38,4 +41,6 @@ def on_chain_start( if node_name: name = node_name - self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", inputs) + self._emit("agent.input", payload, run_id=run_id, parent_run_id=parent_run_id) diff --git a/src/layerlens/instrument/adapters/frameworks/openai_agents.py b/src/layerlens/instrument/adapters/frameworks/openai_agents.py new file mode 100644 index 0000000..e175c34 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/openai_agents.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Dict, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._collector import TraceCollector +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +_HAS_OPENAI_AGENTS = False +try: + from agents.tracing import TracingProcessor # pyright: ignore[reportMissingImports] + + _HAS_OPENAI_AGENTS = True +except (ImportError, Exception): + TracingProcessor = None # type: ignore[assignment,misc] + +# Real TracingProcessor when installed, plain object otherwise. +_Base: Any = TracingProcessor if _HAS_OPENAI_AGENTS else object + + +class OpenAIAgentsAdapter(_Base, FrameworkAdapter): + """OpenAI Agents SDK adapter using the TracingProcessor API. + + The adapter *is* the trace processor — it registers itself globally + to receive all span lifecycle events, then maps agent, generation, + function, handoff, and guardrail spans to flat layerlens events. + + Unlike other adapters that use a single collector, this adapter manages + per-trace collectors because the SDK can run multiple concurrent traces + through the same global processor. + + Usage:: + + adapter = OpenAIAgentsAdapter(client) + adapter.connect() + result = await Runner.run(agent, "hello") + adapter.disconnect() + """ + + name = "openai-agents" + package = "openai-agents" + + _SPAN_HANDLERS = { + "agent": "_handle_agent_span", + "generation": "_handle_generation_span", + "function": "_handle_function_span", + "handoff": "_handle_handoff_span", + "guardrail": "_handle_guardrail_span", + "response": "_handle_response_span", + } + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + FrameworkAdapter.__init__(self, client, capture_config) + self._collectors: Dict[str, TraceCollector] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_HAS_OPENAI_AGENTS) + from agents import add_trace_processor # pyright: ignore[reportMissingImports] + + add_trace_processor(self) # type: ignore[arg-type] + + def _on_disconnect(self) -> None: + from agents import set_trace_processors # pyright: ignore[reportMissingImports] + + set_trace_processors([]) + with self._lock: + self._collectors.clear() + + # ------------------------------------------------------------------ + # TracingProcessor interface + # ------------------------------------------------------------------ + + def on_trace_start(self, trace: Any) -> None: + try: + self._get_collector(trace.trace_id) + except Exception: + log.warning("layerlens: error in on_trace_start", exc_info=True) + + def on_trace_end(self, trace: Any) -> None: + try: + with self._lock: + collector = self._collectors.pop(trace.trace_id, None) + if collector is not None: + collector.flush() + except Exception: + log.warning("layerlens: error in on_trace_end", exc_info=True) + + def on_span_start(self, span: Any) -> None: + pass + + def on_span_end(self, span: Any) -> None: + try: + span_type = getattr(span.span_data, "type", None) or "" + handler_name = self._SPAN_HANDLERS.get(span_type) + if handler_name is not None: + getattr(self, handler_name)(span) + except Exception: + log.warning("layerlens: error handling OpenAI Agents span", exc_info=True) + + def shutdown(self) -> None: + pass + + def force_flush(self) -> None: + pass + + # ------------------------------------------------------------------ + # Per-trace collector + # ------------------------------------------------------------------ + + def _get_collector(self, trace_id: str) -> TraceCollector: + with self._lock: + if trace_id not in self._collectors: + self._collectors[trace_id] = TraceCollector(self._client, self._config) + return self._collectors[trace_id] + + # ------------------------------------------------------------------ + # Span handlers + # ------------------------------------------------------------------ + + def _handle_agent_span(self, span: Any) -> None: + data = span.span_data + collector = self._get_collector(span.trace_id) + agent_name = getattr(data, "name", "unknown") + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + + input_payload = self._payload(agent_name=agent_name) + for key in ("tools", "handoffs", "output_type"): + val = getattr(data, key, None) + if val: + input_payload[key] = val + + collector.emit( + "agent.input", input_payload, + span_id=span_id, parent_span_id=parent_id, + span_name=f"agent:{agent_name}", + ) + + event_type = "agent.error" if span.error else "agent.output" + out_payload = self._payload( + agent_name=agent_name, + status="error" if span.error else "ok", + ) + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + out_payload["duration_ms"] = duration_ms + if span.error: + out_payload["error"] = safe_serialize(span.error) + + collector.emit( + event_type, out_payload, + span_id=span_id, parent_span_id=parent_id, + span_name=f"agent:{agent_name}", + ) + + def _handle_generation_span(self, span: Any) -> None: + data = span.span_data + collector = self._get_collector(span.trace_id) + model = getattr(data, "model", None) or "unknown" + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + + payload = self._payload(model=model) + tokens = self._normalize_tokens(getattr(data, "usage", None)) + payload.update(tokens) + + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + payload["latency_ms"] = duration_ms + + model_config = getattr(data, "model_config", None) + if model_config: + payload["model_config"] = safe_serialize(model_config) + + self._set_if_capturing(payload, "messages", safe_serialize(getattr(data, "input", None))) + self._set_if_capturing(payload, "output_message", safe_serialize(getattr(data, "output", None))) + + if span.error: + payload["error"] = safe_serialize(span.error) + collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + else: + collector.emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) + + if tokens: + cost_payload = self._payload(model=model) + cost_payload.update(tokens) + collector.emit("cost.record", cost_payload, span_id=span_id, parent_span_id=parent_id) + + def _handle_function_span(self, span: Any) -> None: + data = span.span_data + collector = self._get_collector(span.trace_id) + tool_name = getattr(data, "name", "unknown") + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(data, "input", None))) + self._set_if_capturing(payload, "output", safe_serialize(getattr(data, "output", None))) + + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + payload["latency_ms"] = duration_ms + + mcp_data = getattr(data, "mcp_data", None) + if mcp_data: + payload["mcp_data"] = safe_serialize(mcp_data) + + if span.error: + payload["error"] = safe_serialize(span.error) + collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + else: + collector.emit("tool.call", payload, span_id=span_id, parent_span_id=parent_id) + + def _handle_handoff_span(self, span: Any) -> None: + data = span.span_data + self._get_collector(span.trace_id).emit( + "agent.handoff", + self._payload( + from_agent=getattr(data, "from_agent", None) or "unknown", + to_agent=getattr(data, "to_agent", None) or "unknown", + ), + span_id=span.span_id or self._new_span_id(), + parent_span_id=span.parent_id, + ) + + def _handle_guardrail_span(self, span: Any) -> None: + data = span.span_data + self._get_collector(span.trace_id).emit( + "evaluation.result", + self._payload( + guardrail_name=getattr(data, "name", "unknown"), + triggered=getattr(data, "triggered", False), + ), + span_id=span.span_id or self._new_span_id(), + parent_span_id=span.parent_id, + ) + + def _handle_response_span(self, span: Any) -> None: + data = span.span_data + response = getattr(data, "response", None) + if response is None: + return + + collector = self._get_collector(span.trace_id) + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + payload = self._payload() + + model = getattr(response, "model", None) + if model: + payload["model"] = model + + usage = getattr(response, "usage", None) + tokens = self._normalize_tokens(usage) + # OpenAI-specific detailed token breakdowns + if usage is not None: + input_details = getattr(usage, "input_tokens_details", None) + if input_details: + cached = getattr(input_details, "cached_tokens", 0) or 0 + if cached: + tokens["cached_tokens"] = cached + output_details = getattr(usage, "output_tokens_details", None) + if output_details: + reasoning = getattr(output_details, "reasoning_tokens", 0) or 0 + if reasoning: + tokens["reasoning_tokens"] = reasoning + payload.update(tokens) + + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + payload["latency_ms"] = duration_ms + + if span.error: + payload["error"] = safe_serialize(span.error) + collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + else: + collector.emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _compute_duration_ms(span: Any) -> Optional[float]: + started = getattr(span, "started_at", None) + ended = getattr(span, "ended_at", None) + if started is None or ended is None: + return None + try: + if isinstance(started, str): + started = datetime.fromisoformat(started) + if isinstance(ended, str): + ended = datetime.fromisoformat(ended) + return (ended - started).total_seconds() * 1000 + except Exception: + return None diff --git a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py new file mode 100644 index 0000000..b5ae173 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +try: + from pydantic_ai import Agent as _AgentCheck # pyright: ignore[reportMissingImports] # noqa: F401 + + _HAS_PYDANTIC_AI = True + del _AgentCheck +except ImportError: + _HAS_PYDANTIC_AI = False + + +class PydanticAIAdapter(FrameworkAdapter): + """PydanticAI adapter using the native Hooks capability API. + + Injects a ``Hooks`` capability into the target agent to receive + real-time lifecycle callbacks for run start/end, per-model-call, + and per-tool-execution events — with precise per-step timing. + + Concurrent runs on the same agent are safe: each run gets its own + RunState via ContextVar, so collectors, timers, and tool spans + are fully isolated per ``asyncio.Task``. + + Usage:: + + adapter = PydanticAIAdapter(client) + adapter.connect(target=agent) # injects hooks capability + result = agent.run_sync("hello") + adapter.disconnect() # removes hooks capability + """ + + name = "pydantic-ai" + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + super().__init__(client, capture_config) + self._target: Any = None + self._hooks: Any = None + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_HAS_PYDANTIC_AI) + if target is None: + raise ValueError("PydanticAIAdapter requires a target agent: adapter.connect(target=agent)") + + from pydantic_ai.capabilities.hooks import Hooks # pyright: ignore[reportMissingImports] + + self._target = target + self._hooks = Hooks() + self._register_hooks(self._hooks) + target._root_capability.capabilities.append(self._hooks) + + def _on_disconnect(self) -> None: + if self._target is not None and self._hooks is not None: + try: + caps = self._target._root_capability.capabilities + if self._hooks in caps: + caps.remove(self._hooks) + except Exception: + log.warning("Could not remove PydanticAI hooks capability") + self._hooks = None + self._target = None + + # ------------------------------------------------------------------ + # Hook registration + # ------------------------------------------------------------------ + + def _register_hooks(self, hooks: Any) -> None: + hooks.on.before_run(self._on_before_run) + hooks.on.after_run(self._on_after_run) + hooks.on.run_error(self._on_run_error) + hooks.on.after_model_request(self._on_after_model_request) + hooks.on.model_request_error(self._on_model_request_error) + hooks.on.before_tool_execute(self._on_before_tool_execute) + hooks.on.after_tool_execute(self._on_after_tool_execute) + hooks.on.tool_execute_error(self._on_tool_execute_error) + + # ------------------------------------------------------------------ + # Run lifecycle hooks + # ------------------------------------------------------------------ + + def _on_before_run(self, ctx: Any) -> None: + run = self._begin_run() + agent_name = self._get_agent_name(ctx) + model_name = self._get_model_name(ctx) + + payload = self._payload(agent_name=agent_name) + if model_name: + payload["model"] = model_name + self._set_if_capturing(payload, "input", safe_serialize(ctx.prompt)) + + run.collector.emit( + "agent.input", payload, + span_id=run.root_span_id, parent_span_id=None, + span_name=f"pydantic_ai:{agent_name}", + ) + self._start_timer("run") + + def _on_after_run(self, ctx: Any, *, result: Any) -> Any: + latency_ms = self._stop_timer("run") + agent_name = self._get_agent_name(ctx) + model_name = self._get_model_name(ctx) + root_span = self._get_root_span() + collector = self._ensure_collector() + + output = self._extract_output(result) + usage = self._extract_usage(result) + + payload = self._payload(agent_name=agent_name, status="ok") + if model_name: + payload["model"] = model_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + self._set_if_capturing(payload, "output", output) + payload.update(usage) + collector.emit( + "agent.output", payload, + span_id=root_span, parent_span_id=None, + span_name=f"pydantic_ai:{agent_name}", + ) + + if usage: + cost_payload = self._payload() + if model_name: + cost_payload["model"] = model_name + cost_payload.update(usage) + collector.emit( + "cost.record", cost_payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + + self._end_run() + return result + + def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: + latency_ms = self._stop_timer("run") + agent_name = self._get_agent_name(ctx) + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload( + agent_name=agent_name, + error=str(error), + error_type=type(error).__name__, + ) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + collector.emit( + "agent.error", payload, + span_id=root_span, parent_span_id=None, + span_name=f"pydantic_ai:{agent_name}", + ) + + self._end_run() + raise error + + # ------------------------------------------------------------------ + # Model request hooks + # ------------------------------------------------------------------ + + def _on_after_model_request( + self, ctx: Any, *, request_context: Any, response: Any, + ) -> Any: + root_span = self._get_root_span() + collector = self._ensure_collector() + + model_name = getattr(response, "model_name", None) + usage = getattr(response, "usage", None) + tokens = self._normalize_tokens(usage) + + payload = self._payload() + if model_name: + payload["model"] = model_name + payload.update(tokens) + + model_span = self._new_span_id() + collector.emit( + "model.invoke", payload, + span_id=model_span, parent_span_id=root_span, + ) + + parts = getattr(response, "parts", None) or [] + for part in parts: + if type(part).__name__ == "ToolCallPart": + tool_name = getattr(part, "tool_name", "unknown") + tool_payload = self._payload(tool_name=tool_name) + self._set_if_capturing( + tool_payload, "input", + safe_serialize(getattr(part, "args", None)), + ) + collector.emit( + "tool.call", tool_payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + + return response + + def _on_model_request_error( + self, ctx: Any, *, request_context: Any, error: Exception, + ) -> None: + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload( + error=str(error), + error_type=type(error).__name__, + ) + collector.emit( + "agent.error", payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + raise error + + # ------------------------------------------------------------------ + # Tool execution hooks + # ------------------------------------------------------------------ + + def _on_before_tool_execute( + self, ctx: Any, *, call: Any, tool_def: Any, args: Any, + ) -> Any: + tool_name = getattr(call, "tool_name", "unknown") + span_id = self._new_span_id() + run = self._get_run() + if run is not None: + run.data.setdefault("tool_spans", {})[tool_name] = span_id + self._start_timer(f"tool:{tool_name}") + return args + + def _on_after_tool_execute( + self, ctx: Any, *, call: Any, tool_def: Any, args: Any, result: Any, + ) -> Any: + tool_name = getattr(call, "tool_name", "unknown") + latency_ms = self._stop_timer(f"tool:{tool_name}") + + run = self._get_run() + tool_spans = run.data.get("tool_spans", {}) if run is not None else {} + span_id = tool_spans.pop(tool_name, self._new_span_id()) + + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "output", safe_serialize(result)) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + collector.emit( + "tool.result", payload, + span_id=span_id, parent_span_id=root_span, + ) + return result + + def _on_tool_execute_error( + self, ctx: Any, *, call: Any, tool_def: Any, args: Any, error: Exception, + ) -> None: + tool_name = getattr(call, "tool_name", "unknown") + self._stop_timer(f"tool:{tool_name}") + + run = self._get_run() + if run is not None: + run.data.get("tool_spans", {}).pop(tool_name, None) + + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload( + tool_name=tool_name, + error=str(error), + error_type=type(error).__name__, + ) + collector.emit( + "agent.error", payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + raise error + + # ------------------------------------------------------------------ + # Static helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_agent_name(ctx: Any) -> str: + agent = getattr(ctx, "agent", None) + if agent is not None: + name = getattr(agent, "name", None) + if name: + return str(name) + return PydanticAIAdapter._get_model_name(ctx) or "pydantic_ai_agent" + + @staticmethod + def _get_model_name(ctx: Any) -> Optional[str]: + model = getattr(ctx, "model", None) + if model is None: + agent = getattr(ctx, "agent", None) + model = getattr(agent, "model", None) if agent else None + if model is None: + return None + if isinstance(model, str): + return model + name = getattr(model, "model_name", None) + if name: + return str(name) + return str(model) + + @staticmethod + def _extract_output(result: Any) -> Any: + if result is None: + return None + output = getattr(result, "output", None) + if output is not None: + return safe_serialize(output) + return None + + @staticmethod + def _extract_usage(result: Any) -> Dict[str, Any]: + tokens: Dict[str, Any] = {} + usage = getattr(result, "usage", None) + if usage is None: + return tokens + + if callable(usage): + try: + usage = usage() + except Exception: + return tokens + + input_t = getattr(usage, "input_tokens", 0) or 0 + output_t = getattr(usage, "output_tokens", 0) or 0 + + if input_t: + tokens["tokens_prompt"] = input_t + if output_t: + tokens["tokens_completion"] = output_t + if input_t or output_t: + tokens["tokens_total"] = input_t + output_t + + requests = getattr(usage, "requests", 0) or 0 + if requests: + tokens["model_requests"] = requests + + return tokens diff --git a/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py b/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py new file mode 100644 index 0000000..f02fecd --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize, truncate +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +try: + import semantic_kernel as _sk # pyright: ignore[reportMissingImports] # noqa: F401 + + _HAS_SEMANTIC_KERNEL = True +except ImportError: + _HAS_SEMANTIC_KERNEL = False + + +class SemanticKernelAdapter(FrameworkAdapter): + """Semantic Kernel adapter using the SK filter API (semantic-kernel >= 1.0). + + Registers function invocation, prompt rendering, and auto-function + invocation filters on a Kernel instance to capture plugin calls, + prompt templates, and LLM-initiated function calls as flat events. + + Usage:: + + adapter = SemanticKernelAdapter(client) + adapter.connect(target=kernel) + result = await kernel.invoke(my_function, arg1=val1) + adapter.disconnect() + """ + + name = "semantic_kernel" + package = "semantic-kernel" + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + super().__init__(client, capture_config) + self._kernel: Any = None + self._filter_ids: List[tuple] = [] # (FilterTypes, filter_id) for removal + self._seen_plugins: set = set() + self._patched_services: Dict[str, Any] = {} # service_id → original method + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_HAS_SEMANTIC_KERNEL) + if target is None: + raise ValueError("SemanticKernelAdapter requires a target kernel: adapter.connect(target=kernel)") + + from semantic_kernel.filters.filter_types import FilterTypes # pyright: ignore[reportMissingImports] + + self._kernel = target + + filters = [ + (FilterTypes.FUNCTION_INVOCATION, self._function_invocation_filter), + (FilterTypes.PROMPT_RENDERING, self._prompt_rendering_filter), + (FilterTypes.AUTO_FUNCTION_INVOCATION, self._auto_function_invocation_filter), + ] + for filter_type, handler in filters: + target.add_filter(filter_type, handler) + filter_list = _get_filter_list(target, filter_type) + if filter_list: + self._filter_ids.append((filter_type, filter_list[-1][0])) + + # Wrap LLM calls on registered chat services + self._patch_chat_services(target) + + # Discover existing plugins + self._discover_plugins(target) + + def _on_disconnect(self) -> None: + if self._kernel is not None: + for filter_type, filter_id in self._filter_ids: + try: + self._kernel.remove_filter(filter_type, filter_id=filter_id) + except Exception: + log.debug("layerlens: could not remove SK filter %s/%s", filter_type, filter_id) + self._unpatch_chat_services() + self._filter_ids.clear() + self._seen_plugins.clear() + self._kernel = None + + # ------------------------------------------------------------------ + # LLM call wrapping + # ------------------------------------------------------------------ + + def _patch_chat_services(self, kernel: Any) -> None: + """Wrap _inner_get_chat_message_contents on all registered chat services.""" + services = getattr(kernel, "services", None) + if not services or not isinstance(services, dict): + return + + for service_id, service in services.items(): + if not hasattr(service, "_inner_get_chat_message_contents"): + continue + original = service._inner_get_chat_message_contents + adapter = self + + async def _traced_inner(chat_history: Any, settings: Any, _orig: Any = original, _svc: Any = service) -> Any: + span_id = adapter._new_span_id() + root_span = adapter._get_root_span() + adapter._start_timer(span_id) + collector = adapter._ensure_collector() + + model_name = getattr(_svc, "ai_model_id", None) + + try: + result = await _orig(chat_history, settings) + except Exception as exc: + latency_ms = adapter._stop_timer(span_id) + payload = adapter._payload( + error=str(exc), + error_type=type(exc).__name__, + ) + if model_name: + payload["model"] = model_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + collector.emit( + "agent.error", payload, + span_id=span_id, parent_span_id=root_span, + ) + raise + + latency_ms = adapter._stop_timer(span_id) + tokens = adapter._extract_usage_from_response(result) + + payload = adapter._payload() + if model_name: + payload["model"] = model_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + payload.update(tokens) + collector.emit( + "model.invoke", payload, + span_id=span_id, parent_span_id=root_span, + ) + + if tokens: + cost_payload = adapter._payload() + if model_name: + cost_payload["model"] = model_name + cost_payload.update(tokens) + collector.emit( + "cost.record", cost_payload, + span_id=span_id, parent_span_id=root_span, + ) + + return result + + service._inner_get_chat_message_contents = _traced_inner + self._patched_services[service_id] = original + + def _unpatch_chat_services(self) -> None: + """Restore original _inner_get_chat_message_contents on all patched services.""" + if self._kernel is not None: + services = getattr(self._kernel, "services", {}) + for service_id, original in self._patched_services.items(): + service = services.get(service_id) + if service is not None: + try: + service._inner_get_chat_message_contents = original + except Exception: + log.debug("layerlens: could not restore SK chat service %s", service_id) + self._patched_services.clear() + + def _extract_usage_from_response(self, result: Any) -> Dict[str, Any]: + """Extract token usage from ChatMessageContent list returned by _inner_get_chat_message_contents.""" + if not result: + return {} + msg = result[0] if isinstance(result, list) else result + metadata = getattr(msg, "metadata", None) + if not metadata or not isinstance(metadata, dict): + return {} + return self._normalize_tokens(metadata.get("usage")) + + # ------------------------------------------------------------------ + # Plugin discovery + # ------------------------------------------------------------------ + + def _discover_plugins(self, kernel: Any) -> None: + try: + plugins = getattr(kernel, "plugins", None) + if plugins is None: + return + names = list(plugins.keys()) if hasattr(plugins, "keys") else [str(p) for p in plugins] + collector = self._ensure_collector() + for name in names: + if name not in self._seen_plugins: + self._seen_plugins.add(name) + collector.emit( + "environment.config", + self._payload(plugin_name=name, event_subtype="plugin_registered"), + span_id=self._new_span_id(), + parent_span_id=self._get_root_span(), + ) + except Exception: + log.debug("layerlens: error discovering SK plugins", exc_info=True) + + def _maybe_discover_plugin(self, plugin_name: str) -> None: + if not plugin_name or plugin_name in self._seen_plugins: + return + with self._lock: + if plugin_name in self._seen_plugins: + return + self._seen_plugins.add(plugin_name) + collector = self._ensure_collector() + collector.emit( + "environment.config", + self._payload(plugin_name=plugin_name, event_subtype="plugin_registered"), + span_id=self._new_span_id(), + parent_span_id=self._get_root_span(), + ) + + # ------------------------------------------------------------------ + # Shared filter logic + # ------------------------------------------------------------------ + + async def _wrap_invocation( + self, + context: Any, + next: Any, + *, + auto_invoked: bool = False, + ) -> None: + """Shared wrap-and-emit logic for function and auto-function filters. + + Emits tool.call on start, tool.result on success (or agent.error on failure), + with timing. The ``auto_invoked`` flag adds LLM-specific metadata. + """ + plugin_name = _extract_plugin_name(context) + function_name = _extract_function_name(context) + tool_name = f"{plugin_name}.{function_name}" if plugin_name else function_name + + self._maybe_discover_plugin(plugin_name) + + span_id = self._new_span_id() + root_span = self._get_root_span() + self._start_timer(span_id) + collector = self._ensure_collector() + + # -- Emit tool.call (start) -- + call_payload = self._payload( + tool_name=tool_name, + plugin_name=plugin_name, + function_name=function_name, + ) + if auto_invoked: + call_payload["auto_invoked"] = True + call_payload["request_sequence_index"] = getattr(context, "request_sequence_index", 0) + call_payload["function_sequence_index"] = getattr(context, "function_sequence_index", 0) + # Auto-invoked: args come from the LLM's function_call_content + call_content = getattr(context, "function_call_content", None) + if call_content: + self._set_if_capturing( + call_payload, "input", + safe_serialize(getattr(call_content, "arguments", None)), + ) + else: + # User-invoked: args come from context.arguments + self._set_if_capturing( + call_payload, "input", + safe_serialize(_extract_arguments(context)), + ) + + collector.emit( + "tool.call", call_payload, + span_id=span_id, parent_span_id=root_span, + span_name=f"sk:{tool_name}", + ) + + # -- Execute -- + error = None + try: + await next(context) + except Exception as exc: + error = exc + raise + finally: + latency_ms = self._stop_timer(span_id) + + if error: + err_payload = self._payload( + tool_name=tool_name, + error=str(error), + error_type=type(error).__name__, + ) + if auto_invoked: + err_payload["auto_invoked"] = True + if latency_ms is not None: + err_payload["latency_ms"] = latency_ms + collector.emit( + "agent.error", err_payload, + span_id=span_id, parent_span_id=root_span, + ) + else: + # Extract result from the appropriate field + if auto_invoked: + func_result = getattr(context, "function_result", None) + else: + func_result = getattr(context, "result", None) + result_value = getattr(func_result, "value", None) if func_result else None + + result_payload = self._payload( + tool_name=tool_name, + status="ok", + ) + if auto_invoked: + result_payload["auto_invoked"] = True + if latency_ms is not None: + result_payload["latency_ms"] = latency_ms + self._set_if_capturing(result_payload, "output", safe_serialize(result_value)) + collector.emit( + "tool.result", result_payload, + span_id=span_id, parent_span_id=root_span, + span_name=f"sk:{tool_name}", + ) + + # ------------------------------------------------------------------ + # Filters + # ------------------------------------------------------------------ + + async def _function_invocation_filter(self, context: Any, next: Any) -> None: + await self._wrap_invocation(context, next, auto_invoked=False) + + async def _prompt_rendering_filter(self, context: Any, next: Any) -> None: + await next(context) + + function_name = _extract_function_name(context) + rendered = getattr(context, "rendered_prompt", None) + + payload = self._payload(event_subtype="prompt_render") + if function_name: + payload["function_name"] = function_name + if rendered and self._config.capture_content: + payload["rendered_prompt"] = truncate(str(rendered), 2000) + + collector = self._ensure_collector() + collector.emit( + "agent.code", payload, + span_id=self._new_span_id(), parent_span_id=self._get_root_span(), + ) + + async def _auto_function_invocation_filter(self, context: Any, next: Any) -> None: + await self._wrap_invocation(context, next, auto_invoked=True) + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _get_filter_list(kernel: Any, filter_type: Any) -> list: + name = filter_type.value if hasattr(filter_type, "value") else str(filter_type) + attr_map = { + "function_invocation": "function_invocation_filters", + "prompt_rendering": "prompt_rendering_filters", + "auto_function_invocation": "auto_function_invocation_filters", + } + return getattr(kernel, attr_map.get(name, ""), []) + + +def _extract_plugin_name(context: Any) -> str: + fn = getattr(context, "function", None) + if fn is not None: + return getattr(fn, "plugin_name", "") or "" + return getattr(context, "plugin_name", "") or "" + + +def _extract_function_name(context: Any) -> str: + fn = getattr(context, "function", None) + if fn is not None: + return getattr(fn, "name", "") or "" + return getattr(context, "function_name", "") or "" + + +def _extract_arguments(context: Any) -> Optional[Dict[str, Any]]: + args = getattr(context, "arguments", None) + if args is None: + return None + if isinstance(args, dict): + return args + if hasattr(args, "items"): + return dict(args.items()) + return None diff --git a/tests/instrument/adapters/frameworks/test_crewai.py b/tests/instrument/adapters/frameworks/test_crewai.py new file mode 100644 index 0000000..3b914a5 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_crewai.py @@ -0,0 +1,808 @@ +"""Tests for CrewAI adapter using real CrewAI event bus. + +These tests exercise the real crewai.events module — no mocking of CrewAI +internals. Events are constructed and emitted on the real event bus, and +we verify the correct layerlens events come out. + +Requires crewai >= 1.0.0 (Python >= 3.10). +""" + +from __future__ import annotations + +import datetime + +import pytest + +from .conftest import capture_framework_trace, find_event, find_events + +# Skip entire module if crewai is not importable (Python < 3.10 or not installed). +# crewai uses `type | None` syntax which causes TypeError on Python < 3.10, +# and importorskip only catches ImportError, so we guard explicitly. +import sys +if sys.version_info < (3, 10): + pytest.skip("crewai requires Python >= 3.10", allow_module_level=True) +try: + import crewai # noqa: F401 +except (ImportError, TypeError): + pytest.skip("crewai not installed or incompatible", allow_module_level=True) + +from crewai.events import ( # noqa: E402 + TaskFailedEvent, + TaskStartedEvent, + LLMCallFailedEvent, + TaskCompletedEvent, + ToolUsageErrorEvent, + ToolUsageStartedEvent, + LLMCallCompletedEvent, + CrewKickoffFailedEvent, + ToolUsageFinishedEvent, + CrewKickoffStartedEvent, + CrewKickoffCompletedEvent, + AgentExecutionErrorEvent, + AgentExecutionStartedEvent, + AgentExecutionCompletedEvent, + crewai_event_bus, # noqa: E402 +) +from crewai.tasks.task_output import TaskOutput # noqa: E402 + +from layerlens.instrument.adapters.frameworks.crewai import CrewAIAdapter # noqa: E402 + + +@pytest.fixture +def adapter_and_trace(mock_client): + """Create a connected CrewAI adapter with trace capture.""" + uploaded = capture_framework_trace(mock_client) + adapter = CrewAIAdapter(mock_client) + with crewai_event_bus.scoped_handlers(): + adapter.connect() + yield adapter, uploaded + adapter.disconnect() + + +class TestCrewAIAdapterLifecycle: + def test_connect_sets_connected(self, mock_client): + adapter = CrewAIAdapter(mock_client) + assert not adapter.is_connected + with crewai_event_bus.scoped_handlers(): + adapter.connect() + assert adapter.is_connected + adapter.disconnect() + assert not adapter.is_connected + + def test_adapter_info(self, mock_client): + adapter = CrewAIAdapter(mock_client) + with crewai_event_bus.scoped_handlers(): + adapter.connect() + info = adapter.adapter_info() + assert info.name == "crewai" + assert info.adapter_type == "framework" + assert info.connected is True + adapter.disconnect() + + def test_disconnect_clears_state(self, mock_client): + adapter = CrewAIAdapter(mock_client) + with crewai_event_bus.scoped_handlers(): + adapter.connect() + adapter.disconnect() + assert adapter._collector is None + assert adapter._crew_span_id is None + assert adapter._task_span_ids == {} + + +class TestCrewKickoff: + def test_crew_start_emits_agent_input(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + evt = CrewKickoffStartedEvent(crew_name="Research Crew", inputs={"topic": "AI"}) + adapter._on_crew_started(None, evt) + # Crew completed triggers flush + to = TaskOutput(description="test", raw="done", agent="R") + completed = CrewKickoffCompletedEvent(crew_name="Research Crew", output=to) + adapter._on_crew_completed(None, completed) + + events = uploaded["events"] + agent_in = find_event(events, "agent.input") + assert agent_in["payload"]["crew_name"] == "Research Crew" + assert agent_in["payload"]["input"] == {"topic": "AI"} + assert agent_in["payload"]["framework"] == "crewai" + + def test_crew_completed_emits_agent_output(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + start = CrewKickoffStartedEvent(crew_name="MyCrew", inputs={}) + adapter._on_crew_started(None, start) + + to = TaskOutput(description="test", raw="final answer", agent="R") + completed = CrewKickoffCompletedEvent(crew_name="MyCrew", output=to, total_tokens=500) + adapter._on_crew_completed(None, completed) + + events = uploaded["events"] + agent_out = find_event(events, "agent.output") + assert agent_out["payload"]["crew_name"] == "MyCrew" + assert agent_out["payload"]["duration_ns"] > 0 + assert agent_out["payload"]["tokens_total"] == 500 + + # Should also emit cost.record for total_tokens + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] == 500 + + def test_crew_failed_emits_agent_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + start = CrewKickoffStartedEvent(crew_name="FailCrew", inputs={}) + adapter._on_crew_started(None, start) + + failed = CrewKickoffFailedEvent(crew_name="FailCrew", error="LLM rate limit exceeded") + adapter._on_crew_failed(None, failed) + + events = uploaded["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "LLM rate limit exceeded" + assert error["payload"]["crew_name"] == "FailCrew" + + def test_crew_lifecycle_flushes_trace(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + start = CrewKickoffStartedEvent(crew_name="FlushCrew", inputs={}) + adapter._on_crew_started(None, start) + + to = TaskOutput(description="t", raw="ok", agent="R") + completed = CrewKickoffCompletedEvent(crew_name="FlushCrew", output=to) + adapter._on_crew_completed(None, completed) + + assert uploaded["trace_id"] is not None + assert len(uploaded["events"]) >= 2 + assert uploaded["attestation"] is not None + # Collector should be reset after flush + assert adapter._collector is None + + +class TestTaskEvents: + def test_task_start_and_complete(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + # Start crew + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # Task lifecycle + adapter._on_task_started( + None, TaskStartedEvent(context="research context", task_name="Research Task", agent_role="Researcher") + ) + to = TaskOutput(description="Research Task", raw="found it", agent="Researcher") + adapter._on_task_completed(None, TaskCompletedEvent(output=to, task_name="Research Task")) + + # Flush + to2 = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to2)) + + events = uploaded["events"] + # Should have crew agent.input, task agent.input, task agent.output, crew agent.output + agent_inputs = find_events(events, "agent.input") + assert len(agent_inputs) == 2 # crew + task + task_input = [e for e in agent_inputs if e["payload"].get("task_name")] + assert len(task_input) == 1 + assert task_input[0]["payload"]["task_name"] == "Research Task" + assert task_input[0]["payload"]["agent_role"] == "Researcher" + + # Task events should be children of crew span + crew_span_id = agent_inputs[0]["span_id"] + assert task_input[0]["parent_span_id"] == crew_span_id + + def test_task_failed(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="Bad Task")) + adapter._on_task_failed(None, TaskFailedEvent(error="task timeout", task_name="Bad Task")) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="task failed")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + task_error = [e for e in errors if e["payload"].get("task_name")] + assert len(task_error) == 1 + assert task_error[0]["payload"]["error"] == "task timeout" + + +class TestLLMEvents: + def test_llm_completed_emits_model_invoke(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # LLM call with token usage in response + response = {"content": "hello", "usage": {"prompt_tokens": 100, "completion_tokens": 50}} + evt = LLMCallCompletedEvent(model="gpt-4o", call_id="call_1", call_type="llm_call", response=response) + adapter._on_llm_completed(None, evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["model"] == "gpt-4o" + assert model_invoke["payload"]["tokens_prompt"] == 100 + assert model_invoke["payload"]["tokens_completion"] == 50 + assert model_invoke["payload"]["tokens_total"] == 150 + + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] == 150 + + def test_llm_failed_emits_agent_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + evt = LLMCallFailedEvent(model="gpt-4o", call_id="call_1", error="rate limit exceeded") + adapter._on_llm_failed(None, evt) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="llm fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + llm_error = [e for e in errors if e["payload"].get("model")] + assert len(llm_error) == 1 + assert llm_error[0]["payload"]["error"] == "rate limit exceeded" + assert llm_error[0]["payload"]["model"] == "gpt-4o" + + +class TestToolEvents: + def test_tool_started_emits_tool_call(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + started_evt = ToolUsageStartedEvent( + tool_name="web_search", + tool_args="AI safety research", + agent_key="researcher_1", + ) + adapter._on_tool_started(None, started_evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["tool_name"] == "web_search" + assert tool_call["payload"]["input"] == "AI safety research" + + def test_tool_finished_emits_tool_result(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + now = datetime.datetime.now() + later = now + datetime.timedelta(milliseconds=150) + evt = ToolUsageFinishedEvent( + tool_name="web_search", + tool_args="AI safety research", + started_at=now, + finished_at=later, + output="Found 10 results about AI safety", + ) + adapter._on_tool_finished(None, evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["tool_name"] == "web_search" + assert tool_result["payload"]["output"] == "Found 10 results about AI safety" + assert tool_result["payload"]["latency_ms"] == pytest.approx(150, abs=5) + + def test_tool_start_end_share_span_id(self, adapter_and_trace): + """tool.call and tool.result for the same tool use share a span_id.""" + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + started_evt = ToolUsageStartedEvent( + tool_name="calculator", + tool_args="2+2", + agent_key="math_agent_1", + ) + adapter._on_tool_started(None, started_evt) + + now = datetime.datetime.now() + finished_evt = ToolUsageFinishedEvent( + tool_name="calculator", + tool_args="2+2", + agent_key="math_agent_1", + started_at=now, + finished_at=now, + output="4", + ) + adapter._on_tool_finished(None, finished_evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + tool_result = find_event(events, "tool.result") + assert tool_call["span_id"] == tool_result["span_id"] + + def test_tool_from_cache(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + now = datetime.datetime.now() + evt = ToolUsageFinishedEvent( + tool_name="cached_tool", + tool_args="query", + started_at=now, + finished_at=now, + output="cached result", + from_cache=True, + ) + adapter._on_tool_finished(None, evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["from_cache"] is True + + def test_tool_error_emits_agent_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + evt = ToolUsageErrorEvent(tool_name="calculator", tool_args="1/0", error="division by zero") + adapter._on_tool_error(None, evt) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="tool fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + tool_error = [e for e in errors if e["payload"].get("tool_name")] + assert len(tool_error) == 1 + assert tool_error[0]["payload"]["tool_name"] == "calculator" + assert tool_error[0]["payload"]["error"] == "division by zero" + + +class TestFullCrewLifecycle: + """End-to-end test simulating a complete crew run with multiple tasks.""" + + def test_full_crew_with_tasks_and_llm(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + # 1. Crew starts + adapter._on_crew_started( + None, CrewKickoffStartedEvent(crew_name="Analysis Crew", inputs={"topic": "quantum computing"}) + ) + + # 2. Task 1: Research + adapter._on_task_started( + None, TaskStartedEvent(context="research quantum computing", task_name="Research", agent_role="Researcher") + ) + + # 2a. Agent execution starts within task 1 + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Researcher", task_prompt="Research quantum computing") + ) + + # 3. LLM call within task 1 + response = {"content": "Quantum computing uses qubits...", "usage": {"prompt_tokens": 200, "completion_tokens": 100}} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="claude-3-opus", call_id="c1", call_type="llm_call", response=response) + ) + + # 4. Tool use within task 1 (start + finish) + now = datetime.datetime.now() + adapter._on_tool_started( + None, + ToolUsageStartedEvent(tool_name="arxiv_search", tool_args="quantum computing 2024", agent_key="researcher_1"), + ) + adapter._on_tool_finished( + None, + ToolUsageFinishedEvent( + tool_name="arxiv_search", + tool_args="quantum computing 2024", + agent_key="researcher_1", + started_at=now, + finished_at=now, + output="3 papers found", + ), + ) + + # 4a. Agent execution completes + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="Researcher", output="Research complete") + ) + + # 5. Task 1 completes + to1 = TaskOutput(description="Research", raw="Research complete", agent="Researcher") + adapter._on_task_completed(None, TaskCompletedEvent(output=to1, task_name="Research")) + + # 6. Task 2: Writing + adapter._on_task_started( + None, TaskStartedEvent(context="write about quantum computing", task_name="Write Report", agent_role="Writer") + ) + + # 6a. Agent execution starts within task 2 + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Writer", task_prompt="Write the report") + ) + + # 7. Another LLM call + response2 = {"content": "Final report..."} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="gpt-4o", call_id="c2", call_type="llm_call", response=response2) + ) + + # 7a. Agent execution completes + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="Writer", output="Report written") + ) + + # 8. Task 2 completes + to2 = TaskOutput(description="Write Report", raw="Report written", agent="Writer") + adapter._on_task_completed(None, TaskCompletedEvent(output=to2, task_name="Write Report")) + + # 9. Crew completes + final = TaskOutput(description="final", raw="All done", agent="Writer") + adapter._on_crew_completed( + None, CrewKickoffCompletedEvent(crew_name="Analysis Crew", output=final, total_tokens=1500) + ) + + # Verify full event trace + events = uploaded["events"] + assert uploaded["trace_id"] is not None + + # Count event types + agent_inputs = find_events(events, "agent.input") + agent_outputs = find_events(events, "agent.output") + model_invokes = find_events(events, "model.invoke") + tool_calls = find_events(events, "tool.call") + tool_results = find_events(events, "tool.result") + cost_records = find_events(events, "cost.record") + + # crew + 2 tasks + 2 agent executions = 5 agent.input events + assert len(agent_inputs) == 5 + # crew + 2 tasks + 2 agent executions = 5 agent.output events + assert len(agent_outputs) == 5 + assert len(model_invokes) == 2 # 2 LLM calls + assert len(tool_calls) == 1 # 1 tool.call (started) + assert len(tool_results) == 1 # 1 tool.result (finished) + assert len(cost_records) >= 1 # at least crew total_tokens + + # Verify span hierarchy: tasks are children of crew + crew_span = agent_inputs[0]["span_id"] + task_inputs = [e for e in agent_inputs if e["payload"].get("task_name")] + for task_event in task_inputs: + assert task_event["parent_span_id"] == crew_span + + # Verify all events share the same trace_id + trace_ids = {e["trace_id"] for e in events} + assert len(trace_ids) == 1 + + # Verify sequence ordering + sequence_ids = [e["sequence_id"] for e in events] + assert sequence_ids == sorted(sequence_ids) + + # Verify attestation was built + assert uploaded["attestation"].get("root_hash") is not None + + +class TestEventBusIntegration: + """Test that the adapter actually receives events through the real CrewAI event bus.""" + + def test_events_flow_through_bus(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = CrewAIAdapter(mock_client) + + with crewai_event_bus.scoped_handlers(): + adapter.connect() + + # Emit events on the real bus — adapter should pick them up. + # Flush between events so the async started-handler completes + # before completed-handler triggers _flush() (which resets state). + crewai_event_bus.emit(None, event=CrewKickoffStartedEvent(crew_name="BusCrew", inputs={"x": 1})) + crewai_event_bus.flush(timeout=5.0) + + to = TaskOutput(description="t", raw="bus result", agent="A") + crewai_event_bus.emit(None, event=CrewKickoffCompletedEvent(crew_name="BusCrew", output=to)) + crewai_event_bus.flush(timeout=5.0) + + events = uploaded["events"] + assert len(events) >= 2 + + agent_in = find_event(events, "agent.input") + assert agent_in["payload"]["crew_name"] == "BusCrew" + + agent_out = find_event(events, "agent.output") + assert agent_out["payload"]["crew_name"] == "BusCrew" + + def test_scoped_handlers_cleanup(self, mock_client): + """Verify that scoped_handlers prevents handler leaks between tests.""" + uploaded = capture_framework_trace(mock_client) + adapter = CrewAIAdapter(mock_client) + + with crewai_event_bus.scoped_handlers(): + adapter.connect() + + # Events emitted AFTER scope should NOT be captured + crewai_event_bus.emit(None, event=CrewKickoffStartedEvent(crew_name="Ghost", inputs={})) + crewai_event_bus.flush(timeout=2.0) + + # Nothing should have been captured (no flush happened either) + assert uploaded.get("events") is None or len(uploaded.get("events", [])) == 0 + + +class TestCaptureConfigGating: + """Verify CaptureConfig correctly gates event types.""" + + def test_minimal_config_skips_model_and_tool(self, mock_client): + from layerlens.instrument._capture_config import CaptureConfig + + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig.minimal() # l3_model_metadata=False, l5a_tool_calls=False + adapter = CrewAIAdapter(mock_client, capture_config=config) + + with crewai_event_bus.scoped_handlers(): + adapter.connect() + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # These should be filtered by CaptureConfig + response = {"content": "hi", "usage": {"prompt_tokens": 10, "completion_tokens": 5}} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="gpt-4o", call_id="c1", call_type="llm_call", response=response) + ) + now = datetime.datetime.now() + adapter._on_tool_started( + None, ToolUsageStartedEvent(tool_name="x", tool_args="y", agent_key="a1") + ) + adapter._on_tool_finished( + None, ToolUsageFinishedEvent(tool_name="x", tool_args="y", agent_key="a1", started_at=now, finished_at=now, output="z") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + # model.invoke should be filtered out + assert len(find_events(events, "model.invoke")) == 0 + # tool.call and tool.result should be filtered out + assert len(find_events(events, "tool.call")) == 0 + assert len(find_events(events, "tool.result")) == 0 + # agent.input and agent.output should still be there (L1 is enabled) + assert len(find_events(events, "agent.input")) >= 1 + assert len(find_events(events, "agent.output")) >= 1 + # cost.record IS always-enabled, so if tokens were extracted it should be there + cost_events = find_events(events, "cost.record") + assert len(cost_events) >= 1 # cost.record bypasses CaptureConfig + + +class TestFlowEvents: + """Test CrewAI Flow lifecycle event handling.""" + + def test_flow_start_and_finish(self, adapter_and_trace): + from crewai.events import FlowStartedEvent, FlowFinishedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_flow_started(None, FlowStartedEvent(flow_name="AnalysisFlow", inputs={"topic": "AI"})) + adapter._on_flow_finished(None, FlowFinishedEvent(flow_name="AnalysisFlow", result="done", state={})) + + events = uploaded["events"] + flow_in = find_event(events, "agent.input") + assert flow_in["payload"]["flow_name"] == "AnalysisFlow" + assert flow_in["payload"]["input"] == {"topic": "AI"} + assert flow_in["span_name"] == "flow:AnalysisFlow" + + flow_out = find_event(events, "agent.output") + assert flow_out["payload"]["flow_name"] == "AnalysisFlow" + assert flow_out["payload"]["duration_ns"] > 0 + + +class TestMCPToolEvents: + """Test MCP tool execution event handling.""" + + def test_mcp_tool_completed(self, adapter_and_trace): + from crewai.events import MCPToolExecutionCompletedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + now = datetime.datetime.now() + adapter._on_mcp_tool_completed( + None, + MCPToolExecutionCompletedEvent( + tool_name="read_file", + tool_args={"path": "/etc/hosts"}, + server_name="filesystem", + server_url="stdio://mcp-fs", + transport_type="stdio", + result="127.0.0.1 localhost", + started_at=now, + completed_at=now, + execution_duration_ms=42, + ), + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["tool_name"] == "read_file" + assert tool_call["payload"]["mcp_server"] == "filesystem" + assert tool_call["payload"]["latency_ms"] == 42 + assert tool_call["payload"]["output"] == "127.0.0.1 localhost" + + def test_mcp_tool_failed(self, adapter_and_trace): + from crewai.events import MCPToolExecutionFailedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_mcp_tool_failed( + None, + MCPToolExecutionFailedEvent( + tool_name="exec_sql", + tool_args={"query": "DROP TABLE users"}, + server_name="db-server", + server_url="http://localhost:3000", + transport_type="http", + error="permission denied", + ), + ) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="mcp fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + mcp_error = [e for e in errors if e["payload"].get("mcp_server")] + assert len(mcp_error) == 1 + assert mcp_error[0]["payload"]["tool_name"] == "exec_sql" + assert mcp_error[0]["payload"]["mcp_server"] == "db-server" + + +class TestLLMLatencyTracking: + """Test LLM call latency computation from start→complete events.""" + + def test_latency_computed_from_started_event(self, adapter_and_trace): + from crewai.events import LLMCallStartedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # Start event stores timestamp + adapter._on_llm_started(None, LLMCallStartedEvent( + model="gpt-4o", call_id="latency_test", messages=[], call_type="llm_call", + )) + + # Small delay to get measurable latency + import time + time.sleep(0.01) + + # Complete event computes latency + response = {"content": "hi", "usage": {"prompt_tokens": 5, "completion_tokens": 3}} + adapter._on_llm_completed(None, LLMCallCompletedEvent( + model="gpt-4o", call_id="latency_test", call_type="llm_call", response=response, + )) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + model_invoke = find_event(events, "model.invoke") + assert "latency_ms" in model_invoke["payload"] + assert model_invoke["payload"]["latency_ms"] >= 5 # at least 5ms from the sleep + + +class TestAgentExecutionLifecycle: + """Test agent execution start/complete/error events.""" + + def test_agent_execution_started(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="T", agent_role="Researcher")) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct( + agent_role="Researcher", task_prompt="Find AI papers", tools=[] + ) + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + agent_inputs = find_events(events, "agent.input") + # Filter for agent execution events (have agent_role but NOT task_name) + agent_exec = [e for e in agent_inputs if e["payload"].get("agent_role") == "Researcher" and "task_name" not in e["payload"]] + assert len(agent_exec) == 1 + assert agent_exec[0]["payload"]["framework"] == "crewai" + assert agent_exec[0]["payload"]["task_prompt"] == "Find AI papers" + + def test_agent_execution_completed(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Writer") + ) + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="Writer", output="Final draft") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + agent_outputs = find_events(events, "agent.output") + agent_out = [e for e in agent_outputs if e["payload"].get("agent_role") == "Writer"] + assert len(agent_out) == 1 + assert agent_out[0]["payload"]["status"] == "ok" + assert agent_out[0]["payload"]["output"] == "Final draft" + + def test_agent_execution_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Researcher") + ) + adapter._on_agent_execution_error( + None, AgentExecutionErrorEvent.model_construct(agent_role="Researcher", error="agent crashed") + ) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="agent fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + agent_err = [e for e in errors if e["payload"].get("agent_role") == "Researcher"] + assert len(agent_err) == 1 + assert agent_err[0]["payload"]["error"] == "agent crashed" + + def test_agent_span_hierarchy(self, adapter_and_trace): + """Agent execution events are children of the current task span.""" + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="T1", agent_role="R")) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="R") + ) + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="R", output="done") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + # Find the task span_id + task_inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("task_name") == "T1"] + assert len(task_inputs) == 1 + task_span = task_inputs[0]["span_id"] + + # Agent execution should be parented to task (filter out task event which also has agent_role) + agent_exec_inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("agent_role") == "R" and "task_name" not in e["payload"]] + assert len(agent_exec_inputs) == 1 + assert agent_exec_inputs[0]["parent_span_id"] == task_span + + def test_llm_parented_to_agent(self, adapter_and_trace): + """LLM events should be children of the current agent execution span.""" + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="T1", agent_role="R")) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="R") + ) + + response = {"content": "hi", "usage": {"prompt_tokens": 5, "completion_tokens": 3}} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="gpt-4o", call_id="c1", call_type="llm_call", response=response) + ) + + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="R", output="done") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + # Find the agent execution span_id (not the task event which also has agent_role) + agent_exec_inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("agent_role") == "R" and "task_name" not in e["payload"]] + assert len(agent_exec_inputs) == 1 + agent_span = agent_exec_inputs[0]["span_id"] + + # LLM event should be parented to agent execution + model_invoke = find_event(events, "model.invoke") + assert model_invoke["parent_span_id"] == agent_span diff --git a/tests/instrument/adapters/frameworks/test_langchain.py b/tests/instrument/adapters/frameworks/test_langchain.py index d2a3057..db82d0d 100644 --- a/tests/instrument/adapters/frameworks/test_langchain.py +++ b/tests/instrument/adapters/frameworks/test_langchain.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import BaseCallbackHandler +from layerlens.instrument._capture_config import CaptureConfig from layerlens.instrument.adapters.frameworks.langchain import LangChainCallbackHandler from .conftest import capture_framework_trace, find_event, find_events @@ -23,14 +24,21 @@ def test_name(self): handler = LangChainCallbackHandler(Mock()) assert handler.name == "langchain" + def test_adapter_info(self): + handler = LangChainCallbackHandler(Mock()) + info = handler.adapter_info() + assert info.name == "langchain" + assert info.adapter_type == "framework" + assert info.connected is False + # --------------------------------------------------------------------------- -# Emit events +# Chain lifecycle # --------------------------------------------------------------------------- -class TestEmitsEvents: - def test_chain_lifecycle(self, mock_client): +class TestChainLifecycle: + def test_chain_emits_input_and_output(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) @@ -49,49 +57,99 @@ def test_chain_lifecycle(self, mock_client): agent_output = find_event(events, "agent.output") assert agent_output["payload"]["status"] == "ok" + assert agent_output["payload"]["output"] == {"output": "AI is..."} + + def test_chain_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) + handler.on_chain_error(ValueError("broke"), run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "broke" + assert error["payload"]["status"] == "error" + + +# --------------------------------------------------------------------------- +# LLM lifecycle — single merged model.invoke +# --------------------------------------------------------------------------- + - def test_llm_lifecycle(self, mock_client): +def _make_llm_response( + text: str = "AI is...", + model_name: str = "gpt-4", + prompt_tokens: int = 100, + completion_tokens: int = 50, +) -> Mock: + resp = Mock() + resp.generations = [[Mock(text=text)]] + resp.llm_output = { + "model_name": model_name, + "token_usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + return resp + + +class TestLLMLifecycle: + def test_single_model_invoke_with_merged_data(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() llm_id = uuid4() - handler.on_chain_start( - {"name": "Chain"}, {"input": "x"}, run_id=chain_id, - ) + handler.on_chain_start({"name": "Chain"}, {"input": "x"}, run_id=chain_id) handler.on_llm_start( - {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, + {"name": "ChatOpenAI"}, ["What is AI?"], - run_id=llm_id, - parent_run_id=chain_id, + run_id=llm_id, parent_run_id=chain_id, ) - - llm_response = Mock() - llm_response.generations = [[Mock(text="AI is...")]] - llm_response.llm_output = { - "token_usage": {"total_tokens": 50}, - "model_name": "gpt-4", - } - handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) handler.on_chain_end({"output": "AI is..."}, run_id=chain_id) events = uploaded["events"] - model_invokes = find_events(events, "model.invoke") - assert len(model_invokes) >= 1 - # Start event has name and messages - start_invoke = [m for m in model_invokes if m["payload"].get("name") == "ChatOpenAI"] - assert len(start_invoke) == 1 - # End event has model and output - end_invoke = [m for m in model_invokes if m["payload"].get("model") == "gpt-4"] - assert len(end_invoke) == 1 - assert end_invoke[0]["payload"]["output_message"] == "AI is..." + # Single event, not two + assert len(model_invokes) == 1 + + invoke = model_invokes[0] + assert invoke["payload"]["name"] == "ChatOpenAI" + assert invoke["payload"]["model"] == "gpt-4" + assert invoke["payload"]["messages"] == ["What is AI?"] + assert invoke["payload"]["output_message"] == "AI is..." + assert invoke["payload"]["latency_ms"] >= 0 + + def test_normalized_token_fields(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["p"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + invoke = find_event(events, "model.invoke") + assert invoke["payload"]["tokens_prompt"] == 100 + assert invoke["payload"]["tokens_completion"] == 50 + assert invoke["payload"]["tokens_total"] == 150 cost = find_event(events, "cost.record") - assert cost["payload"]["total_tokens"] == 50 + assert cost["payload"]["tokens_prompt"] == 100 + assert cost["payload"]["tokens_completion"] == 50 + assert cost["payload"]["tokens_total"] == 150 + assert cost["payload"]["model"] == "gpt-4" - def test_chat_model_start(self, mock_client): + def test_chat_model_start_serializes_messages(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) @@ -105,8 +163,11 @@ def test_chat_model_start(self, mock_client): handler.on_chat_model_start( {"name": "ChatAnthropic"}, [[msg]], + run_id=chat_id, parent_run_id=chain_id, + ) + handler.on_llm_end( + _make_llm_response(text="Hi!", model_name="claude-3"), run_id=chat_id, - parent_run_id=chain_id, ) handler.on_chain_end({}, run_id=chain_id) @@ -114,6 +175,101 @@ def test_chat_model_start(self, mock_client): invoke = find_event(events, "model.invoke") assert invoke["payload"]["name"] == "ChatAnthropic" assert invoke["payload"]["messages"] == [[{"type": "human", "content": "Hello"}]] + assert invoke["payload"]["output_message"] == "Hi!" + + def test_llm_error_emits_model_invoke_with_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_error(RuntimeError("timeout"), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + invoke = find_event(events, "model.invoke") + assert invoke["payload"]["error"] == "timeout" + assert invoke["payload"]["latency_ms"] >= 0 + + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "timeout" + + +# --------------------------------------------------------------------------- +# CaptureConfig content gating +# --------------------------------------------------------------------------- + + +class TestCaptureConfig: + def test_capture_content_false_strips_inputs_and_messages(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {"secret": "data"}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["secret prompt"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_end(_make_llm_response(text="secret reply"), run_id=llm_id) + handler.on_chain_end({"output": "secret"}, run_id=chain_id) + + events = uploaded["events"] + + # Chain events should not contain content + agent_input = find_event(events, "agent.input") + assert "input" not in agent_input["payload"] + agent_output = find_event(events, "agent.output") + assert "output" not in agent_output["payload"] + + # Model invoke should not contain messages or output + invoke = find_event(events, "model.invoke") + assert "messages" not in invoke["payload"] + assert "output_message" not in invoke["payload"] + # But structural fields are still present + assert invoke["payload"]["name"] == "LLM" + + def test_capture_content_false_strips_tool_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "secret query", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("secret results", run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert "input" not in tool_call["payload"] + assert tool_call["payload"]["name"] == "search" + + tool_result = find_event(events, "tool.result") + assert "output" not in tool_result["payload"] + + def test_capture_content_false_strips_retriever_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) + + chain_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start({"name": "vs"}, "secret query", run_id=ret_id, parent_run_id=chain_id) + docs = [Mock(page_content="secret doc", metadata={"source": "a.txt"})] + handler.on_retriever_end(docs, run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert "input" not in tool_call["payload"] + + tool_result = find_event(events, "tool.result") + assert "output" not in tool_result["payload"] # --------------------------------------------------------------------------- @@ -189,69 +345,105 @@ def test_combined_tools_and_retrievers(self, mock_client): assert len(find_events(events, "tool.call")) == 2 assert len(find_events(events, "tool.result")) == 2 + def test_tool_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) -# --------------------------------------------------------------------------- -# Error handling -# --------------------------------------------------------------------------- + chain_id = uuid4() + tool_id = uuid4() + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_error(RuntimeError("404"), run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) -class TestErrors: - def test_chain_error(self, mock_client): + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "404" + + def test_retriever_error(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() - handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) - handler.on_chain_error(ValueError("broke"), run_id=chain_id) + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) + handler.on_retriever_error(ConnectionError("down"), run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "broke" - assert error["payload"]["status"] == "error" + assert error["payload"]["error"] == "down" - def test_llm_error(self, mock_client): + +# --------------------------------------------------------------------------- +# Agent action / finish callbacks +# --------------------------------------------------------------------------- + + +class TestAgentCallbacks: + def test_agent_action_emits_input(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() - llm_id = uuid4() + agent_id = uuid4() - handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) - handler.on_llm_start({"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id) - handler.on_llm_error(RuntimeError("timeout"), run_id=llm_id) + handler.on_chain_start({"name": "AgentExecutor"}, {}, run_id=chain_id) + + action = Mock() + action.tool = "search" + action.tool_input = "what is AI" + action.log = "Thought: I need to search" + handler.on_agent_action(action, run_id=agent_id, parent_run_id=chain_id) handler.on_chain_end({}, run_id=chain_id) - error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "timeout" + events = uploaded["events"] + inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("tool") == "search"] + assert len(inputs) == 1 + assert inputs[0]["payload"]["tool_input"] == "what is AI" + assert inputs[0]["payload"]["log"] == "Thought: I need to search" - def test_tool_error(self, mock_client): + def test_agent_finish_emits_output(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() - tool_id = uuid4() + agent_id = uuid4() - handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) - handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) - handler.on_tool_error(RuntimeError("404"), run_id=tool_id) + handler.on_chain_start({"name": "AgentExecutor"}, {}, run_id=chain_id) + + finish = Mock() + finish.return_values = {"output": "AI is artificial intelligence"} + finish.log = "Final Answer: AI is artificial intelligence" + handler.on_agent_finish(finish, run_id=agent_id, parent_run_id=chain_id) handler.on_chain_end({}, run_id=chain_id) - error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "404" + events = uploaded["events"] + outputs = [e for e in find_events(events, "agent.output") if e["payload"].get("log")] + assert len(outputs) == 1 + assert outputs[0]["payload"]["output"] == {"output": "AI is artificial intelligence"} - def test_retriever_error(self, mock_client): + def test_agent_action_respects_capture_content(self, mock_client): uploaded = capture_framework_trace(mock_client) - handler = LangChainCallbackHandler(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) chain_id = uuid4() - ret_id = uuid4() + agent_id = uuid4() handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) - handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) - handler.on_retriever_error(ConnectionError("down"), run_id=ret_id) + action = Mock() + action.tool = "secret_tool" + action.tool_input = "secret input" + action.log = "secret reasoning" + handler.on_agent_action(action, run_id=agent_id, parent_run_id=chain_id) handler.on_chain_end({}, run_id=chain_id) - error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "down" + events = uploaded["events"] + inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("tool") == "secret_tool"] + assert len(inputs) == 1 + assert "tool_input" not in inputs[0]["payload"] + assert "log" not in inputs[0]["payload"] # --------------------------------------------------------------------------- @@ -272,16 +464,13 @@ def test_llm_parent_is_chain(self, mock_client): {"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id, ) - llm_response = Mock() - llm_response.generations = [[Mock(text="out")]] - llm_response.llm_output = {} - handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) handler.on_chain_end({}, run_id=chain_id) events = uploaded["events"] chain_input = find_event(events, "agent.input") - llm_invoke = [e for e in find_events(events, "model.invoke") if e["payload"].get("name") == "LLM"][0] - assert llm_invoke["parent_span_id"] == chain_input["span_id"] + invoke = find_event(events, "model.invoke") + assert invoke["parent_span_id"] == chain_input["span_id"] # --------------------------------------------------------------------------- @@ -329,17 +518,23 @@ def test_llm_end_no_output(self, mock_client): handler.on_llm_end(empty_response, run_id=llm_id) handler.on_chain_end({}, run_id=chain_id) - # Should complete without error — no model.invoke end event since no output/model + # Should emit model.invoke with name but no output_message + invoke = find_event(uploaded["events"], "model.invoke") + assert invoke["payload"]["name"] == "LLM" + assert "output_message" not in invoke["payload"] + def test_llm_end_without_start(self, mock_client): + """on_llm_end without a preceding on_llm_start should not crash.""" + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) -# --------------------------------------------------------------------------- -# adapter_info -# --------------------------------------------------------------------------- + chain_id = uuid4() + llm_id = uuid4() + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) -class TestAdapterInfo: - def test_info(self): - handler = LangChainCallbackHandler(Mock()) - info = handler.adapter_info() - assert info.name == "langchain" - assert info.adapter_type == "framework" + # Should still emit model.invoke from the response data + invoke = find_event(uploaded["events"], "model.invoke") + assert invoke["payload"]["model"] == "gpt-4" diff --git a/tests/instrument/adapters/frameworks/test_langgraph.py b/tests/instrument/adapters/frameworks/test_langgraph.py index 7ff6e9d..87097ad 100644 --- a/tests/instrument/adapters/frameworks/test_langgraph.py +++ b/tests/instrument/adapters/frameworks/test_langgraph.py @@ -52,7 +52,7 @@ def test_llm_events_inherited(self, mock_client): events = uploaded["events"] assert len(find_events(events, "model.invoke")) >= 1 - assert find_event(events, "cost.record")["payload"]["total_tokens"] == 10 + assert find_event(events, "cost.record")["payload"]["tokens_total"] == 10 def test_tool_events_inherited(self, mock_client): uploaded = capture_framework_trace(mock_client) diff --git a/tests/instrument/adapters/frameworks/test_openai_agents.py b/tests/instrument/adapters/frameworks/test_openai_agents.py new file mode 100644 index 0000000..111be7d --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_openai_agents.py @@ -0,0 +1,823 @@ +"""Tests for the OpenAI Agents SDK adapter using real SDK types. + +Uses real TracingProcessor, SpanImpl, Trace, and span data types. +No mocking of Agents SDK internals — only our mock_client for upload capture. +""" +from __future__ import annotations + +import json +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest + +import sys +if sys.version_info < (3, 10): + pytest.skip("openai-agents requires Python >= 3.10", allow_module_level=True) +try: + import agents # noqa: F401 +except (ImportError, Exception): + pytest.skip("openai-agents not installed or incompatible", allow_module_level=True) + +from agents.tracing import TracingProcessor, set_trace_processors # noqa: E402 +from agents.tracing.spans import SpanImpl # noqa: E402 +from agents.tracing.traces import TraceImpl # noqa: E402 +from agents.tracing.span_data import ( # noqa: E402 + AgentSpanData, + HandoffSpanData, + FunctionSpanData, + GuardrailSpanData, + GenerationSpanData, +) + +from layerlens.instrument._capture_config import CaptureConfig # noqa: E402 +from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter # noqa: E402 + +from .conftest import capture_framework_trace, find_event, find_events # noqa: E402 + +# -- Helpers -- + + +class _NoOpProcessor(TracingProcessor): + """Minimal processor that does nothing — used to reset global state.""" + + def on_trace_start(self, trace): + pass + + def on_trace_end(self, trace): + pass + + def on_span_start(self, span): + pass + + def on_span_end(self, span): + pass + + def shutdown(self): + pass + + def force_flush(self): + pass + + +_noop = _NoOpProcessor() + + +def _make_span( + _adapter: Any, + trace_id: str, + span_id: str, + span_data: Any, + parent_id: str | None = None, +) -> SpanImpl: + """Create a real SpanImpl for testing. + + Uses a NoOpProcessor internally so span.start()/finish() don't + double-trigger our adapter. Tests call adapter.on_span_end() manually. + The _adapter param is accepted for call-site readability but unused. + """ + return SpanImpl( + trace_id=trace_id, + span_id=span_id, + parent_id=parent_id, + processor=_noop, + span_data=span_data, + tracing_api_key=None, + ) + + +def _make_trace(name: str = "test_trace", trace_id: str = "trace_001", processor: Any = None) -> TraceImpl: + """Create a real TraceImpl for testing. + + If processor is None, uses a no-op processor. In actual tests, + pass the adapter's processor so trace lifecycle events route correctly. + """ + proc = processor or _NoOpProcessor() + return TraceImpl(name=name, trace_id=trace_id, group_id=None, metadata=None, processor=proc) + + +# -- Fixtures -- + + +@pytest.fixture +def adapter_and_trace(mock_client): + """Create adapter, connect, yield (adapter, uploaded_dict), then clean up. + + The adapter IS the TracingProcessor, so tests call adapter.on_span_end() etc. + directly — no separate processor object. + """ + uploaded = capture_framework_trace(mock_client) + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + yield adapter, uploaded + adapter.disconnect() + set_trace_processors([]) # ensure clean slate + + +@pytest.fixture(autouse=True) +def clean_processors(): + """Reset global trace processors after each test.""" + yield + set_trace_processors([]) + + +# -- Tests -- + + +class TestOpenAIAgentsAdapterLifecycle: + def test_connect_sets_connected(self, mock_client): + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + assert adapter.is_connected + info = adapter.adapter_info() + assert info.name == "openai-agents" + assert info.adapter_type == "framework" + adapter.disconnect() + + def test_disconnect_clears_state(self, mock_client): + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + adapter.disconnect() + assert not adapter.is_connected + + def test_connect_without_agents_raises(self, mock_client, monkeypatch): + import layerlens.instrument.adapters.frameworks.openai_agents as mod + + monkeypatch.setattr(mod, "_HAS_OPENAI_AGENTS", False) + adapter = OpenAIAgentsAdapter(mock_client) + with pytest.raises(ImportError, match="openai-agents"): + adapter.connect() + + +class TestAgentSpans: + """Test agent span handling with real AgentSpanData.""" + + def test_agent_span_emits_input_and_output(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t1") + + # Simulate trace + agent span lifecycle + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t1", "s_agent", + AgentSpanData(name="research_agent", tools=["search", "browse"], handoffs=["writer"]), + ) + span.start() + adapter.on_span_start(span) + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + assert len(events) >= 2 + + inp = find_event(events, "agent.input") + assert inp["payload"]["agent_name"] == "research_agent" + assert inp["payload"]["tools"] == ["search", "browse"] + assert inp["payload"]["handoffs"] == ["writer"] + assert inp["payload"]["framework"] == "openai-agents" + assert inp["span_id"] == "s_agent" + + out = find_event(events, "agent.output") + assert out["payload"]["agent_name"] == "research_agent" + assert out["payload"]["status"] == "ok" + assert out["span_id"] == "s_agent" + + def test_agent_span_with_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_err") + + adapter.on_trace_start(trace) + + span = _make_span(adapter,"t_err", "s_err", AgentSpanData(name="buggy_agent")) + span.start() + adapter.on_span_start(span) + span.set_error({"message": "Agent crashed", "data": {"step": 3}}) + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["agent_name"] == "buggy_agent" + assert err["payload"]["status"] == "error" + assert "Agent crashed" in str(err["payload"]["error"]) + + def test_nested_agent_spans(self, adapter_and_trace): + """Multi-agent: parent agent delegates to child agent.""" + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_nested") + + adapter.on_trace_start(trace) + + # Parent agent + parent = _make_span(adapter,"t_nested", "s_parent", AgentSpanData(name="orchestrator")) + parent.start() + adapter.on_span_start(parent) + + # Child agent + child = _make_span(adapter,"t_nested", "s_child", AgentSpanData(name="researcher"), parent_id="s_parent") + child.start() + adapter.on_span_start(child) + child.finish() + adapter.on_span_end(child) + + parent.finish() + adapter.on_span_end(parent) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + agent_inputs = find_events(events, "agent.input") + assert len(agent_inputs) == 2 + + # Child should have parent_span_id pointing to parent + child_input = [e for e in agent_inputs if e["payload"]["agent_name"] == "researcher"][0] + assert child_input["parent_span_id"] == "s_parent" + + +class TestGenerationSpans: + """Test LLM generation span handling.""" + + def test_generation_emits_model_invoke(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_gen") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_gen", "s_gen", + GenerationSpanData( + input=[{"role": "user", "content": "What is 2+2?"}], + output=[{"role": "assistant", "content": "4"}], + model="gpt-4o", + model_config={"temperature": 0.7}, + usage={"input_tokens": 50, "output_tokens": 10}, + ), + parent_id="s_agent", + ) + span.start() + adapter.on_span_start(span) + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + me = find_event(events, "model.invoke") + assert me["payload"]["model"] == "gpt-4o" + assert me["payload"]["tokens_prompt"] == 50 + assert me["payload"]["tokens_completion"] == 10 + assert me["payload"]["tokens_total"] == 60 + assert me["payload"]["latency_ms"] >= 0 + assert me["payload"]["messages"] == [{"role": "user", "content": "What is 2+2?"}] + assert me["payload"]["output_message"] == [{"role": "assistant", "content": "4"}] + assert me["span_id"] == "s_gen" + assert me["parent_span_id"] == "s_agent" + + def test_generation_emits_cost_record(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_cost") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_cost", "s_cost", + GenerationSpanData( + input=[], output=[], model="gpt-4o-mini", + model_config={}, + usage={"input_tokens": 100, "output_tokens": 25}, + ), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + cost = find_event(events, "cost.record") + assert cost["payload"]["model"] == "gpt-4o-mini" + assert cost["payload"]["tokens_prompt"] == 100 + assert cost["payload"]["tokens_completion"] == 25 + assert cost["payload"]["tokens_total"] == 125 + + def test_generation_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_gen_err") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_gen_err", "s_gen_err", + GenerationSpanData( + input=[{"role": "user", "content": "fail"}], + output=[], model="gpt-4o", + model_config={}, usage={}, + ), + ) + span.start() + span.set_error({"message": "Rate limit exceeded"}) + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert "Rate limit" in str(err["payload"]["error"]) + + def test_multiple_generations(self, adapter_and_trace): + """Agent makes multiple LLM calls (e.g. tool use loop).""" + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_multi_gen") + + adapter.on_trace_start(trace) + + for i, (inp_tok, out_tok) in enumerate([(50, 15), (80, 20)]): + span = _make_span( + adapter,"t_multi_gen", f"s_gen_{i}", + GenerationSpanData( + input=[], output=[], model="gpt-4o", + model_config={}, + usage={"input_tokens": inp_tok, "output_tokens": out_tok}, + ), + parent_id="s_agent", + ) + span.start() + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + gens = find_events(events, "model.invoke") + assert len(gens) == 2 + assert gens[0]["span_id"] != gens[1]["span_id"] + + +class TestFunctionSpans: + """Test tool/function span handling.""" + + def test_function_span_emits_tool_call(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_func") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_func", "s_func", + FunctionSpanData(name="get_weather", input='{"city":"NYC"}', output='{"temp":72}'), + parent_id="s_agent", + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + tc = find_event(events, "tool.call") + assert tc["payload"]["tool_name"] == "get_weather" + assert tc["payload"]["input"] == '{"city":"NYC"}' + assert tc["payload"]["output"] == '{"temp":72}' + assert tc["payload"]["latency_ms"] >= 0 + assert tc["parent_span_id"] == "s_agent" + + def test_function_span_with_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_func_err") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_func_err", "s_func_err", + FunctionSpanData(name="dangerous_tool", input="delete all", output=None), + ) + span.start() + span.set_error({"message": "Permission denied"}) + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["tool_name"] == "dangerous_tool" + assert "Permission denied" in str(err["payload"]["error"]) + + def test_function_span_with_mcp(self, adapter_and_trace): + """Function spans can include MCP data.""" + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_mcp") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_mcp", "s_mcp", + FunctionSpanData(name="mcp_tool", input="query", output="result"), + ) + # Set mcp_data manually + span.span_data.mcp_data = {"server": "my-mcp-server", "tool": "query_db"} + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + tc = find_event(events, "tool.call") + assert tc["payload"]["mcp_data"]["server"] == "my-mcp-server" + + +class TestHandoffSpans: + """Test handoff span handling.""" + + def test_handoff_emits_event(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_handoff") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_handoff", "s_handoff", + HandoffSpanData(from_agent="triage", to_agent="specialist"), + parent_id="s_agent", + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + ho = find_event(events, "agent.handoff") + assert ho["payload"]["from_agent"] == "triage" + assert ho["payload"]["to_agent"] == "specialist" + assert ho["parent_span_id"] == "s_agent" + + +class TestGuardrailSpans: + """Test guardrail span handling.""" + + def test_guardrail_emits_evaluation_result(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_guard") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_guard", "s_guard", + GuardrailSpanData(name="content_filter", triggered=True), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + ev = find_event(events, "evaluation.result") + assert ev["payload"]["guardrail_name"] == "content_filter" + assert ev["payload"]["triggered"] is True + + def test_guardrail_not_triggered(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_guard2") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_guard2", "s_guard2", + GuardrailSpanData(name="pii_detector", triggered=False), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + ev = find_event(events, "evaluation.result") + assert ev["payload"]["triggered"] is False + + +class TestFullAgentFlow: + """End-to-end test simulating a complete agent run with tools and handoff.""" + + def test_complete_flow(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_flow", name="customer_support") + + adapter.on_trace_start(trace) + + # Agent span + agent = _make_span(adapter,"t_flow", "s_agent", AgentSpanData(name="triage", tools=["classify"])) + agent.start() + adapter.on_span_start(agent) + + # LLM call + gen = _make_span( + adapter,"t_flow", "s_gen", + GenerationSpanData( + input=[{"role": "user", "content": "I need help"}], + output=[{"role": "assistant", "content": "Let me classify this"}], + model="gpt-4o-mini", + model_config={}, + usage={"input_tokens": 30, "output_tokens": 10}, + ), + parent_id="s_agent", + ) + gen.start() + gen.finish() + adapter.on_span_end(gen) + + # Tool call + tool = _make_span( + adapter,"t_flow", "s_tool", + FunctionSpanData(name="classify", input="I need help", output="billing"), + parent_id="s_agent", + ) + tool.start() + tool.finish() + adapter.on_span_end(tool) + + # Guardrail + guard = _make_span( + adapter,"t_flow", "s_guard", + GuardrailSpanData(name="safety_check", triggered=False), + parent_id="s_agent", + ) + guard.start() + guard.finish() + adapter.on_span_end(guard) + + # Handoff + handoff = _make_span( + adapter,"t_flow", "s_handoff", + HandoffSpanData(from_agent="triage", to_agent="billing_agent"), + parent_id="s_agent", + ) + handoff.start() + handoff.finish() + adapter.on_span_end(handoff) + + agent.finish() + adapter.on_span_end(agent) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + types = [e["event_type"] for e in events] + + assert "agent.input" in types + assert "agent.output" in types + assert "model.invoke" in types + assert "cost.record" in types + assert "tool.call" in types + assert "evaluation.result" in types + assert "agent.handoff" in types + + # Verify ordering + seq_ids = [e["sequence_id"] for e in events] + assert seq_ids == sorted(seq_ids) + assert len(set(seq_ids)) == len(seq_ids) + + # Verify parent-child relationships + me = find_event(events, "model.invoke") + assert me["parent_span_id"] == "s_agent" + + tc = find_event(events, "tool.call") + assert tc["parent_span_id"] == "s_agent" + + +class TestCaptureConfigGating: + """Test that CaptureConfig gates events properly.""" + + def test_minimal_config(self, mock_client): + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig.minimal() + adapter = OpenAIAgentsAdapter(mock_client, capture_config=config) + adapter.connect() + + + trace = _make_trace(trace_id="t_min") + + adapter.on_trace_start(trace) + + # Agent span (L1 — should be captured) + agent = _make_span(adapter,"t_min", "s_agent", AgentSpanData(name="test")) + agent.start() + agent.finish() + adapter.on_span_end(agent) + + # Generation span (L3 — should be skipped) + gen = _make_span( + adapter,"t_min", "s_gen", + GenerationSpanData( + input=[], output=[], model="gpt-4o", + model_config={}, usage={"input_tokens": 10, "output_tokens": 5}, + ), + ) + gen.start() + gen.finish() + adapter.on_span_end(gen) + + # Tool span (L5a — should be skipped) + tool = _make_span( + adapter,"t_min", "s_tool", + FunctionSpanData(name="search", input="q", output="r"), + ) + tool.start() + tool.finish() + adapter.on_span_end(tool) + + adapter.on_trace_end(trace) + + events = uploaded.get("events", []) + types = [e["event_type"] for e in events] + + assert "agent.input" in types + assert "agent.output" in types + assert "model.invoke" not in types + assert "tool.call" not in types + # cost.record is always enabled + assert "cost.record" in types + + adapter.disconnect() + + +class TestConcurrentTraces: + """Test that multiple concurrent traces are isolated.""" + + def test_parallel_traces_isolated(self, mock_client): + all_uploads: List[Dict[str, Any]] = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + all_uploads.append(data[0]) + + mock_client.traces.upload = MagicMock(side_effect=_capture) + + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + + + # Two concurrent traces + t1 = _make_trace(trace_id="t_par_1") + t2 = _make_trace(trace_id="t_par_2") + + adapter.on_trace_start(t1) + adapter.on_trace_start(t2) + + # Agent in trace 1 + s1 = _make_span(adapter,"t_par_1", "s1", AgentSpanData(name="agent_1")) + s1.start() + s1.finish() + adapter.on_span_end(s1) + + # Agent in trace 2 + s2 = _make_span(adapter,"t_par_2", "s2", AgentSpanData(name="agent_2")) + s2.start() + s2.finish() + adapter.on_span_end(s2) + + adapter.on_trace_end(t1) + adapter.on_trace_end(t2) + + assert len(all_uploads) == 2 + + # Each trace should have its own events + names = set() + for upload in all_uploads: + for e in upload["events"]: + if e["event_type"] == "agent.input": + names.add(e["payload"]["agent_name"]) + + assert names == {"agent_1", "agent_2"} + + adapter.disconnect() + + +class TestErrorIsolation: + """Verify hooks never crash the SDK.""" + + def test_broken_collector_does_not_crash(self, mock_client): + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + + + trace = _make_trace(trace_id="t_safe") + adapter.on_trace_start(trace) + + # Break the collector + adapter._collectors["t_safe"] = None # type: ignore[assignment] + + # This should not raise + span = _make_span(adapter,"t_safe", "s_safe", AgentSpanData(name="test")) + span.start() + span.finish() + adapter.on_span_end(span) # Should log warning, not crash + + # Trace end should not crash either + adapter.on_trace_end(trace) + + adapter.disconnect() + + +class TestEdgeCases: + def test_empty_usage(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_empty") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_empty", "s_empty", + GenerationSpanData(input=[], output=[], model="gpt-4o", model_config={}, usage={}), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + me = find_event(events, "model.invoke") + assert "tokens_prompt" not in me["payload"] + assert "tokens_completion" not in me["payload"] + + def test_none_values_in_span_data(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_none") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_none", "s_none", + AgentSpanData(name="minimal_agent"), # no tools, no handoffs + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + inp = find_event(events, "agent.input") + assert inp["payload"]["agent_name"] == "minimal_agent" + assert "tools" not in inp["payload"] + assert "handoffs" not in inp["payload"] + + def test_function_span_with_none_output(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_none_out") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_none_out", "s_func", + FunctionSpanData(name="void_tool", input="run", output=None), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + tc = find_event(events, "tool.call") + assert tc["payload"]["tool_name"] == "void_tool" + # output should not be in payload since it was None + assert "output" not in tc["payload"] + + def test_span_duration_tracking(self, adapter_and_trace): + """Verify duration_ms is computed from span timing.""" + import time as _time + + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_dur") + + adapter.on_trace_start(trace) + + span = _make_span(adapter,"t_dur", "s_dur", AgentSpanData(name="slow_agent")) + span.start() + _time.sleep(0.02) # 20ms + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + out = find_event(events, "agent.output") + assert out["payload"]["duration_ms"] >= 15 # allow tolerance diff --git a/tests/instrument/adapters/frameworks/test_pydantic_ai.py b/tests/instrument/adapters/frameworks/test_pydantic_ai.py new file mode 100644 index 0000000..c60ae7a --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_pydantic_ai.py @@ -0,0 +1,471 @@ +"""Tests for the PydanticAI adapter using the native Hooks capability API. + +Tests use PydanticAI's TestModel to exercise the real agent loop with +hooks firing at each lifecycle point — no monkey-patching or mocking of +PydanticAI internals. +""" +from __future__ import annotations + +import asyncio +from typing import Optional + +import pytest + +pydantic_ai = pytest.importorskip("pydantic_ai") + +from pydantic_ai import Agent # noqa: E402 +from pydantic_ai.models.test import TestModel # noqa: E402 + +from layerlens.instrument._capture_config import CaptureConfig # noqa: E402 +from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter # noqa: E402 + +from .conftest import capture_framework_trace, find_event, find_events # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_agent( + name: Optional[str] = None, + output_text: str = "Hello!", + model_name: str = "test", + tools: Optional[list] = None, +) -> Agent: + """Create a PydanticAI Agent with TestModel for deterministic testing.""" + agent = Agent( + model=TestModel(custom_output_text=output_text, model_name=model_name), + name=name, + ) + if tools: + for tool_fn in tools: + agent.tool_plain(tool_fn) + return agent + + +def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"72F in {city}" + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +class TestPydanticAIAdapterLifecycle: + def test_connect_injects_hooks(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent() + + caps_before = len(agent._root_capability.capabilities) + adapter.connect(target=agent) + + assert adapter.is_connected + assert len(agent._root_capability.capabilities) == caps_before + 1 + info = adapter.adapter_info() + assert info.name == "pydantic-ai" + assert info.adapter_type == "framework" + assert info.connected is True + + adapter.disconnect() + + def test_disconnect_removes_hooks(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent() + caps_before = len(agent._root_capability.capabilities) + + adapter.connect(target=agent) + adapter.disconnect() + + assert not adapter.is_connected + assert len(agent._root_capability.capabilities) == caps_before + + def test_connect_without_target_raises(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + with pytest.raises(ValueError, match="requires a target agent"): + adapter.connect() + + def test_connect_without_pydantic_ai_raises(self, mock_client, monkeypatch): + import layerlens.instrument.adapters.frameworks.pydantic_ai as mod + + monkeypatch.setattr(mod, "_HAS_PYDANTIC_AI", False) + adapter = PydanticAIAdapter(mock_client) + with pytest.raises(ImportError, match="pydantic-ai"): + adapter.connect(target=_make_agent()) + + +# --------------------------------------------------------------------------- +# run_sync +# --------------------------------------------------------------------------- + + +class TestRunSync: + def test_basic_run(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="The weather is sunny") + + adapter.connect(target=agent) + result = agent.run_sync("What is the weather?") + adapter.disconnect() + + assert result.output == "The weather is sunny" + events = uploaded["events"] + + inp = find_event(events, "agent.input") + assert inp["payload"]["framework"] == "pydantic-ai" + assert inp["payload"]["input"] == "What is the weather?" + + out = find_event(events, "agent.output") + assert out["payload"]["status"] == "ok" + assert out["payload"]["output"] == "The weather is sunny" + assert out["payload"]["latency_ms"] >= 0 + assert out["payload"]["tokens_prompt"] > 0 + assert out["payload"]["tokens_completion"] > 0 + + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] > 0 + + def test_named_agent(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(name="my_agent", output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("test") + adapter.disconnect() + + inp = find_event(uploaded["events"], "agent.input") + assert inp["payload"]["agent_name"] == "my_agent" + + +# --------------------------------------------------------------------------- +# async run +# --------------------------------------------------------------------------- + + +class TestRunAsync: + def test_async_run(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(name="async_agent", output_text="Async result") + + adapter.connect(target=agent) + result = asyncio.get_event_loop().run_until_complete(agent.run("async test")) + adapter.disconnect() + + assert result.output == "Async result" + + inp = find_event(uploaded["events"], "agent.input") + assert inp["payload"]["agent_name"] == "async_agent" + assert inp["payload"]["input"] == "async test" + + out = find_event(uploaded["events"], "agent.output") + assert out["payload"]["status"] == "ok" + + +# --------------------------------------------------------------------------- +# Model invocation events +# --------------------------------------------------------------------------- + + +class TestModelInvocation: + def test_model_invoke_emitted(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="hello", model_name="gpt-4o-test") + + adapter.connect(target=agent) + agent.run_sync("hi") + adapter.disconnect() + + model_invokes = find_events(uploaded["events"], "model.invoke") + assert len(model_invokes) >= 1 + assert model_invokes[0]["payload"]["model"] == "gpt-4o-test" + assert model_invokes[0]["payload"]["tokens_prompt"] > 0 + + def test_model_invoke_with_tools_has_two_calls(self, mock_client): + """When a tool is called, TestModel makes 2 model requests: + first to call the tool, then to produce the final text.""" + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather NYC") + adapter.disconnect() + + model_invokes = find_events(uploaded["events"], "model.invoke") + assert len(model_invokes) == 2 + + +# --------------------------------------------------------------------------- +# Tool events +# --------------------------------------------------------------------------- + + +class TestToolEvents: + def test_tool_call_and_result(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather NYC") + adapter.disconnect() + + events = uploaded["events"] + + tool_calls = find_events(events, "tool.call") + assert len(tool_calls) == 1 + assert tool_calls[0]["payload"]["tool_name"] == "get_weather" + + tool_results = find_events(events, "tool.result") + assert len(tool_results) == 1 + assert tool_results[0]["payload"]["tool_name"] == "get_weather" + + def test_tool_result_has_output(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client, capture_config=CaptureConfig.full()) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather NYC") + adapter.disconnect() + + tool_results = find_events(uploaded["events"], "tool.result") + assert len(tool_results) == 1 + # The output should contain the tool's return value + assert "72F" in str(tool_results[0]["payload"]["output"]) + + def test_tool_result_has_latency(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather") + adapter.disconnect() + + tool_results = find_events(uploaded["events"], "tool.result") + assert len(tool_results) == 1 + assert tool_results[0]["payload"]["latency_ms"] >= 0 + + +# --------------------------------------------------------------------------- +# Span hierarchy +# --------------------------------------------------------------------------- + + +class TestSpanHierarchy: + def test_per_step_events_parented_to_root(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather") + adapter.disconnect() + + events = uploaded["events"] + root = find_event(events, "agent.input") + root_span = root["span_id"] + + for evt in find_events(events, "model.invoke"): + assert evt["parent_span_id"] == root_span + for evt in find_events(events, "tool.call"): + assert evt["parent_span_id"] == root_span + for evt in find_events(events, "tool.result"): + assert evt["parent_span_id"] == root_span + + +# --------------------------------------------------------------------------- +# CaptureConfig gating +# --------------------------------------------------------------------------- + + +class TestCaptureConfigGating: + def test_no_content_capture_omits_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig(capture_content=False) + adapter = PydanticAIAdapter(mock_client, capture_config=config) + agent = _make_agent(output_text="done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("secret prompt") + adapter.disconnect() + + events = uploaded["events"] + + inp = find_event(events, "agent.input") + assert "input" not in inp["payload"] + + tool_calls = find_events(events, "tool.call") + assert len(tool_calls) >= 1 + assert "input" not in tool_calls[0]["payload"] + + tool_results = find_events(events, "tool.result") + assert len(tool_results) >= 1 + assert "output" not in tool_results[0]["payload"] + + # cost.record should still exist + assert len(find_events(events, "cost.record")) == 1 + + def test_full_config_includes_content(self, mock_client): + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig.full() + adapter = PydanticAIAdapter(mock_client, capture_config=config) + agent = _make_agent(output_text="Hi Alice", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("greet Alice") + adapter.disconnect() + + events = uploaded["events"] + + inp = find_event(events, "agent.input") + assert inp["payload"]["input"] == "greet Alice" + + out = find_event(events, "agent.output") + assert out["payload"]["output"] == "Hi Alice" + + tool_calls = find_events(events, "tool.call") + assert "input" in tool_calls[0]["payload"] + + +# --------------------------------------------------------------------------- +# Multiple runs +# --------------------------------------------------------------------------- + + +class TestMultipleRuns: + def test_sequential_runs_separate_traces(self, mock_client): + import json + + all_uploads: list = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + all_uploads.append(data[0]) + + mock_client.traces.upload.side_effect = _capture + + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("first") + agent.run_sync("second") + adapter.disconnect() + + assert len(all_uploads) == 2 + trace_ids = {u["trace_id"] for u in all_uploads} + assert len(trace_ids) == 2 + + +# --------------------------------------------------------------------------- +# Event structure +# --------------------------------------------------------------------------- + + +class TestEventStructure: + def test_event_fields(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(name="test_agent", output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("hello") + adapter.disconnect() + + events = uploaded["events"] + for event in events: + assert "event_type" in event + assert "trace_id" in event + assert "span_id" in event + assert "sequence_id" in event + assert "timestamp_ns" in event + assert "payload" in event + + seq_ids = [e["sequence_id"] for e in events] + assert seq_ids == sorted(seq_ids) + assert len(set(seq_ids)) == len(seq_ids) + + trace_ids = {e["trace_id"] for e in events} + assert len(trace_ids) == 1 + + def test_attestation_present(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("test") + adapter.disconnect() + + assert uploaded.get("trace_id") is not None + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_prompt(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("") + adapter.disconnect() + + inp = find_event(uploaded["events"], "agent.input") + assert inp["payload"]["framework"] == "pydantic-ai" + + def test_pydantic_model_output(self, mock_client): + from pydantic import BaseModel + + class CityInfo(BaseModel): + city: str + temp: int + + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = Agent( + model=TestModel(custom_output_args={"city": "NYC", "temp": 72}), + output_type=CityInfo, + ) + + adapter.connect(target=agent) + result = agent.run_sync("weather") + adapter.disconnect() + + out = find_event(uploaded["events"], "agent.output") + assert out["payload"]["output"] == {"city": "NYC", "temp": 72} + + def test_zero_token_usage_still_has_tokens(self, mock_client): + """TestModel always produces some tokens, so we verify they're present.""" + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("test") + adapter.disconnect() + + out = find_event(uploaded["events"], "agent.output") + # TestModel always has some token usage + assert "tokens_prompt" in out["payload"] + assert len(find_events(uploaded["events"], "cost.record")) == 1 + + def test_disconnect_idempotent(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent() + adapter.connect(target=agent) + adapter.disconnect() + adapter.disconnect() # should not raise diff --git a/tests/instrument/adapters/frameworks/test_semantic_kernel.py b/tests/instrument/adapters/frameworks/test_semantic_kernel.py new file mode 100644 index 0000000..9ae833a --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_semantic_kernel.py @@ -0,0 +1,753 @@ +"""Tests for the Semantic Kernel adapter using the SK filter API. + +Tests use real Kernel objects and KernelFunctions. Filters are exercised +either through actual kernel.invoke() calls or by directly invoking the +filter callables with mock contexts. +""" +from __future__ import annotations + +import asyncio +from typing import Any, Optional +from unittest.mock import MagicMock + +import pytest + +sk = pytest.importorskip("semantic_kernel") + +from semantic_kernel import Kernel # noqa: E402 +from semantic_kernel.functions import kernel_function # noqa: E402 +from semantic_kernel.filters.filter_types import FilterTypes # noqa: E402 + +from layerlens.instrument._capture_config import CaptureConfig # noqa: E402 +from layerlens.instrument.adapters.frameworks.semantic_kernel import ( # noqa: E402 + SemanticKernelAdapter, + _extract_arguments, + _extract_function_name, + _extract_plugin_name, +) + +from .conftest import capture_framework_trace, find_event, find_events # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class MathPlugin: + @kernel_function(name="add", description="Add two numbers") + def add(self, a: int, b: int) -> int: + return a + b + + @kernel_function(name="divide", description="Divide a by b") + def divide(self, a: int, b: int) -> float: + return a / b + + +class TextPlugin: + @kernel_function(name="upper", description="Uppercase text") + def upper(self, text: str) -> str: + return text.upper() + + +class MockFunction: + def __init__(self, name: str = "test_func", plugin_name: str = "TestPlugin"): + self.name = name + self.plugin_name = plugin_name + + +class MockContext: + def __init__( + self, + function: Any = None, + arguments: Any = None, + result: Any = None, + rendered_prompt: Optional[str] = None, + function_call_content: Any = None, + function_result: Any = None, + request_sequence_index: int = 0, + function_sequence_index: int = 0, + ): + self.function = function or MockFunction() + self.arguments = arguments + self.result = result + self.rendered_prompt = rendered_prompt + self.function_call_content = function_call_content + self.function_result = function_result + self.request_sequence_index = request_sequence_index + self.function_sequence_index = function_sequence_index + + +class MockFunctionCallContent: + def __init__(self, arguments: Any = None): + self.arguments = arguments + + +class MockFunctionResult: + def __init__(self, value: Any = None): + self.value = value + + +def _run(coro: Any) -> Any: + return asyncio.get_event_loop().run_until_complete(coro) + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def test_connect_registers_filters(self, mock_client): + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + assert adapter.is_connected + assert len(kernel.function_invocation_filters) == 1 + assert len(kernel.prompt_rendering_filters) == 1 + assert len(kernel.auto_function_invocation_filters) == 1 + + info = adapter.adapter_info() + assert info.name == "semantic_kernel" + assert info.adapter_type == "framework" + assert info.connected is True + + adapter.disconnect() + + def test_disconnect_removes_filters(self, mock_client): + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + adapter.disconnect() + + assert not adapter.is_connected + assert len(kernel.function_invocation_filters) == 0 + assert len(kernel.prompt_rendering_filters) == 0 + assert len(kernel.auto_function_invocation_filters) == 0 + + def test_connect_without_target_raises(self, mock_client): + adapter = SemanticKernelAdapter(mock_client) + with pytest.raises(ValueError, match="requires a target kernel"): + adapter.connect() + + def test_connect_without_sk_raises(self, mock_client, monkeypatch): + import layerlens.instrument.adapters.frameworks.semantic_kernel as mod + + monkeypatch.setattr(mod, "_HAS_SEMANTIC_KERNEL", False) + adapter = SemanticKernelAdapter(mock_client) + with pytest.raises(ImportError, match="semantic_kernel"): + adapter.connect(target=Kernel()) + + def test_disconnect_idempotent(self, mock_client): + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + adapter.disconnect() + adapter.disconnect() # should not raise + + +# --------------------------------------------------------------------------- +# Function invocation via real kernel.invoke() +# --------------------------------------------------------------------------- + + +class TestFunctionInvocation: + def test_invoke_emits_tool_call(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + result = _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=2, b=3)) + assert str(result) == "5" + + adapter.disconnect() + + events = uploaded["events"] + tool_calls = find_events(events, "tool.call") + assert len(tool_calls) >= 1 + assert tool_calls[0]["payload"]["tool_name"] == "MathPlugin.add" + assert tool_calls[0]["payload"]["plugin_name"] == "MathPlugin" + assert tool_calls[0]["payload"]["function_name"] == "add" + + tool_results = find_events(events, "tool.result") + assert len(tool_results) >= 1 + assert tool_results[0]["payload"]["status"] == "ok" + assert tool_results[0]["payload"]["latency_ms"] >= 0 + + def test_invoke_captures_output(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=10, b=20)) + adapter.disconnect() + + events = uploaded["events"] + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["output"] == 30 + + def test_invoke_error_emits_agent_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + with pytest.raises(Exception): + _run(kernel.invoke(plugin_name="MathPlugin", function_name="divide", a=1, b=0)) + + adapter.disconnect() + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert "division by zero" in err["payload"]["error"] + assert err["payload"]["error_type"] == "ZeroDivisionError" + assert err["payload"]["tool_name"] == "MathPlugin.divide" + + def test_sequential_invocations(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=1, b=2)) + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=3, b=4)) + adapter.disconnect() + + events = uploaded["events"] + assert len(find_events(events, "tool.call")) == 2 + assert len(find_events(events, "tool.result")) == 2 + + +# --------------------------------------------------------------------------- +# Function invocation filter via direct call +# --------------------------------------------------------------------------- + + +class TestFunctionInvocationFilter: + def test_filter_calls_next_and_emits(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("greet", "HelloPlugin"), + ) + + async def mock_next(context): + context.result = MockFunctionResult("Hi") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["plugin_name"] == "HelloPlugin" + assert tool_call["payload"]["function_name"] == "greet" + + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["status"] == "ok" + + def test_filter_propagates_exception(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx = MockContext() + + async def failing_next(context): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + _run(adapter._function_invocation_filter(ctx, failing_next)) + + adapter.disconnect() + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["error"] == "boom" + assert err["payload"]["error_type"] == "RuntimeError" + + +# --------------------------------------------------------------------------- +# Prompt rendering +# --------------------------------------------------------------------------- + + +class TestPromptRendering: + def test_prompt_render_emits_agent_code(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("summarize", "TextPlugin"), + rendered_prompt="Summarize: Hello world", + ) + + async def mock_next(context): + pass + + _run(adapter._prompt_rendering_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + ev = find_event(events, "agent.code") + assert ev["payload"]["event_subtype"] == "prompt_render" + assert ev["payload"]["function_name"] == "summarize" + assert "Summarize" in ev["payload"]["rendered_prompt"] + + def test_prompt_render_no_content_when_disabled(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + config = CaptureConfig(l2_agent_code=True, capture_content=False) + adapter = SemanticKernelAdapter(mock_client, capture_config=config) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("summarize", "TextPlugin"), + rendered_prompt="secret prompt", + ) + + async def mock_next(context): + pass + + _run(adapter._prompt_rendering_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + ev = find_event(events, "agent.code") + assert "rendered_prompt" not in ev["payload"] + + +# --------------------------------------------------------------------------- +# Auto function invocation (LLM-initiated tool calls) +# --------------------------------------------------------------------------- + + +class TestAutoFunctionInvocation: + def test_auto_function_emits_tool_call_and_result(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("web_search", "SearchPlugin"), + function_call_content=MockFunctionCallContent(arguments={"query": "test"}), + function_result=MockFunctionResult("found it"), + request_sequence_index=1, + function_sequence_index=0, + ) + + async def mock_next(context): + pass + + _run(adapter._auto_function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["auto_invoked"] is True + assert tool_call["payload"]["tool_name"] == "SearchPlugin.web_search" + assert tool_call["payload"]["input"] == {"query": "test"} + assert tool_call["payload"]["request_sequence_index"] == 1 + + tool_results = find_events(events, "tool.result") + assert len(tool_results) == 1 + assert tool_results[0]["payload"]["auto_invoked"] is True + assert tool_results[0]["payload"]["output"] == "found it" + assert tool_results[0]["payload"]["latency_ms"] >= 0 + + def test_auto_function_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("fail_tool", "ToolPlugin"), + ) + + async def failing_next(context): + raise ValueError("tool exploded") + + with pytest.raises(ValueError, match="tool exploded"): + _run(adapter._auto_function_invocation_filter(ctx, failing_next)) + + adapter.disconnect() + + events = uploaded["events"] + # tool.call should still be emitted (before the error) + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["auto_invoked"] is True + + err = find_event(events, "agent.error") + assert err["payload"]["error"] == "tool exploded" + assert err["payload"]["auto_invoked"] is True + + +# --------------------------------------------------------------------------- +# Plugin discovery +# --------------------------------------------------------------------------- + + +class TestPluginDiscovery: + def test_discover_plugins_on_connect(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + kernel.add_plugin(TextPlugin(), "TextPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + adapter.disconnect() + + events = uploaded["events"] + config_events = find_events(events, "environment.config") + plugin_names = {e["payload"]["plugin_name"] for e in config_events} + assert "MathPlugin" in plugin_names + assert "TextPlugin" in plugin_names + + def test_new_plugin_discovered_on_first_call(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + # Invoke filter directly with a plugin not yet seen + ctx = MockContext(function=MockFunction("do_stuff", "NewPlugin")) + + async def mock_next(context): + context.result = MockFunctionResult("ok") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + config_events = find_events(events, "environment.config") + names = {e["payload"]["plugin_name"] for e in config_events} + assert "NewPlugin" in names + + def test_duplicate_plugin_not_rediscovered(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx1 = MockContext(function=MockFunction("f1", "SamePlugin")) + ctx2 = MockContext(function=MockFunction("f2", "SamePlugin")) + + async def mock_next(context): + context.result = MockFunctionResult("ok") + + _run(adapter._function_invocation_filter(ctx1, mock_next)) + _run(adapter._function_invocation_filter(ctx2, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + config_events = find_events(events, "environment.config") + same_plugin = [e for e in config_events if e["payload"]["plugin_name"] == "SamePlugin"] + assert len(same_plugin) == 1 + + +# --------------------------------------------------------------------------- +# CaptureConfig gating +# --------------------------------------------------------------------------- + + +class TestCaptureConfigGating: + def test_no_content_strips_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig(capture_content=False)) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("search", "Plugin"), + arguments={"secret": "key"}, + ) + + async def mock_next(context): + context.result = MockFunctionResult("classified") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert "input" not in tool_call["payload"] + + tool_result = find_event(events, "tool.result") + assert "output" not in tool_result["payload"] + + def test_full_config_includes_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("search", "Plugin"), + arguments={"query": "test"}, + ) + + async def mock_next(context): + context.result = MockFunctionResult("results") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["input"] == {"query": "test"} + + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["output"] == "results" + + +# --------------------------------------------------------------------------- +# LLM call wrapping +# --------------------------------------------------------------------------- + + +class MockUsage: + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + + +class MockChatMessage: + def __init__(self, text: str = "Hello!", model_id: str = "gpt-4o", usage: Any = None): + self.content = text + self.ai_model_id = model_id + self.metadata = {"usage": usage} if usage else {} + + +class MockChatService: + """Minimal mock that looks like a ChatCompletionClientBase to the adapter.""" + + def __init__(self, response_text: str = "Hello!", model_id: str = "gpt-4o", + prompt_tokens: int = 100, completion_tokens: int = 50): + self.ai_model_id = model_id + self._response = MockChatMessage( + text=response_text, + model_id=model_id, + usage=MockUsage(prompt_tokens, completion_tokens), + ) + + async def _inner_get_chat_message_contents(self, chat_history: Any, settings: Any) -> list: + return [self._response] + + +class TestLLMCallWrapping: + def _register_mock_service(self, kernel, service): + """Register a mock service directly on the kernel.""" + kernel.services["mock"] = service + + def test_model_invoke_emitted(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService(prompt_tokens=100, completion_tokens=50) + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + # Call the wrapped method directly + _run(service._inner_get_chat_message_contents(None, None)) + + adapter.disconnect() + + events = uploaded["events"] + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["model"] == "gpt-4o" + assert model_invoke["payload"]["tokens_prompt"] == 100 + assert model_invoke["payload"]["tokens_completion"] == 50 + assert model_invoke["payload"]["latency_ms"] >= 0 + + def test_cost_record_emitted(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService(prompt_tokens=200, completion_tokens=100) + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + _run(service._inner_get_chat_message_contents(None, None)) + adapter.disconnect() + + events = uploaded["events"] + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] == 300 + assert cost["payload"]["model"] == "gpt-4o" + + def test_no_cost_record_without_tokens(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService(prompt_tokens=0, completion_tokens=0) + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + _run(service._inner_get_chat_message_contents(None, None)) + adapter.disconnect() + + events = uploaded["events"] + cost_events = find_events(events, "cost.record") + assert len(cost_events) == 0 + + def test_llm_error_emits_agent_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService() + self._register_mock_service(kernel, service) + + # Replace inner method with one that fails + original = service._inner_get_chat_message_contents + + async def failing_inner(chat_history, settings): + raise RuntimeError("API timeout") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + # The adapter wrapped the original, so replace the original call path + # We need to set up the service to fail BEFORE connect wraps it + # Let's test by reconnecting + adapter.disconnect() + + service._inner_get_chat_message_contents = failing_inner + adapter.connect(target=kernel) + + with pytest.raises(RuntimeError, match="API timeout"): + _run(service._inner_get_chat_message_contents(None, None)) + + adapter.disconnect() + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["error"] == "API timeout" + assert err["payload"]["model"] == "gpt-4o" + + def test_disconnect_restores_original(self, mock_client): + kernel = Kernel() + service = MockChatService() + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + # After connect, the method is our wrapper (an instance attribute, not the class method) + assert "_traced_inner" in service._inner_get_chat_message_contents.__name__ + + adapter.disconnect() + # After disconnect, the instance override is removed and the class method is accessible again + assert "_traced_inner" not in service._inner_get_chat_message_contents.__name__ + + +# --------------------------------------------------------------------------- +# Span hierarchy +# --------------------------------------------------------------------------- + + +class TestSpanHierarchy: + def test_events_share_root_span(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=1, b=2)) + adapter.disconnect() + + events = uploaded["events"] + # All events should share the same root span (via parent_span_id) + parent_spans = {e.get("parent_span_id") for e in events if e.get("parent_span_id")} + # There should be at most one root + assert len(parent_spans) <= 2 # root_span_id from _ensure_collector + our root + + +# --------------------------------------------------------------------------- +# Event structure +# --------------------------------------------------------------------------- + + +class TestEventStructure: + def test_event_fields(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=1, b=2)) + adapter.disconnect() + + events = uploaded["events"] + for event in events: + assert "event_type" in event + assert "trace_id" in event + assert "span_id" in event + assert "sequence_id" in event + assert "timestamp_ns" in event + assert "payload" in event + assert event["payload"]["framework"] == "semantic_kernel" + + seq_ids = [e["sequence_id"] for e in events] + assert seq_ids == sorted(seq_ids) + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +class TestHelpers: + def test_extract_plugin_name_from_function(self): + ctx = MockContext(function=MockFunction(plugin_name="MyPlugin")) + assert _extract_plugin_name(ctx) == "MyPlugin" + + def test_extract_plugin_name_fallback(self): + class Ctx: + function = None + plugin_name = "FallbackPlugin" + + assert _extract_plugin_name(Ctx()) == "FallbackPlugin" + + def test_extract_function_name(self): + ctx = MockContext(function=MockFunction(name="my_func")) + assert _extract_function_name(ctx) == "my_func" + + def test_extract_arguments_dict(self): + ctx = MockContext(arguments={"x": 1, "y": 2}) + assert _extract_arguments(ctx) == {"x": 1, "y": 2} + + def test_extract_arguments_none(self): + ctx = MockContext(arguments=None) + assert _extract_arguments(ctx) is None + + def test_extract_arguments_mapping(self): + """SK KernelArguments has .items() but isn't a dict.""" + class FakeArgs: + def items(self): + return [("a", 1)] + + ctx = MockContext(arguments=FakeArgs()) + assert _extract_arguments(ctx) == {"a": 1}