From 6003e90e8ccbe24005be2c4e0e3cae01d3135af6 Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:51:50 -0700 Subject: [PATCH 1/4] feat: context propagation and upload circuit breaker --- src/layerlens/instrument/__init__.py | 3 + .../instrument/_context_propagation.py | 93 +++ src/layerlens/instrument/_upload.py | 68 +- .../adapters/frameworks/_base_framework.py | 107 ++- tests/instrument/test_trace_context.py | 642 ++++++++++++++++++ 5 files changed, 882 insertions(+), 31 deletions(-) create mode 100644 src/layerlens/instrument/_context_propagation.py create mode 100644 tests/instrument/test_trace_context.py diff --git a/src/layerlens/instrument/__init__.py b/src/layerlens/instrument/__init__.py index a7237a0..04e4667 100644 --- a/src/layerlens/instrument/__init__.py +++ b/src/layerlens/instrument/__init__.py @@ -5,6 +5,7 @@ from ._capture_config import CaptureConfig from ._collector import TraceCollector from ._decorator import trace +from ._context_propagation import trace_context, get_trace_context from .adapters._base import AdapterInfo, BaseAdapter __all__ = [ @@ -13,6 +14,8 @@ "CaptureConfig", "TraceCollector", "emit", + "get_trace_context", "span", "trace", + "trace_context", ] diff --git a/src/layerlens/instrument/_context_propagation.py b/src/layerlens/instrument/_context_propagation.py new file mode 100644 index 0000000..f1ced8d --- /dev/null +++ b/src/layerlens/instrument/_context_propagation.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import uuid +from typing import Any, Dict, Generator, Optional +from contextlib import contextmanager + +from ._collector import TraceCollector +from ._capture_config import CaptureConfig +from ._context import ( + _current_collector, + _current_span_id, + _parent_span_id, + _push_span, + _pop_span, +) + + +@contextmanager +def trace_context( + client: Any, + *, + capture_config: Optional[CaptureConfig] = None, + from_context: Optional[Dict[str, Any]] = None, +) -> Generator[TraceCollector, None, None]: + """Establish a shared trace context for multiple adapters. + + Creates a :class:`TraceCollector` and sets it as the active collector + in ``contextvars`` so that any adapter emitting events inside the + block will use the same ``trace_id`` and span hierarchy. + + When *from_context* is provided (a dict from :func:`get_trace_context`), + the new collector reuses the original ``trace_id`` so events on both + sides of a boundary belong to the same trace. + + The collector is flushed automatically when the context exits. + + Args: + client: A :class:`~layerlens.Stratix` (or compatible) client used + for uploading the trace on flush. + capture_config: Optional capture configuration. Falls back to + :meth:`CaptureConfig.standard` if not provided. + from_context: Optional dict produced by :func:`get_trace_context`. + When supplied the collector inherits the original trace_id. + + Yields: + The shared :class:`TraceCollector`. + """ + config = capture_config or CaptureConfig.standard() + collector = TraceCollector(client, config) + + if from_context is not None: + collector._trace_id = from_context["trace_id"] # noqa: SLF001 + + root_span_id = uuid.uuid4().hex[:16] + + col_token = _current_collector.set(collector) + span_snapshot = _push_span(root_span_id, "trace_context") + try: + yield collector + finally: + _pop_span(span_snapshot) + _current_collector.reset(col_token) + collector.flush() + + +def get_trace_context() -> Optional[Dict[str, Any]]: + """Snapshot the current trace context as a plain dict. + + Returns ``None`` when called outside a ``@trace`` / ``trace_context`` + block. The returned dict is safe to serialise (JSON, headers, message + queues, etc.) and restore via ``trace_context(client, from_context=ctx)``. + + Keys: + + * ``trace_id`` — 16-char hex trace identifier + * ``span_id`` — current span (becomes the parent in the remote scope) + * ``parent_span_id`` — optional grandparent for reference + * ``version`` — format version for forward compatibility + """ + collector = _current_collector.get() + if collector is None: + return None + + span_id = _current_span_id.get() + if span_id is None: + return None + + return { + "trace_id": collector.trace_id, + "span_id": span_id, + "parent_span_id": _parent_span_id.get(), + "version": 1, + } diff --git a/src/layerlens/instrument/_upload.py b/src/layerlens/instrument/_upload.py index c594d29..ae42048 100644 --- a/src/layerlens/instrument/_upload.py +++ b/src/layerlens/instrument/_upload.py @@ -2,16 +2,70 @@ import os import json +import time import asyncio import logging import tempfile +import threading from typing import Any, Dict log: logging.Logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Circuit breaker +# --------------------------------------------------------------------------- + +_lock = threading.Lock() +_error_count = 0 +_circuit_open = False +_opened_at: float = 0.0 + +_THRESHOLD = 10 +_COOLDOWN_S = 60.0 + + +def _allow() -> bool: + global _circuit_open, _error_count + with _lock: + if not _circuit_open: + return True + if time.monotonic() - _opened_at >= _COOLDOWN_S: + _circuit_open = False + _error_count = 0 + log.info("layerlens: upload circuit breaker half-open, retrying") + return True + return False + + +def _on_success() -> None: + global _error_count, _circuit_open + with _lock: + if _error_count > 0: + _error_count = 0 + _circuit_open = False + + +def _on_failure() -> None: + global _error_count, _circuit_open, _opened_at + with _lock: + _error_count += 1 + if _error_count >= _THRESHOLD and not _circuit_open: + _circuit_open = True + _opened_at = time.monotonic() + log.warning( + "layerlens: upload circuit breaker OPEN after %d errors (cooldown %.0fs)", + _error_count, + _COOLDOWN_S, + ) + + +# --------------------------------------------------------------------------- +# Upload +# --------------------------------------------------------------------------- + + def _write_trace_file(payload: Dict[str, Any]) -> str: - """Write trace payload to a temp file and return its path.""" fd, path = tempfile.mkstemp(suffix=".json", prefix="layerlens_trace_") with os.fdopen(fd, "w") as f: json.dump([payload], f, default=str) @@ -19,9 +73,15 @@ def _write_trace_file(payload: Dict[str, Any]) -> str: def upload_trace(client: Any, payload: Dict[str, Any]) -> None: + if not _allow(): + return path = _write_trace_file(payload) try: client.traces.upload(path) + _on_success() + except Exception: + _on_failure() + log.warning("layerlens: trace upload failed", exc_info=True) finally: try: os.unlink(path) @@ -30,9 +90,15 @@ def upload_trace(client: Any, payload: Dict[str, Any]) -> None: async def async_upload_trace(client: Any, payload: Dict[str, Any]) -> None: + if not _allow(): + return path = await asyncio.to_thread(_write_trace_file, payload) try: await client.traces.upload(path) + _on_success() + except Exception: + _on_failure() + log.warning("layerlens: async trace upload failed", exc_info=True) finally: try: os.unlink(path) diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 197c65e..20ddbdb 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -1,12 +1,4 @@ -"""Unified base class for all framework adapters. - -Framework adapters hook into a framework's callback / event / tracing -system and emit LayerLens events. They share a common lifecycle: - - 1. Lazy-init a :class:`TraceCollector` on first event. - 2. Emit events through a thread-safe helper. - 3. Flush the collector when a logical trace ends (root span completes, - agent run finishes, disconnect, etc.). +"""Base class for framework adapters. Subclasses MUST set ``name`` and implement ``connect()``. Subclasses SHOULD call ``super().disconnect()`` after unhooking. @@ -14,12 +6,17 @@ from __future__ import annotations import uuid +import logging import threading -from typing import Any, Dict, Optional +from typing import Any, Dict, Generator, Optional +from contextlib import contextmanager from .._base import AdapterInfo, BaseAdapter from ..._collector import TraceCollector from ..._capture_config import CaptureConfig +from ..._context import _current_collector, _current_span_id, _push_span, _pop_span + +log = logging.getLogger(__name__) class FrameworkAdapter(BaseAdapter): @@ -34,16 +31,27 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._connected = False self._collector: Optional[TraceCollector] = None self._root_span_id: Optional[str] = None + self._using_shared_collector = False # Optional run_id → span_id mapping for callback-style frameworks self._span_ids: Dict[str, str] = {} + # Subclasses populate during connect() for adapter_info() metadata + self._metadata: Dict[str, Any] = {} # ------------------------------------------------------------------ # Collector lifecycle # ------------------------------------------------------------------ def _ensure_collector(self) -> TraceCollector: - """Lazily create a collector and root span ID.""" + """Return the shared collector from ContextVars, or create a private one.""" + shared = _current_collector.get() + if shared is not None: + self._using_shared_collector = True + if self._root_span_id is None: + self._root_span_id = _current_span_id.get() + return shared + if self._collector is None: + self._using_shared_collector = False self._collector = TraceCollector(self._client, self._config) self._root_span_id = uuid.uuid4().hex[:16] return self._collector @@ -52,6 +60,55 @@ def _ensure_collector(self) -> TraceCollector: def _new_span_id() -> str: return uuid.uuid4().hex[:16] + # ------------------------------------------------------------------ + # Callback scope — bridges framework callbacks to ContextVars + # ------------------------------------------------------------------ + + @contextmanager + def _callback_scope( + self, + span_name: Optional[str] = None, + ) -> Generator[str, None, None]: + """Push collector + new span into ContextVars; yields the span_id.""" + collector = self._ensure_collector() + span_id = self._new_span_id() + + # Only set the collector ContextVar if no shared one exists already + needs_collector_push = _current_collector.get() is None + col_token = None + if needs_collector_push: + col_token = _current_collector.set(collector) + + snapshot = _push_span(span_id, span_name) + try: + yield span_id + finally: + _pop_span(snapshot) + if col_token is not None: + _current_collector.reset(col_token) + + def _traced_call( + self, + original: Any, + *args: Any, + _span_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Call *original* inside a _callback_scope so providers see this collector.""" + with self._callback_scope(_span_name): + return original(*args, **kwargs) + + async def _async_traced_call( + self, + original: Any, + *args: Any, + _span_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Async version of _traced_call.""" + with self._callback_scope(_span_name): + return await original(*args, **kwargs) + # ------------------------------------------------------------------ # Event emission (thread-safe) # ------------------------------------------------------------------ @@ -79,12 +136,7 @@ def _emit( # ------------------------------------------------------------------ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: - """Map a framework run_id to a span_id, creating one if needed. - - Returns ``(span_id, parent_span_id)``. Useful for frameworks - (LangChain, CrewAI, OpenAI Agents) that assign their own run - identifiers to each step. - """ + """Map a framework run_id to a (span_id, parent_span_id) pair.""" rid = str(run_id) if rid not in self._span_ids: self._span_ids[rid] = self._new_span_id() @@ -97,13 +149,15 @@ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Opt # ------------------------------------------------------------------ def _flush_collector(self) -> None: - """Flush the current collector and reset state.""" + """Flush private collector (no-op for shared collectors).""" with self._lock: collector = self._collector + is_shared = self._using_shared_collector self._collector = None self._root_span_id = None + self._using_shared_collector = False self._span_ids.clear() - if collector is not None: + if collector is not None and not is_shared: collector.flush() # ------------------------------------------------------------------ @@ -111,27 +165,20 @@ def _flush_collector(self) -> None: # ------------------------------------------------------------------ def connect(self, target: Any = None, **kwargs: Any) -> Any: - """Mark the adapter as connected. - - Callback-style adapters (LangChain, LangGraph) are passed directly - to the framework, so ``connect()`` just flips the flag. Adapters - that need registration (CrewAI, LlamaIndex, etc.) should override. - """ + """Mark as connected. Subclasses override for framework registration.""" self._connected = True return target def disconnect(self) -> None: - """Flush remaining events and mark as disconnected. - - Subclasses should unhook from the framework first, then call - ``super().disconnect()``. - """ + """Flush remaining events and mark as disconnected.""" self._flush_collector() self._connected = False + self._metadata.clear() def adapter_info(self) -> AdapterInfo: return AdapterInfo( name=self.name, adapter_type="framework", connected=self._connected, + metadata=self._metadata, ) diff --git a/tests/instrument/test_trace_context.py b/tests/instrument/test_trace_context.py new file mode 100644 index 0000000..03a09be --- /dev/null +++ b/tests/instrument/test_trace_context.py @@ -0,0 +1,642 @@ +"""Tests for trace context: shared collectors, context propagation, +callback scope, and upload circuit breaker. +""" +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest + +from layerlens.instrument import ( + trace, + trace_context, + emit, + span, + get_trace_context, + CaptureConfig, +) +from layerlens.instrument._context import _current_collector, _current_span_id +from layerlens.instrument._collector import TraceCollector +from layerlens.instrument import _upload +from layerlens.instrument.adapters.frameworks._base_framework import FrameworkAdapter + +from .conftest import find_event, find_events + + +# --------------------------------------------------------------------------- +# Minimal concrete adapter for testing +# --------------------------------------------------------------------------- + +class StubAdapter(FrameworkAdapter): + name = "stub" + + def connect(self, target: Any = None, **kwargs: Any) -> Any: + self._connected = True + return target + + def fire_event(self, event_type: str, payload: Dict[str, Any], + span_id: Optional[str] = None, + parent_span_id: Optional[str] = None) -> None: + self._emit(event_type, payload, span_id=span_id, + parent_span_id=parent_span_id, span_name=event_type) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_client(): + client = Mock() + client.traces = Mock() + client.traces.upload = Mock() + return client + + +@pytest.fixture +def capture_trace(mock_client): + """Capture uploaded trace payloads. Supports multiple uploads.""" + uploads: List[Dict[str, Any]] = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + uploads.append(data[0]) + + mock_client.traces.upload.side_effect = _capture + return uploads + + +@pytest.fixture(autouse=True) +def reset_circuit_breaker(): + """Reset the upload circuit breaker between tests.""" + _upload._error_count = 0 + _upload._circuit_open = False + _upload._opened_at = 0.0 + yield + _upload._error_count = 0 + _upload._circuit_open = False + _upload._opened_at = 0.0 + + +# =================================================================== +# 1. Shared trace_id via @trace +# =================================================================== + +class TestSharedCollectorViaTrace: + + def test_framework_adapter_shares_trace_id_with_trace_decorator( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + adapter.fire_event("agent.lifecycle", {"action": "crew.start"}) + return "done" + + agent_run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + lifecycle = find_event(events, "agent.lifecycle") + agent_input = find_event(events, "agent.input") + assert lifecycle["trace_id"] == agent_input["trace_id"] + + def test_multiple_adapters_share_same_trace( + self, mock_client, capture_trace, + ): + adapter_a = StubAdapter(mock_client) + adapter_b = StubAdapter(mock_client) + adapter_a.connect() + adapter_b.connect() + + @trace(mock_client) + def agent_run(): + adapter_a.fire_event("agent.lifecycle", {"source": "A"}) + adapter_b.fire_event("agent.lifecycle", {"source": "B"}) + return "done" + + agent_run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + lifecycles = find_events(events, "agent.lifecycle") + assert len(lifecycles) == 2 + assert lifecycles[0]["trace_id"] == lifecycles[1]["trace_id"] + + def test_framework_adapter_standalone_creates_own_trace( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + adapter.fire_event("agent.lifecycle", {"action": "standalone"}) + adapter.disconnect() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + assert len(events) == 1 + assert events[0]["event_type"] == "agent.lifecycle" + + +# =================================================================== +# 2. Cross-adapter parent-child spans +# =================================================================== + +class TestCrossAdapterSpanHierarchy: + + def test_framework_events_parent_to_trace_root_span( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + adapter.fire_event("agent.lifecycle", {"action": "start"}) + return "done" + + agent_run() + + events = capture_trace[0]["events"] + agent_input = find_event(events, "agent.input") + lifecycle = find_event(events, "agent.lifecycle") + root_span = agent_input["span_id"] + assert lifecycle["parent_span_id"] == root_span + + def test_framework_events_parent_to_active_span( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + with span("retrieval"): + adapter.fire_event("tool.call", {"name": "search", "input": "q"}) + return "done" + + agent_run() + + events = capture_trace[0]["events"] + agent_input = find_event(events, "agent.input") + tool_call = find_event(events, "tool.call") + assert tool_call["parent_span_id"] is not None + assert tool_call["trace_id"] == agent_input["trace_id"] + + def test_adapter_with_explicit_parent_overrides_default( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + explicit_parent = "custom_parent_id" + + @trace(mock_client) + def agent_run(): + adapter.fire_event( + "agent.lifecycle", {"action": "step"}, + parent_span_id=explicit_parent, + ) + return "done" + + agent_run() + + events = capture_trace[0]["events"] + lifecycle = find_event(events, "agent.lifecycle") + assert lifecycle["parent_span_id"] == explicit_parent + + +# =================================================================== +# 3. trace_context() +# =================================================================== + +class TestTraceContext: + + def test_creates_shared_collector(self, mock_client, capture_trace): + adapter_a = StubAdapter(mock_client) + adapter_b = StubAdapter(mock_client) + adapter_a.connect() + adapter_b.connect() + + with trace_context(mock_client): + adapter_a.fire_event("agent.lifecycle", {"source": "A"}) + adapter_b.fire_event("agent.lifecycle", {"source": "B"}) + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + assert len(events) == 2 + assert events[0]["trace_id"] == events[1]["trace_id"] + + def test_flushes_on_exit(self, mock_client, capture_trace): + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + + def test_cleans_up_on_exit(self, mock_client): + with trace_context(mock_client): + assert _current_collector.get() is not None + + assert _current_collector.get() is None + assert _current_span_id.get() is None + + def test_cleans_up_on_error(self, mock_client): + with pytest.raises(RuntimeError): + with trace_context(mock_client): + raise RuntimeError("boom") + + assert _current_collector.get() is None + assert _current_span_id.get() is None + + def test_yields_collector(self, mock_client): + with trace_context(mock_client) as collector: + assert isinstance(collector, TraceCollector) + assert len(collector.trace_id) == 16 + + def test_with_custom_capture_config(self, mock_client, capture_trace): + config = CaptureConfig.standard() + + with trace_context(mock_client, capture_config=config): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + assert capture_trace[0]["capture_config"] == config.to_dict() + + +# =================================================================== +# 4. Context serialisation (get_trace_context / from_context) +# =================================================================== + +class TestGetTraceContext: + + def test_returns_none_outside_trace(self): + assert get_trace_context() is None + + def test_returns_dict_inside_trace(self, mock_client, capture_trace): + @trace(mock_client) + def run(): + ctx = get_trace_context() + assert ctx is not None + assert "trace_id" in ctx + assert "span_id" in ctx + assert "parent_span_id" in ctx + assert ctx["version"] == 1 + return ctx + + ctx = run() + assert len(ctx["trace_id"]) == 16 + assert len(ctx["span_id"]) == 16 + + def test_returns_dict_inside_trace_context(self, mock_client, capture_trace): + with trace_context(mock_client): + ctx = get_trace_context() + assert ctx is not None + assert len(ctx["trace_id"]) == 16 + + def test_span_id_updates_inside_child_span(self, mock_client, capture_trace): + @trace(mock_client) + def run(): + ctx_outer = get_trace_context() + with span("inner"): + ctx_inner = get_trace_context() + return ctx_outer, ctx_inner + + outer, inner = run() + assert outer["trace_id"] == inner["trace_id"] + assert outer["span_id"] != inner["span_id"] + + +class TestTraceContextFromContext: + + def test_restores_trace_id(self, mock_client, capture_trace): + with trace_context(mock_client): + original_ctx = get_trace_context() + emit("tool.call", {"name": "origin", "input": "x"}) + + original_trace_id = original_ctx["trace_id"] + + with trace_context(mock_client, from_context=original_ctx) as restored: + assert restored.trace_id == original_trace_id + emit("tool.call", {"name": "remote", "input": "y"}) + + assert len(capture_trace) == 2 + assert capture_trace[0]["trace_id"] == capture_trace[1]["trace_id"] + + def test_creates_child_span(self, mock_client, capture_trace): + with trace_context(mock_client): + original_ctx = get_trace_context() + emit("tool.call", {"name": "origin", "input": "x"}) + + with trace_context(mock_client, from_context=original_ctx): + ctx_inside = get_trace_context() + + assert ctx_inside["span_id"] != original_ctx["span_id"] + assert ctx_inside["trace_id"] == original_ctx["trace_id"] + + +# =================================================================== +# 5. Flush semantics +# =================================================================== + +class TestFlushSemantics: + + def test_adapter_disconnect_does_not_flush_shared_collector( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def agent_run(): + adapter.fire_event("agent.lifecycle", {"action": "start"}) + adapter.disconnect() + emit("tool.call", {"name": "post_disconnect", "input": "x"}) + return "done" + + agent_run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + types = [e["event_type"] for e in events] + assert "agent.lifecycle" in types + assert "tool.call" in types + assert "agent.output" in types + + def test_adapter_disconnect_flushes_own_collector_when_standalone( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + adapter.fire_event("agent.lifecycle", {"action": "standalone"}) + adapter.disconnect() + + assert len(capture_trace) == 1 + + def test_multiple_adapters_disconnect_independently_under_shared_context( + self, mock_client, capture_trace, + ): + adapter_a = StubAdapter(mock_client) + adapter_b = StubAdapter(mock_client) + adapter_a.connect() + adapter_b.connect() + + with trace_context(mock_client): + adapter_a.fire_event("agent.lifecycle", {"source": "A"}) + adapter_a.disconnect() + adapter_b.fire_event("agent.lifecycle", {"source": "B"}) + adapter_b.disconnect() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + sources = [e["payload"]["source"] for e in events] + assert "A" in sources + assert "B" in sources + + +# =================================================================== +# 6. Callback scope + _traced_call +# =================================================================== + +class TestCallbackScope: + + def test_pushes_collector_when_standalone(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + assert _current_collector.get() is None + with adapter._callback_scope("test_scope") as scope_span_id: + assert _current_collector.get() is not None + assert _current_span_id.get() == scope_span_id + emit("tool.call", {"name": "test", "input": "x"}) + + assert _current_collector.get() is None + + def test_preserves_shared_collector(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def run(): + shared_collector = _current_collector.get() + with adapter._callback_scope("inner") as scope_span: + assert _current_collector.get() is shared_collector + assert _current_span_id.get() == scope_span + emit("tool.call", {"name": "inner_tool", "input": "x"}) + return "done" + + run() + + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["name"] == "inner_tool" + + def test_creates_child_span(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def run(): + root_span = _current_span_id.get() + with adapter._callback_scope("child"): + child_span = _current_span_id.get() + assert child_span != root_span + emit("tool.call", {"name": "scoped", "input": "x"}) + assert _current_span_id.get() == root_span + return "done" + + run() + + def test_cleans_up_on_error(self, mock_client): + adapter = StubAdapter(mock_client) + adapter.connect() + + with pytest.raises(RuntimeError): + with adapter._callback_scope("failing"): + raise RuntimeError("boom") + + assert _current_collector.get() is None + assert _current_span_id.get() is None + + def test_traced_call_makes_providers_visible(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + def fake_agent_run(prompt): + assert _current_collector.get() is not None + emit("model.invoke", {"model": "gpt-4", "input": prompt}) + return "result" + + assert _current_collector.get() is None + result = adapter._traced_call(fake_agent_run, "hello", _span_name="agent.run") + assert result == "result" + assert _current_collector.get() is None + + adapter.disconnect() + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + model_event = find_event(events, "model.invoke") + assert model_event["payload"]["model"] == "gpt-4" + + def test_traced_call_under_shared_context(self, mock_client, capture_trace): + adapter = StubAdapter(mock_client) + adapter.connect() + + def fake_agent_run(prompt): + emit("model.invoke", {"model": "gpt-4", "input": prompt}) + return "result" + + @trace(mock_client) + def run(): + return adapter._traced_call(fake_agent_run, "hello", _span_name="agent.run") + + run() + assert len(capture_trace) == 1 + events = capture_trace[0]["events"] + assert find_event(events, "model.invoke") + assert find_event(events, "agent.input") + + +# =================================================================== +# 7. Upload circuit breaker +# =================================================================== + +class TestUploadCircuitBreaker: + + def test_successful_upload(self, mock_client, capture_trace): + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + assert _upload._error_count == 0 + + def test_upload_failure_records_error(self, mock_client): + mock_client.traces.upload.side_effect = RuntimeError("network error") + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert _upload._error_count == 1 + assert not _upload._circuit_open + + def test_circuit_opens_after_threshold(self, mock_client): + mock_client.traces.upload.side_effect = RuntimeError("network error") + + for _ in range(_upload._THRESHOLD): + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert _upload._circuit_open + assert _upload._error_count == _upload._THRESHOLD + + def test_open_circuit_skips_upload(self, mock_client): + _upload._circuit_open = True + _upload._opened_at = __import__("time").monotonic() + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + mock_client.traces.upload.assert_not_called() + + def test_circuit_resets_after_cooldown(self, mock_client, capture_trace): + _upload._circuit_open = True + _upload._error_count = _upload._THRESHOLD + _upload._opened_at = ( + __import__("time").monotonic() - _upload._COOLDOWN_S - 1 + ) + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert len(capture_trace) == 1 + assert not _upload._circuit_open + assert _upload._error_count == 0 + + def test_success_after_failures_resets_count(self, mock_client, capture_trace): + _upload._error_count = 5 + + with trace_context(mock_client): + emit("tool.call", {"name": "test", "input": "x"}) + + assert _upload._error_count == 0 + + def test_protects_trace_decorator(self, mock_client): + _upload._circuit_open = True + _upload._opened_at = __import__("time").monotonic() + + @trace(mock_client) + def run(): + emit("tool.call", {"name": "test", "input": "x"}) + return "done" + + run() + mock_client.traces.upload.assert_not_called() + + def test_protects_framework_adapter(self, mock_client): + adapter = StubAdapter(mock_client) + adapter.connect() + + _upload._circuit_open = True + _upload._opened_at = __import__("time").monotonic() + + adapter.fire_event("tool.call", {"name": "test", "input": "x"}) + adapter.disconnect() + + mock_client.traces.upload.assert_not_called() + + +# =================================================================== +# 8. Edge cases +# =================================================================== + +class TestEdgeCases: + + def test_adapter_used_across_multiple_traces( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + @trace(mock_client) + def run_1(): + adapter.fire_event("agent.lifecycle", {"run": 1}) + return "done" + + @trace(mock_client) + def run_2(): + adapter.fire_event("agent.lifecycle", {"run": 2}) + return "done" + + run_1() + run_2() + + assert len(capture_trace) == 2 + assert capture_trace[0]["trace_id"] != capture_trace[1]["trace_id"] + + def test_no_events_means_no_upload(self, mock_client): + with trace_context(mock_client): + pass + + mock_client.traces.upload.assert_not_called() + + def test_standalone_adapter_unaffected_by_previous_shared_context( + self, mock_client, capture_trace, + ): + adapter = StubAdapter(mock_client) + adapter.connect() + + with trace_context(mock_client): + adapter.fire_event("agent.lifecycle", {"phase": "shared"}) + + adapter.disconnect() + + adapter = StubAdapter(mock_client) + adapter.connect() + adapter.fire_event("agent.lifecycle", {"phase": "standalone"}) + adapter.disconnect() + + assert len(capture_trace) == 2 + assert capture_trace[0]["trace_id"] != capture_trace[1]["trace_id"] From 313ba105b6b82c89fdbbc65e6749bd89a640244c Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:46:27 -0700 Subject: [PATCH 2/4] feat: updates + new adapters --- src/layerlens/instrument/_context.py | 23 +- .../adapters/frameworks/_base_framework.py | 269 +++++- .../instrument/adapters/frameworks/_utils.py | 69 ++ .../instrument/adapters/frameworks/crewai.py | 475 ++++++++++ .../adapters/frameworks/langchain.py | 232 +++-- .../adapters/frameworks/langgraph.py | 7 +- .../adapters/frameworks/openai_agents.py | 306 +++++++ .../adapters/frameworks/pydantic_ai.py | 350 ++++++++ .../adapters/frameworks/semantic_kernel.py | 389 +++++++++ .../adapters/frameworks/test_crewai.py | 808 +++++++++++++++++ .../adapters/frameworks/test_langchain.py | 347 ++++++-- .../adapters/frameworks/test_langgraph.py | 2 +- .../adapters/frameworks/test_openai_agents.py | 823 ++++++++++++++++++ .../adapters/frameworks/test_pydantic_ai.py | 471 ++++++++++ .../frameworks/test_semantic_kernel.py | 753 ++++++++++++++++ 15 files changed, 5158 insertions(+), 166 deletions(-) create mode 100644 src/layerlens/instrument/adapters/frameworks/_utils.py create mode 100644 src/layerlens/instrument/adapters/frameworks/crewai.py create mode 100644 src/layerlens/instrument/adapters/frameworks/openai_agents.py create mode 100644 src/layerlens/instrument/adapters/frameworks/pydantic_ai.py create mode 100644 src/layerlens/instrument/adapters/frameworks/semantic_kernel.py create mode 100644 tests/instrument/adapters/frameworks/test_crewai.py create mode 100644 tests/instrument/adapters/frameworks/test_openai_agents.py create mode 100644 tests/instrument/adapters/frameworks/test_pydantic_ai.py create mode 100644 tests/instrument/adapters/frameworks/test_semantic_kernel.py diff --git a/src/layerlens/instrument/_context.py b/src/layerlens/instrument/_context.py index dc1f873..fce18d2 100644 --- a/src/layerlens/instrument/_context.py +++ b/src/layerlens/instrument/_context.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Optional, NamedTuple +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, NamedTuple from contextvars import ContextVar from ._collector import TraceCollector @@ -11,6 +12,26 @@ _current_span_name: ContextVar[Optional[str]] = ContextVar("_current_span_name", default=None) +@dataclass +class RunState: + """Per-run state isolated via ContextVar. + + Each concurrent run (agent invocation, crew kickoff, etc.) gets its own + RunState stored in ``_current_run``. This isolates the collector, root span, + timers, and any adapter-specific data so concurrent runs on the same adapter + instance don't clobber each other. + """ + + collector: TraceCollector + root_span_id: str + timers: Dict[str, int] = field(default_factory=dict) + data: Dict[str, Any] = field(default_factory=dict) + _token: Any = field(default=None, repr=False) + + +_current_run: ContextVar[Optional[RunState]] = ContextVar("_current_run", default=None) + + class _SpanSnapshot(NamedTuple): span_id: Any parent_span_id: Any diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 20ddbdb..8190510 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import time import uuid import logging import threading @@ -14,15 +15,27 @@ from .._base import AdapterInfo, BaseAdapter from ..._collector import TraceCollector from ..._capture_config import CaptureConfig -from ..._context import _current_collector, _current_span_id, _push_span, _pop_span +from ..._context import _current_collector, _current_span_id, _push_span, _pop_span, _current_run, RunState log = logging.getLogger(__name__) +_UNSET: Any = object() # sentinel: distinguish "not passed" from explicit None + class FrameworkAdapter(BaseAdapter): """Base for framework adapters with collector lifecycle management.""" name: str # Subclass must set: "crewai", "llamaindex", etc. + package: str = "" # pip extra name, e.g. "crewai" → pip install layerlens[crewai] + + def _check_dependency(self, available: bool) -> None: + """Raise ImportError with a helpful install message if the dependency is missing.""" + if not available: + pkg = self.package or self.name + raise ImportError( + "The '%s' package is required for %s instrumentation. " + "Install it with: pip install layerlens[%s]" % (pkg, self.name, pkg) + ) def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: self._client = client @@ -34,15 +47,68 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._using_shared_collector = False # Optional run_id → span_id mapping for callback-style frameworks self._span_ids: Dict[str, str] = {} + # Root run tracking for auto-flush on outermost callback completion + self._root_run_id: Optional[str] = None + # Timing: key → start_ns for _start_timer / _stop_timer + self._timers: Dict[str, int] = {} # Subclasses populate during connect() for adapter_info() metadata self._metadata: Dict[str, Any] = {} + # ------------------------------------------------------------------ + # Per-run state (ContextVar-based isolation for concurrent runs) + # ------------------------------------------------------------------ + + def _begin_run(self) -> RunState: + """Start a new run with its own collector, root span, and timers. + + Stores the RunState in a ContextVar so all subsequent calls to + ``_ensure_collector``, ``_start_timer``, ``_stop_timer``, and + ``_get_root_span`` use per-run state instead of instance state. + + ContextVars are automatically isolated per ``asyncio.Task``, so + concurrent runs on the same adapter get independent state. + """ + run = RunState( + collector=TraceCollector(self._client, self._config), + root_span_id=uuid.uuid4().hex[:16], + ) + run._token = _current_run.set(run) + return run + + def _end_run(self) -> None: + """Flush the current run's collector and restore the previous ContextVar state.""" + run = _current_run.get() + if run is None: + return + if run._token is not None: + try: + _current_run.reset(run._token) + except ValueError: + # Token created in a different Context (e.g. framework copies + # contexts between hook callbacks). Fall back to plain set. + _current_run.set(None) + else: + _current_run.set(None) + run.collector.flush() + + def _get_run(self) -> Optional[RunState]: + """Return the current RunState, or None if not inside a ``_begin_run`` scope.""" + return _current_run.get() + # ------------------------------------------------------------------ # Collector lifecycle # ------------------------------------------------------------------ def _ensure_collector(self) -> TraceCollector: - """Return the shared collector from ContextVars, or create a private one.""" + """Return the collector for the current context. + + Checks (in order): active RunState, shared collector from ContextVars, + then creates a private instance-level collector as fallback. + """ + run = _current_run.get() + if run is not None: + return run.collector + shared = _current_collector.get() if shared is not None: self._using_shared_collector = True @@ -60,32 +126,141 @@ def _ensure_collector(self) -> TraceCollector: def _new_span_id() -> str: return uuid.uuid4().hex[:16] + # ------------------------------------------------------------------ + # Shared helpers — payload, timing, tokens, content gating + # ------------------------------------------------------------------ + + def _payload(self, **extra: Any) -> Dict[str, Any]: + """Start a payload dict with ``framework: self.name``. + + Usage:: + + payload = self._payload(agent_name="foo", status="ok") + """ + p: Dict[str, Any] = {"framework": self.name} + if extra: + p.update(extra) + return p + + def _get_root_span(self) -> str: + """Return the root span ID for the current run. + + Checks RunState first, then falls back to instance-level ``_root_span_id``. + If neither is set, generates a new one. + """ + run = _current_run.get() + if run is not None: + return run.root_span_id + + with self._lock: + sid = self._root_span_id + if sid is not None: + return sid + sid = self._new_span_id() + with self._lock: + self._root_span_id = sid + return sid + + def _start_timer(self, key: str) -> None: + """Record a start timestamp (nanoseconds) under *key*.""" + run = _current_run.get() + if run is not None: + run.timers[key] = time.time_ns() + return + with self._lock: + self._timers[key] = time.time_ns() + + def _stop_timer(self, key: str) -> Optional[float]: + """Pop the start time for *key* and return elapsed ``latency_ms``, or ``None``.""" + run = _current_run.get() + if run is not None: + start_ns = run.timers.pop(key, 0) + else: + with self._lock: + start_ns = self._timers.pop(key, 0) + if not start_ns: + return None + return (time.time_ns() - start_ns) / 1_000_000 + + @staticmethod + def _normalize_tokens(usage: Any) -> Dict[str, Any]: + """Extract token counts from any usage object or dict. + + Handles field-name variants across providers: + ``prompt_tokens`` / ``input_tokens`` → ``tokens_prompt`` + ``completion_tokens`` / ``output_tokens`` → ``tokens_completion`` + + Returns a dict with ``tokens_prompt``, ``tokens_completion``, + ``tokens_total`` — only keys that have non-zero values. + """ + tokens: Dict[str, Any] = {} + if usage is None: + return tokens + + if isinstance(usage, dict): + prompt = usage.get("prompt_tokens") or usage.get("input_tokens") + completion = usage.get("completion_tokens") or usage.get("output_tokens") + total = usage.get("total_tokens") + else: + prompt = ( + getattr(usage, "prompt_tokens", None) + or getattr(usage, "input_tokens", None) + ) + completion = ( + getattr(usage, "completion_tokens", None) + or getattr(usage, "output_tokens", None) + ) + total = getattr(usage, "total_tokens", None) + + if prompt is not None: + tokens["tokens_prompt"] = int(prompt) + if completion is not None: + tokens["tokens_completion"] = int(completion) + if prompt is not None and completion is not None: + tokens["tokens_total"] = int(prompt) + int(completion) + elif total is not None: + tokens["tokens_total"] = int(total) + return tokens + + def _set_if_capturing(self, payload: Dict[str, Any], key: str, value: Any) -> None: + """Set ``payload[key] = value`` only if ``capture_content`` is enabled.""" + if self._config.capture_content and value is not None: + payload[key] = value + # ------------------------------------------------------------------ # Callback scope — bridges framework callbacks to ContextVars # ------------------------------------------------------------------ + def _push_context(self, span_id: str, span_name: Optional[str] = None) -> Any: + """Push collector + span into ContextVars. Returns an opaque token for ``_pop_context``.""" + with self._lock: + collector = self._ensure_collector() + needs_collector_push = _current_collector.get() is None + col_token = _current_collector.set(collector) if needs_collector_push else None + snapshot = _push_span(span_id, span_name) + return (snapshot, col_token) + + def _pop_context(self, token: Any) -> None: + """Restore ContextVars from a token returned by ``_push_context``.""" + if token is None: + return + snapshot, col_token = token + _pop_span(snapshot) + if col_token is not None: + _current_collector.reset(col_token) + @contextmanager def _callback_scope( self, span_name: Optional[str] = None, ) -> Generator[str, None, None]: """Push collector + new span into ContextVars; yields the span_id.""" - collector = self._ensure_collector() span_id = self._new_span_id() - - # Only set the collector ContextVar if no shared one exists already - needs_collector_push = _current_collector.get() is None - col_token = None - if needs_collector_push: - col_token = _current_collector.set(collector) - - snapshot = _push_span(span_id, span_name) + token = self._push_context(span_id, span_name) try: yield span_id finally: - _pop_span(snapshot) - if col_token is not None: - _current_collector.reset(col_token) + self._pop_context(token) def _traced_call( self, @@ -118,14 +293,43 @@ def _emit( event_type: str, payload: Dict[str, Any], span_id: Optional[str] = None, - parent_span_id: Optional[str] = None, + parent_span_id: Any = _UNSET, span_name: Optional[str] = None, + run_id: Any = None, + parent_run_id: Any = None, ) -> None: - """Thread-safe event emission through the collector.""" + """Thread-safe event emission through the collector. + + When *run_id* is provided, it is translated to a span_id via + ``_span_id_for`` and the first run_id seen is tracked as the root + (for flush-on-completion in callback-style frameworks). + + When *parent_span_id* is omitted, falls back to ``_root_span_id``. + Pass ``parent_span_id=None`` explicitly to emit with no parent + (for adapters that manage their own span hierarchy). + """ + # RunState path: per-run isolation, no lock needed + run = _current_run.get() + if run is not None: + if run_id is not None: + span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) + sid = span_id or self._new_span_id() + parent = run.root_span_id if parent_span_id is _UNSET else parent_span_id + run.collector.emit( + event_type, payload, + span_id=sid, parent_span_id=parent, span_name=span_name, + ) + return + + # Legacy path: instance-level state with lock + if run_id is not None: + span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) + if self._root_run_id is None: + self._root_run_id = str(run_id) with self._lock: collector = self._ensure_collector() sid = span_id or self._new_span_id() - parent = parent_span_id or self._root_span_id + parent = self._root_span_id if parent_span_id is _UNSET else parent_span_id collector.emit( event_type, payload, span_id=sid, parent_span_id=parent, span_name=span_name, @@ -136,12 +340,19 @@ def _emit( # ------------------------------------------------------------------ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: - """Map a framework run_id to a (span_id, parent_span_id) pair.""" + """Map a framework run_id to a (span_id, parent_span_id) pair. + + When a RunState is active, span_ids are stored per-run in + ``run.data["span_ids"]`` for concurrent-run isolation. + Falls back to instance-level ``_span_ids`` otherwise. + """ + run = _current_run.get() + span_ids = run.data.setdefault("span_ids", {}) if run is not None else self._span_ids rid = str(run_id) - if rid not in self._span_ids: - self._span_ids[rid] = self._new_span_id() - span_id = self._span_ids[rid] - parent_span_id = self._span_ids.get(str(parent_run_id)) if parent_run_id else None + if rid not in span_ids: + span_ids[rid] = self._new_span_id() + span_id = span_ids[rid] + parent_span_id = span_ids.get(str(parent_run_id)) if parent_run_id else None return span_id, parent_span_id # ------------------------------------------------------------------ @@ -165,16 +376,26 @@ def _flush_collector(self) -> None: # ------------------------------------------------------------------ def connect(self, target: Any = None, **kwargs: Any) -> Any: - """Mark as connected. Subclasses override for framework registration.""" + """Check dependencies, run framework-specific setup, and mark as connected.""" + self._on_connect(target, **kwargs) self._connected = True return target + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + """Override to set up framework-specific resources (subscribe, wrap, etc.).""" + pass + def disconnect(self) -> None: - """Flush remaining events and mark as disconnected.""" + """Clean up framework resources, flush events, and mark as disconnected.""" + self._on_disconnect() self._flush_collector() self._connected = False self._metadata.clear() + def _on_disconnect(self) -> None: + """Override to clean up framework-specific resources (unsubscribe, restore, etc.).""" + pass + def adapter_info(self) -> AdapterInfo: return AdapterInfo( name=self.name, diff --git a/src/layerlens/instrument/adapters/frameworks/_utils.py b/src/layerlens/instrument/adapters/frameworks/_utils.py new file mode 100644 index 0000000..fdd66be --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/_utils.py @@ -0,0 +1,69 @@ +"""Shared utilities for framework adapters. + +Centralises helpers that were previously copy-pasted across adapter +files: serialisation, span ID generation, and text truncation. +""" +from __future__ import annotations + +import uuid +from typing import Any + +# --------------------------------------------------------------------------- +# Span IDs +# --------------------------------------------------------------------------- + + +def new_span_id() -> str: + """Generate a short random span identifier.""" + return uuid.uuid4().hex[:16] + + +# --------------------------------------------------------------------------- +# Serialisation +# --------------------------------------------------------------------------- + + +def safe_serialize(value: Any) -> Any: + """Best-effort conversion of *value* into a JSON-friendly form. + + Handles Pydantic models (``model_dump``), objects with ``to_dict``, + dicts, lists/tuples, and falls back to ``str()``. + """ + if value is None: + return None + if isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [safe_serialize(v) for v in value] + if hasattr(value, "model_dump"): + try: + return value.model_dump() + except Exception: + pass + if hasattr(value, "to_dict"): + try: + return value.to_dict() + except Exception: + pass + if isinstance(value, dict): + return {str(k): safe_serialize(v) for k, v in value.items()} + return str(value) + + +# --------------------------------------------------------------------------- +# Text truncation +# --------------------------------------------------------------------------- + + +def truncate(text: Any, max_len: int = 2000) -> Any: + """Truncate *text* to *max_len* characters, appending ``'...'``. + + Returns *None* unchanged. Non-string values are stringified first. + """ + if text is None: + return None + if not isinstance(text, str): + text = str(text) + if len(text) <= max_len: + return text + return text[:max_len] + "..." diff --git a/src/layerlens/instrument/adapters/frameworks/crewai.py b/src/layerlens/instrument/adapters/frameworks/crewai.py new file mode 100644 index 0000000..b922748 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +try: + from crewai.events import BaseEventListener as _BaseEventListener # pyright: ignore[reportMissingImports] +except (ImportError, TypeError): + _BaseEventListener = None + + +class CrewAIAdapter(FrameworkAdapter): + """CrewAI adapter using the typed event bus API (crewai >= 1.0). + + Subscribes to CrewAI's event bus to capture crew lifecycle, agent + execution, LLM calls, tool usage, flows, and MCP tool events as + flat layerlens events. + + Usage:: + + adapter = CrewAIAdapter(client) + adapter.connect() + crew.kickoff() # events flow automatically via event bus + adapter.disconnect() + """ + + name = "crewai" + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + super().__init__(client, capture_config) + self._registered_handlers: list = [] + + # Span tracking: crew/flow → task → agent → leaf hierarchy + self._crew_span_id: Optional[str] = None + self._task_span_ids: Dict[str, str] = {} # task name → span_id + self._current_task_span_id: Optional[str] = None + self._agent_span_ids: Dict[str, str] = {} # agent_role → span_id + self._current_agent_span_id: Optional[str] = None + # tool.call span IDs keyed by tool_name+id for pairing start/end + self._tool_span_ids: Dict[str, str] = {} + + # Event name → handler method name; resolved to real classes at subscribe time. + _EVENT_MAP = [ + ("CrewKickoffStartedEvent", "_on_crew_started"), + ("CrewKickoffCompletedEvent", "_on_crew_completed"), + ("CrewKickoffFailedEvent", "_on_crew_failed"), + ("TaskStartedEvent", "_on_task_started"), + ("TaskCompletedEvent", "_on_task_completed"), + ("TaskFailedEvent", "_on_task_failed"), + ("AgentExecutionStartedEvent", "_on_agent_execution_started"), + ("AgentExecutionCompletedEvent", "_on_agent_execution_completed"), + ("AgentExecutionErrorEvent", "_on_agent_execution_error"), + ("LLMCallStartedEvent", "_on_llm_started"), + ("LLMCallCompletedEvent", "_on_llm_completed"), + ("LLMCallFailedEvent", "_on_llm_failed"), + ("ToolUsageStartedEvent", "_on_tool_started"), + ("ToolUsageFinishedEvent", "_on_tool_finished"), + ("ToolUsageErrorEvent", "_on_tool_error"), + ("FlowStartedEvent", "_on_flow_started"), + ("FlowFinishedEvent", "_on_flow_finished"), + ("MCPToolExecutionCompletedEvent", "_on_mcp_tool_completed"), + ("MCPToolExecutionFailedEvent", "_on_mcp_tool_failed"), + ] + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_BaseEventListener is not None) + self._subscribe() + + def _on_disconnect(self) -> None: + self._unsubscribe() + self._registered_handlers.clear() + self._reset_spans() + + # ------------------------------------------------------------------ + # Event bus wiring + # ------------------------------------------------------------------ + + def _subscribe(self) -> None: + """Register all event handlers on the CrewAI bus.""" + import crewai.events as ev # pyright: ignore[reportMissingImports] + + for event_name, method_name in self._EVENT_MAP: + event_cls = getattr(ev, event_name) + method = getattr(self, method_name) + + def _handler(source: Any, event: Any, _m: Any = method) -> None: + try: + _m(source, event) + except Exception: + log.warning("layerlens: error in CrewAI event handler", exc_info=True) + + ev.crewai_event_bus.on(event_cls)(_handler) + self._registered_handlers.append((event_cls, _handler)) + + def _unsubscribe(self) -> None: + """Remove all previously registered handlers from the CrewAI bus.""" + try: + from crewai.events import crewai_event_bus # pyright: ignore[reportMissingImports] + except ImportError: + return + for event_cls, handler in self._registered_handlers: + try: + crewai_event_bus.off(event_cls, handler) + except Exception: + log.debug("layerlens: could not unregister %s handler", event_cls.__name__, exc_info=True) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_name(obj: Any) -> str: + return getattr(obj, "name", None) or type(obj).__name__ + + @staticmethod + def _get_task_name(event: Any) -> str: + """Extract task name from a CrewAI event.""" + name = getattr(event, "task_name", None) + if name: + return str(name) + task = getattr(event, "task", None) + if task: + return str(getattr(task, "description", None) or getattr(task, "name", ""))[:200] + return "" + + @staticmethod + def _tool_event_key(event: Any) -> str: + """Build a key to correlate ToolUsageStarted with ToolUsageFinished.""" + tool_name = getattr(event, "tool_name", None) or "" + agent_key = getattr(event, "agent_key", None) or "" + return f"{tool_name}:{agent_key}" + + def _leaf_parent_span_id(self) -> Optional[str]: + """Return the innermost active parent span for leaf events (LLM, tool).""" + with self._lock: + return self._current_agent_span_id or self._current_task_span_id or self._crew_span_id + + def _reset_spans(self) -> None: + """Clear all span tracking state.""" + with self._lock: + self._crew_span_id = None + self._task_span_ids.clear() + self._current_task_span_id = None + self._agent_span_ids.clear() + self._current_agent_span_id = None + self._tool_span_ids.clear() + + def _end_trace(self) -> None: + """Reset spans and flush — called when a crew/flow run completes.""" + self._reset_spans() + self._flush_collector() + + # ------------------------------------------------------------------ + # Crew lifecycle + # ------------------------------------------------------------------ + + def _on_crew_started(self, source: Any, event: Any) -> None: + span_id = self._new_span_id() + with self._lock: + self._crew_span_id = span_id + self._start_timer("crew") + crew_name = getattr(event, "crew_name", None) or self._get_name(source) + payload = self._payload(crew_name=crew_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "inputs", None))) + self._emit("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=crew_name) + + def _on_crew_completed(self, source: Any, event: Any) -> None: + latency_ms = self._stop_timer("crew") + crew_name = getattr(event, "crew_name", None) or self._get_name(source) + payload = self._payload(crew_name=crew_name) + if latency_ms is not None: + payload["duration_ns"] = int(latency_ms * 1_000_000) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + total_tokens = getattr(event, "total_tokens", None) + if total_tokens is not None: + payload["tokens_total"] = total_tokens + self._emit( + "agent.output", payload, + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, span_name=crew_name, + ) + if total_tokens: + self._emit( + "cost.record", + self._payload(tokens_total=total_tokens), + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, + ) + self._end_trace() + + def _on_crew_failed(self, source: Any, event: Any) -> None: + error = str(getattr(event, "error", "unknown error")) + crew_name = getattr(event, "crew_name", None) or self._get_name(source) + self._emit( + "agent.error", + self._payload(crew_name=crew_name, error=error), + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, span_name=crew_name, + ) + self._end_trace() + + # ------------------------------------------------------------------ + # Task lifecycle + # ------------------------------------------------------------------ + + def _on_task_started(self, source: Any, event: Any) -> None: + task_name = self._get_task_name(event) + span_id = self._new_span_id() + with self._lock: + self._task_span_ids[task_name] = span_id + self._current_task_span_id = span_id + parent = self._crew_span_id + agent_role = getattr(event, "agent_role", None) + payload = self._payload(task_name=task_name) + if agent_role: + payload["agent_role"] = agent_role + if self._config.capture_content: + context = getattr(event, "context", None) + if context: + payload["context"] = str(context)[:500] + self._emit( + "agent.input", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"task:{task_name[:60]}", + ) + + def _on_task_completed(self, source: Any, event: Any) -> None: + task_name = self._get_task_name(event) + with self._lock: + span_id = self._task_span_ids.pop(task_name, self._current_task_span_id or self._new_span_id()) + parent = self._crew_span_id + payload = self._payload(task_name=task_name) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + self._emit( + "agent.output", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"task:{task_name[:60]}", + ) + + def _on_task_failed(self, source: Any, event: Any) -> None: + task_name = self._get_task_name(event) + with self._lock: + span_id = self._task_span_ids.pop(task_name, self._current_task_span_id or self._new_span_id()) + parent = self._crew_span_id + error = str(getattr(event, "error", "unknown error")) + self._emit( + "agent.error", + self._payload(task_name=task_name, error=error), + span_id=span_id, parent_span_id=parent, + ) + + # ------------------------------------------------------------------ + # Agent execution lifecycle + # ------------------------------------------------------------------ + + def _on_agent_execution_started(self, source: Any, event: Any) -> None: + agent = getattr(event, "agent", None) + agent_role = getattr(event, "agent_role", None) or ( + getattr(agent, "role", None) if agent else None + ) or "unknown" + span_id = self._new_span_id() + with self._lock: + self._agent_span_ids[agent_role] = span_id + self._current_agent_span_id = span_id + parent = self._current_task_span_id or self._crew_span_id + + payload = self._payload(agent_role=agent_role) + tools = getattr(event, "tools", None) + if tools: + payload["tools"] = [getattr(t, "name", str(t)) for t in tools] + if self._config.capture_content: + task_prompt = getattr(event, "task_prompt", None) + if task_prompt: + payload["task_prompt"] = str(task_prompt)[:500] + self._emit( + "agent.input", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"agent:{agent_role[:60]}", + ) + + def _on_agent_execution_completed(self, source: Any, event: Any) -> None: + agent = getattr(event, "agent", None) + agent_role = getattr(event, "agent_role", None) or ( + getattr(agent, "role", None) if agent else None + ) or "unknown" + with self._lock: + span_id = self._agent_span_ids.pop(agent_role, self._current_agent_span_id or self._new_span_id()) + parent = self._current_task_span_id or self._crew_span_id + if self._current_agent_span_id == span_id: + self._current_agent_span_id = None + + payload = self._payload(agent_role=agent_role, status="ok") + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + self._emit( + "agent.output", payload, + span_id=span_id, parent_span_id=parent, + span_name=f"agent:{agent_role[:60]}", + ) + + def _on_agent_execution_error(self, source: Any, event: Any) -> None: + agent = getattr(event, "agent", None) + agent_role = getattr(event, "agent_role", None) or ( + getattr(agent, "role", None) if agent else None + ) or "unknown" + error = str(getattr(event, "error", "unknown error")) + with self._lock: + span_id = self._agent_span_ids.pop(agent_role, self._current_agent_span_id or self._new_span_id()) + parent = self._current_task_span_id or self._crew_span_id + if self._current_agent_span_id == span_id: + self._current_agent_span_id = None + + self._emit( + "agent.error", + self._payload(agent_role=agent_role, error=error), + span_id=span_id, parent_span_id=parent, + span_name=f"agent:{agent_role[:60]}", + ) + + # ------------------------------------------------------------------ + # LLM calls + # ------------------------------------------------------------------ + + def _on_llm_started(self, source: Any, event: Any) -> None: + call_id = getattr(event, "call_id", None) + if call_id: + self._start_timer(f"llm:{call_id}") + + def _on_llm_completed(self, source: Any, event: Any) -> None: + model = getattr(event, "model", None) + response = getattr(event, "response", None) + # Unwrap .usage from the response before normalizing + usage = getattr(response, "usage", None) if response and not isinstance(response, dict) else ( + response.get("usage") if isinstance(response, dict) else None + ) + tokens = self._normalize_tokens(usage) + payload = self._payload() + if model: + payload["model"] = model + call_id = getattr(event, "call_id", None) + if call_id: + latency_ms = self._stop_timer(f"llm:{call_id}") + if latency_ms is not None: + payload["latency_ms"] = latency_ms + payload.update(tokens) + parent = self._leaf_parent_span_id() + span_id = self._new_span_id() + self._emit("model.invoke", payload, span_id=span_id, parent_span_id=parent) + if tokens: + self._emit( + "cost.record", + self._payload(model=model, **tokens), + span_id=span_id, parent_span_id=parent, + ) + + def _on_llm_failed(self, source: Any, event: Any) -> None: + error = str(getattr(event, "error", "unknown error")) + model = getattr(event, "model", None) + payload = self._payload(error=error) + if model: + payload["model"] = model + parent = self._leaf_parent_span_id() + self._emit("agent.error", payload, parent_span_id=parent) + + # ------------------------------------------------------------------ + # Tool usage — split into tool.call (start) and tool.result (end) + # ------------------------------------------------------------------ + + def _on_tool_started(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + span_id = self._new_span_id() + tool_key = self._tool_event_key(event) + with self._lock: + self._tool_span_ids[tool_key] = span_id + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "tool_args", None))) + parent = self._leaf_parent_span_id() + self._emit("tool.call", payload, span_id=span_id, parent_span_id=parent) + + def _on_tool_finished(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + tool_key = self._tool_event_key(event) + with self._lock: + span_id = self._tool_span_ids.pop(tool_key, None) + if span_id is None: + span_id = self._new_span_id() + + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) + # Compute latency from started_at/finished_at + started_at = getattr(event, "started_at", None) + finished_at = getattr(event, "finished_at", None) + if started_at is not None and finished_at is not None: + try: + payload["latency_ms"] = (finished_at - started_at).total_seconds() * 1000 + except Exception: + pass + from_cache = getattr(event, "from_cache", None) + if from_cache: + payload["from_cache"] = True + parent = self._leaf_parent_span_id() + self._emit("tool.result", payload, span_id=span_id, parent_span_id=parent) + + def _on_tool_error(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + error = str(getattr(event, "error", "unknown error")) + tool_key = self._tool_event_key(event) + with self._lock: + self._tool_span_ids.pop(tool_key, None) + parent = self._leaf_parent_span_id() + self._emit( + "agent.error", + self._payload(tool_name=tool_name, error=error), + parent_span_id=parent, + ) + + # ------------------------------------------------------------------ + # Flow events + # ------------------------------------------------------------------ + + def _on_flow_started(self, source: Any, event: Any) -> None: + span_id = self._new_span_id() + with self._lock: + self._crew_span_id = span_id + self._start_timer("crew") + flow_name = getattr(event, "flow_name", None) or self._get_name(source) + payload = self._payload(flow_name=flow_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "inputs", None))) + self._emit("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=f"flow:{flow_name}") + + def _on_flow_finished(self, source: Any, event: Any) -> None: + latency_ms = self._stop_timer("crew") + flow_name = getattr(event, "flow_name", None) or self._get_name(source) + payload = self._payload(flow_name=flow_name) + if latency_ms is not None: + payload["duration_ns"] = int(latency_ms * 1_000_000) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "result", None))) + self._emit( + "agent.output", payload, + span_id=self._crew_span_id or self._new_span_id(), + parent_span_id=None, span_name=f"flow:{flow_name}", + ) + self._end_trace() + + # ------------------------------------------------------------------ + # MCP tool events + # ------------------------------------------------------------------ + + def _on_mcp_tool_completed(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + server_name = getattr(event, "server_name", None) + latency_ms = getattr(event, "execution_duration_ms", None) + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "result", None))) + if server_name: + payload["mcp_server"] = server_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + parent = self._leaf_parent_span_id() + self._emit("tool.call", payload, parent_span_id=parent) + + def _on_mcp_tool_failed(self, source: Any, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or "unknown" + error = str(getattr(event, "error", "unknown error")) + server_name = getattr(event, "server_name", None) + payload = self._payload(tool_name=tool_name, error=error) + if server_name: + payload["mcp_server"] = server_name + parent = self._leaf_parent_span_id() + self._emit("agent.error", payload, parent_span_id=parent) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py index 5b14f0e..a69a7d6 100644 --- a/src/layerlens/instrument/adapters/frameworks/langchain.py +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -1,12 +1,27 @@ from __future__ import annotations +import functools from uuid import UUID -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from ._base_framework import FrameworkAdapter +from ..._capture_config import CaptureConfig + + +def _auto_flush(fn): # type: ignore[type-arg] + """Decorator: after the callback returns, flush if this was the outermost run.""" + @functools.wraps(fn) + def wrapper(self, *args, run_id, **kwargs): # type: ignore[no-untyped-def] + fn(self, *args, run_id=run_id, **kwargs) + run = self._get_run() + if run is not None: + if str(run_id) == run.data.get("root_run_id"): + self._end_run() + elif str(run_id) == self._root_run_id and self._collector is not None: + self._flush_collector() + self._root_run_id = None + return wrapper -if TYPE_CHECKING: - from ..._capture_config import CaptureConfig try: from langchain_core.callbacks import BaseCallbackHandler # pyright: ignore[reportAssignmentType] @@ -26,28 +41,14 @@ class LangChainCallbackHandler(BaseCallbackHandler, FrameworkAdapter): def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: BaseCallbackHandler.__init__(self) FrameworkAdapter.__init__(self, client, capture_config=capture_config) - self._root_run_id: Optional[str] = None + # Pending LLM runs: run_id -> {name, messages, parent_run_id} + self._pending_llm: Dict[str, Dict[str, Any]] = {} + # Context tokens for span propagation: run_id -> token from _push_context + self._run_contexts: Dict[str, Any] = {} - def _emit_for_run( - self, - event_type: str, - payload: Dict[str, Any], - run_id: UUID, - parent_run_id: Optional[UUID] = None, - ) -> None: - """Emit an event, mapping framework run_ids to span_ids.""" - span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) - rid = str(run_id) - if self._root_run_id is None: - self._root_run_id = rid - self._emit(event_type, payload, span_id=span_id, parent_span_id=parent_span_id) - - def _maybe_flush(self, run_id: UUID) -> None: - if str(run_id) == self._root_run_id and self._collector is not None: - self._flush_collector() - self._root_run_id = None - - # -- Chain -- + # ------------------------------------------------------------------ + # Chain callbacks + # ------------------------------------------------------------------ def on_chain_start( self, @@ -58,10 +59,16 @@ def on_chain_start( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: + if parent_run_id is None: + run = self._begin_run() + run.data["root_run_id"] = str(run_id) serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", inputs) + self._emit("agent.input", payload, run_id=run_id, parent_run_id=parent_run_id) + @_auto_flush def on_chain_end( self, outputs: Dict[str, Any], @@ -70,9 +77,11 @@ def on_chain_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.output", {"output": outputs, "status": "ok"}, run_id) - self._maybe_flush(run_id) + payload = self._payload(status="ok") + self._set_if_capturing(payload, "output", outputs) + self._emit("agent.output", payload, run_id=run_id) + @_auto_flush def on_chain_error( self, error: BaseException, @@ -81,10 +90,11 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) - # -- LLM -- + # ------------------------------------------------------------------ + # LLM callbacks — merged into single model.invoke on end + # ------------------------------------------------------------------ def on_llm_start( self, @@ -97,7 +107,15 @@ def on_llm_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit_for_run("model.invoke", {"name": name, "messages": prompts}, run_id, parent_run_id) + self._start_timer(str(run_id)) + pending: Dict[str, Any] = { + "name": name, + "parent_run_id": parent_run_id, + } + self._set_if_capturing(pending, "messages", prompts) + self._pending_llm[str(run_id)] = pending + span_id, _ = self._span_id_for(run_id) + self._run_contexts[str(run_id)] = self._push_context(span_id) def on_chat_model_start( self, @@ -110,13 +128,20 @@ def on_chat_model_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit_for_run( - "model.invoke", - {"name": name, "messages": [[_serialize_lc_message(m) for m in batch] for batch in messages]}, - run_id, - parent_run_id, + self._start_timer(str(run_id)) + pending: Dict[str, Any] = { + "name": name, + "parent_run_id": parent_run_id, + } + self._set_if_capturing( + pending, "messages", + [[_serialize_lc_message(m) for m in batch] for batch in messages], ) + self._pending_llm[str(run_id)] = pending + span_id, _ = self._span_id_for(run_id) + self._run_contexts[str(run_id)] = self._push_context(span_id) + @_auto_flush def on_llm_end( self, response: Any, @@ -125,6 +150,10 @@ def on_llm_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: + self._pop_context(self._run_contexts.pop(str(run_id), None)) + pending = self._pending_llm.pop(str(run_id), {}) + + # Extract response data output = None try: generations = response.generations @@ -139,20 +168,40 @@ def on_llm_end( llm_output = {} model_name = llm_output.get("model_name") - if model_name or output: - self._emit_for_run( - "model.invoke", - {"model": model_name, "output_message": output}, - run_id, - parent_run_id, - ) - usage = llm_output.get("token_usage", {}) - if usage: - self._emit_for_run("cost.record", usage, run_id, parent_run_id) + # Build single merged model.invoke event + payload = self._payload() + if pending.get("name"): + payload["name"] = pending["name"] + if model_name: + payload["model"] = model_name + self._set_if_capturing(payload, "messages", pending.get("messages")) + self._set_if_capturing(payload, "output_message", output) + + # Latency + latency_ms = self._stop_timer(str(run_id)) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + # Tokens + usage = llm_output.get("token_usage") or llm_output.get("usage_metadata") + tokens = self._normalize_tokens(usage) + payload.update(tokens) + + self._emit( + "model.invoke", payload, + run_id=run_id, parent_run_id=pending.get("parent_run_id"), + ) - self._maybe_flush(run_id) + # Separate cost.record if we have token data + if tokens: + cost_payload = self._payload() + if model_name: + cost_payload["model"] = model_name + cost_payload.update(tokens) + self._emit("cost.record", cost_payload, run_id=run_id, parent_run_id=pending.get("parent_run_id")) + @_auto_flush def on_llm_error( self, error: BaseException, @@ -161,10 +210,22 @@ def on_llm_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._pop_context(self._run_contexts.pop(str(run_id), None)) + pending = self._pending_llm.pop(str(run_id), {}) - # -- Tool -- + payload = self._payload(error=str(error)) + if pending.get("name"): + payload["name"] = pending["name"] + latency_ms = self._stop_timer(str(run_id)) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + self._emit("model.invoke", payload, run_id=run_id, parent_run_id=pending.get("parent_run_id")) + + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) + + # ------------------------------------------------------------------ + # Tool callbacks + # ------------------------------------------------------------------ def on_tool_start( self, @@ -176,8 +237,11 @@ def on_tool_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "tool") - self._emit_for_run("tool.call", {"name": name, "input": input_str}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", input_str) + self._emit("tool.call", payload, run_id=run_id, parent_run_id=parent_run_id) + @_auto_flush def on_tool_end( self, output: str, @@ -186,9 +250,11 @@ def on_tool_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("tool.result", {"output": output}, run_id) - self._maybe_flush(run_id) + payload = self._payload() + self._set_if_capturing(payload, "output", output) + self._emit("tool.result", payload, run_id=run_id) + @_auto_flush def on_tool_error( self, error: BaseException, @@ -197,10 +263,11 @@ def on_tool_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) - # -- Retriever -- + # ------------------------------------------------------------------ + # Retriever callbacks + # ------------------------------------------------------------------ def on_retriever_start( self, @@ -212,8 +279,11 @@ def on_retriever_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "retriever") - self._emit_for_run("tool.call", {"name": name, "input": query}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", query) + self._emit("tool.call", payload, run_id=run_id, parent_run_id=parent_run_id) + @_auto_flush def on_retriever_end( self, documents: Sequence[Any], @@ -222,10 +292,14 @@ def on_retriever_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - output = [_serialize_lc_document(d) for d in documents] - self._emit_for_run("tool.result", {"output": output}, run_id) - self._maybe_flush(run_id) + payload = self._payload() + self._set_if_capturing( + payload, "output", + [_serialize_lc_document(d) for d in documents], + ) + self._emit("tool.result", payload, run_id=run_id) + @_auto_flush def on_retriever_error( self, error: BaseException, @@ -234,10 +308,42 @@ def on_retriever_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) - self._maybe_flush(run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) - # -- Text (required by base) -- + # ------------------------------------------------------------------ + # Agent callbacks + # ------------------------------------------------------------------ + + def on_agent_action( + self, + action: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + payload = self._payload(tool=getattr(action, "tool", "unknown")) + self._set_if_capturing(payload, "tool_input", getattr(action, "tool_input", None)) + self._set_if_capturing(payload, "log", getattr(action, "log", None) or None) + self._emit("agent.input", payload, run_id=run_id, parent_run_id=parent_run_id) + + @_auto_flush + def on_agent_finish( + self, + finish: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + payload = self._payload(status="ok") + self._set_if_capturing(payload, "output", getattr(finish, "return_values", None)) + self._set_if_capturing(payload, "log", getattr(finish, "log", None) or None) + self._emit("agent.output", payload, run_id=run_id, parent_run_id=parent_run_id) + + # ------------------------------------------------------------------ + # No-ops (required by base) + # ------------------------------------------------------------------ def on_text(self, text: str, **kwargs: Any) -> None: pass diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph.py b/src/layerlens/instrument/adapters/frameworks/langgraph.py index f4b666a..35de3c4 100644 --- a/src/layerlens/instrument/adapters/frameworks/langgraph.py +++ b/src/layerlens/instrument/adapters/frameworks/langgraph.py @@ -19,6 +19,9 @@ def on_chain_start( tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: + if parent_run_id is None: + run = self._begin_run() + run.data["root_run_id"] = str(run_id) serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] @@ -38,4 +41,6 @@ def on_chain_start( if node_name: name = node_name - self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + payload = self._payload(name=name) + self._set_if_capturing(payload, "input", inputs) + self._emit("agent.input", payload, run_id=run_id, parent_run_id=parent_run_id) diff --git a/src/layerlens/instrument/adapters/frameworks/openai_agents.py b/src/layerlens/instrument/adapters/frameworks/openai_agents.py new file mode 100644 index 0000000..e175c34 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/openai_agents.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Dict, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._collector import TraceCollector +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +_HAS_OPENAI_AGENTS = False +try: + from agents.tracing import TracingProcessor # pyright: ignore[reportMissingImports] + + _HAS_OPENAI_AGENTS = True +except (ImportError, Exception): + TracingProcessor = None # type: ignore[assignment,misc] + +# Real TracingProcessor when installed, plain object otherwise. +_Base: Any = TracingProcessor if _HAS_OPENAI_AGENTS else object + + +class OpenAIAgentsAdapter(_Base, FrameworkAdapter): + """OpenAI Agents SDK adapter using the TracingProcessor API. + + The adapter *is* the trace processor — it registers itself globally + to receive all span lifecycle events, then maps agent, generation, + function, handoff, and guardrail spans to flat layerlens events. + + Unlike other adapters that use a single collector, this adapter manages + per-trace collectors because the SDK can run multiple concurrent traces + through the same global processor. + + Usage:: + + adapter = OpenAIAgentsAdapter(client) + adapter.connect() + result = await Runner.run(agent, "hello") + adapter.disconnect() + """ + + name = "openai-agents" + package = "openai-agents" + + _SPAN_HANDLERS = { + "agent": "_handle_agent_span", + "generation": "_handle_generation_span", + "function": "_handle_function_span", + "handoff": "_handle_handoff_span", + "guardrail": "_handle_guardrail_span", + "response": "_handle_response_span", + } + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + FrameworkAdapter.__init__(self, client, capture_config) + self._collectors: Dict[str, TraceCollector] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_HAS_OPENAI_AGENTS) + from agents import add_trace_processor # pyright: ignore[reportMissingImports] + + add_trace_processor(self) # type: ignore[arg-type] + + def _on_disconnect(self) -> None: + from agents import set_trace_processors # pyright: ignore[reportMissingImports] + + set_trace_processors([]) + with self._lock: + self._collectors.clear() + + # ------------------------------------------------------------------ + # TracingProcessor interface + # ------------------------------------------------------------------ + + def on_trace_start(self, trace: Any) -> None: + try: + self._get_collector(trace.trace_id) + except Exception: + log.warning("layerlens: error in on_trace_start", exc_info=True) + + def on_trace_end(self, trace: Any) -> None: + try: + with self._lock: + collector = self._collectors.pop(trace.trace_id, None) + if collector is not None: + collector.flush() + except Exception: + log.warning("layerlens: error in on_trace_end", exc_info=True) + + def on_span_start(self, span: Any) -> None: + pass + + def on_span_end(self, span: Any) -> None: + try: + span_type = getattr(span.span_data, "type", None) or "" + handler_name = self._SPAN_HANDLERS.get(span_type) + if handler_name is not None: + getattr(self, handler_name)(span) + except Exception: + log.warning("layerlens: error handling OpenAI Agents span", exc_info=True) + + def shutdown(self) -> None: + pass + + def force_flush(self) -> None: + pass + + # ------------------------------------------------------------------ + # Per-trace collector + # ------------------------------------------------------------------ + + def _get_collector(self, trace_id: str) -> TraceCollector: + with self._lock: + if trace_id not in self._collectors: + self._collectors[trace_id] = TraceCollector(self._client, self._config) + return self._collectors[trace_id] + + # ------------------------------------------------------------------ + # Span handlers + # ------------------------------------------------------------------ + + def _handle_agent_span(self, span: Any) -> None: + data = span.span_data + collector = self._get_collector(span.trace_id) + agent_name = getattr(data, "name", "unknown") + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + + input_payload = self._payload(agent_name=agent_name) + for key in ("tools", "handoffs", "output_type"): + val = getattr(data, key, None) + if val: + input_payload[key] = val + + collector.emit( + "agent.input", input_payload, + span_id=span_id, parent_span_id=parent_id, + span_name=f"agent:{agent_name}", + ) + + event_type = "agent.error" if span.error else "agent.output" + out_payload = self._payload( + agent_name=agent_name, + status="error" if span.error else "ok", + ) + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + out_payload["duration_ms"] = duration_ms + if span.error: + out_payload["error"] = safe_serialize(span.error) + + collector.emit( + event_type, out_payload, + span_id=span_id, parent_span_id=parent_id, + span_name=f"agent:{agent_name}", + ) + + def _handle_generation_span(self, span: Any) -> None: + data = span.span_data + collector = self._get_collector(span.trace_id) + model = getattr(data, "model", None) or "unknown" + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + + payload = self._payload(model=model) + tokens = self._normalize_tokens(getattr(data, "usage", None)) + payload.update(tokens) + + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + payload["latency_ms"] = duration_ms + + model_config = getattr(data, "model_config", None) + if model_config: + payload["model_config"] = safe_serialize(model_config) + + self._set_if_capturing(payload, "messages", safe_serialize(getattr(data, "input", None))) + self._set_if_capturing(payload, "output_message", safe_serialize(getattr(data, "output", None))) + + if span.error: + payload["error"] = safe_serialize(span.error) + collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + else: + collector.emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) + + if tokens: + cost_payload = self._payload(model=model) + cost_payload.update(tokens) + collector.emit("cost.record", cost_payload, span_id=span_id, parent_span_id=parent_id) + + def _handle_function_span(self, span: Any) -> None: + data = span.span_data + collector = self._get_collector(span.trace_id) + tool_name = getattr(data, "name", "unknown") + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "input", safe_serialize(getattr(data, "input", None))) + self._set_if_capturing(payload, "output", safe_serialize(getattr(data, "output", None))) + + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + payload["latency_ms"] = duration_ms + + mcp_data = getattr(data, "mcp_data", None) + if mcp_data: + payload["mcp_data"] = safe_serialize(mcp_data) + + if span.error: + payload["error"] = safe_serialize(span.error) + collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + else: + collector.emit("tool.call", payload, span_id=span_id, parent_span_id=parent_id) + + def _handle_handoff_span(self, span: Any) -> None: + data = span.span_data + self._get_collector(span.trace_id).emit( + "agent.handoff", + self._payload( + from_agent=getattr(data, "from_agent", None) or "unknown", + to_agent=getattr(data, "to_agent", None) or "unknown", + ), + span_id=span.span_id or self._new_span_id(), + parent_span_id=span.parent_id, + ) + + def _handle_guardrail_span(self, span: Any) -> None: + data = span.span_data + self._get_collector(span.trace_id).emit( + "evaluation.result", + self._payload( + guardrail_name=getattr(data, "name", "unknown"), + triggered=getattr(data, "triggered", False), + ), + span_id=span.span_id or self._new_span_id(), + parent_span_id=span.parent_id, + ) + + def _handle_response_span(self, span: Any) -> None: + data = span.span_data + response = getattr(data, "response", None) + if response is None: + return + + collector = self._get_collector(span.trace_id) + span_id = span.span_id or self._new_span_id() + parent_id = span.parent_id + payload = self._payload() + + model = getattr(response, "model", None) + if model: + payload["model"] = model + + usage = getattr(response, "usage", None) + tokens = self._normalize_tokens(usage) + # OpenAI-specific detailed token breakdowns + if usage is not None: + input_details = getattr(usage, "input_tokens_details", None) + if input_details: + cached = getattr(input_details, "cached_tokens", 0) or 0 + if cached: + tokens["cached_tokens"] = cached + output_details = getattr(usage, "output_tokens_details", None) + if output_details: + reasoning = getattr(output_details, "reasoning_tokens", 0) or 0 + if reasoning: + tokens["reasoning_tokens"] = reasoning + payload.update(tokens) + + duration_ms = _compute_duration_ms(span) + if duration_ms is not None: + payload["latency_ms"] = duration_ms + + if span.error: + payload["error"] = safe_serialize(span.error) + collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + else: + collector.emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _compute_duration_ms(span: Any) -> Optional[float]: + started = getattr(span, "started_at", None) + ended = getattr(span, "ended_at", None) + if started is None or ended is None: + return None + try: + if isinstance(started, str): + started = datetime.fromisoformat(started) + if isinstance(ended, str): + ended = datetime.fromisoformat(ended) + return (ended - started).total_seconds() * 1000 + except Exception: + return None diff --git a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py new file mode 100644 index 0000000..b5ae173 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +try: + from pydantic_ai import Agent as _AgentCheck # pyright: ignore[reportMissingImports] # noqa: F401 + + _HAS_PYDANTIC_AI = True + del _AgentCheck +except ImportError: + _HAS_PYDANTIC_AI = False + + +class PydanticAIAdapter(FrameworkAdapter): + """PydanticAI adapter using the native Hooks capability API. + + Injects a ``Hooks`` capability into the target agent to receive + real-time lifecycle callbacks for run start/end, per-model-call, + and per-tool-execution events — with precise per-step timing. + + Concurrent runs on the same agent are safe: each run gets its own + RunState via ContextVar, so collectors, timers, and tool spans + are fully isolated per ``asyncio.Task``. + + Usage:: + + adapter = PydanticAIAdapter(client) + adapter.connect(target=agent) # injects hooks capability + result = agent.run_sync("hello") + adapter.disconnect() # removes hooks capability + """ + + name = "pydantic-ai" + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + super().__init__(client, capture_config) + self._target: Any = None + self._hooks: Any = None + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_HAS_PYDANTIC_AI) + if target is None: + raise ValueError("PydanticAIAdapter requires a target agent: adapter.connect(target=agent)") + + from pydantic_ai.capabilities.hooks import Hooks # pyright: ignore[reportMissingImports] + + self._target = target + self._hooks = Hooks() + self._register_hooks(self._hooks) + target._root_capability.capabilities.append(self._hooks) + + def _on_disconnect(self) -> None: + if self._target is not None and self._hooks is not None: + try: + caps = self._target._root_capability.capabilities + if self._hooks in caps: + caps.remove(self._hooks) + except Exception: + log.warning("Could not remove PydanticAI hooks capability") + self._hooks = None + self._target = None + + # ------------------------------------------------------------------ + # Hook registration + # ------------------------------------------------------------------ + + def _register_hooks(self, hooks: Any) -> None: + hooks.on.before_run(self._on_before_run) + hooks.on.after_run(self._on_after_run) + hooks.on.run_error(self._on_run_error) + hooks.on.after_model_request(self._on_after_model_request) + hooks.on.model_request_error(self._on_model_request_error) + hooks.on.before_tool_execute(self._on_before_tool_execute) + hooks.on.after_tool_execute(self._on_after_tool_execute) + hooks.on.tool_execute_error(self._on_tool_execute_error) + + # ------------------------------------------------------------------ + # Run lifecycle hooks + # ------------------------------------------------------------------ + + def _on_before_run(self, ctx: Any) -> None: + run = self._begin_run() + agent_name = self._get_agent_name(ctx) + model_name = self._get_model_name(ctx) + + payload = self._payload(agent_name=agent_name) + if model_name: + payload["model"] = model_name + self._set_if_capturing(payload, "input", safe_serialize(ctx.prompt)) + + run.collector.emit( + "agent.input", payload, + span_id=run.root_span_id, parent_span_id=None, + span_name=f"pydantic_ai:{agent_name}", + ) + self._start_timer("run") + + def _on_after_run(self, ctx: Any, *, result: Any) -> Any: + latency_ms = self._stop_timer("run") + agent_name = self._get_agent_name(ctx) + model_name = self._get_model_name(ctx) + root_span = self._get_root_span() + collector = self._ensure_collector() + + output = self._extract_output(result) + usage = self._extract_usage(result) + + payload = self._payload(agent_name=agent_name, status="ok") + if model_name: + payload["model"] = model_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + self._set_if_capturing(payload, "output", output) + payload.update(usage) + collector.emit( + "agent.output", payload, + span_id=root_span, parent_span_id=None, + span_name=f"pydantic_ai:{agent_name}", + ) + + if usage: + cost_payload = self._payload() + if model_name: + cost_payload["model"] = model_name + cost_payload.update(usage) + collector.emit( + "cost.record", cost_payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + + self._end_run() + return result + + def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: + latency_ms = self._stop_timer("run") + agent_name = self._get_agent_name(ctx) + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload( + agent_name=agent_name, + error=str(error), + error_type=type(error).__name__, + ) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + collector.emit( + "agent.error", payload, + span_id=root_span, parent_span_id=None, + span_name=f"pydantic_ai:{agent_name}", + ) + + self._end_run() + raise error + + # ------------------------------------------------------------------ + # Model request hooks + # ------------------------------------------------------------------ + + def _on_after_model_request( + self, ctx: Any, *, request_context: Any, response: Any, + ) -> Any: + root_span = self._get_root_span() + collector = self._ensure_collector() + + model_name = getattr(response, "model_name", None) + usage = getattr(response, "usage", None) + tokens = self._normalize_tokens(usage) + + payload = self._payload() + if model_name: + payload["model"] = model_name + payload.update(tokens) + + model_span = self._new_span_id() + collector.emit( + "model.invoke", payload, + span_id=model_span, parent_span_id=root_span, + ) + + parts = getattr(response, "parts", None) or [] + for part in parts: + if type(part).__name__ == "ToolCallPart": + tool_name = getattr(part, "tool_name", "unknown") + tool_payload = self._payload(tool_name=tool_name) + self._set_if_capturing( + tool_payload, "input", + safe_serialize(getattr(part, "args", None)), + ) + collector.emit( + "tool.call", tool_payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + + return response + + def _on_model_request_error( + self, ctx: Any, *, request_context: Any, error: Exception, + ) -> None: + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload( + error=str(error), + error_type=type(error).__name__, + ) + collector.emit( + "agent.error", payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + raise error + + # ------------------------------------------------------------------ + # Tool execution hooks + # ------------------------------------------------------------------ + + def _on_before_tool_execute( + self, ctx: Any, *, call: Any, tool_def: Any, args: Any, + ) -> Any: + tool_name = getattr(call, "tool_name", "unknown") + span_id = self._new_span_id() + run = self._get_run() + if run is not None: + run.data.setdefault("tool_spans", {})[tool_name] = span_id + self._start_timer(f"tool:{tool_name}") + return args + + def _on_after_tool_execute( + self, ctx: Any, *, call: Any, tool_def: Any, args: Any, result: Any, + ) -> Any: + tool_name = getattr(call, "tool_name", "unknown") + latency_ms = self._stop_timer(f"tool:{tool_name}") + + run = self._get_run() + tool_spans = run.data.get("tool_spans", {}) if run is not None else {} + span_id = tool_spans.pop(tool_name, self._new_span_id()) + + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload(tool_name=tool_name) + self._set_if_capturing(payload, "output", safe_serialize(result)) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + collector.emit( + "tool.result", payload, + span_id=span_id, parent_span_id=root_span, + ) + return result + + def _on_tool_execute_error( + self, ctx: Any, *, call: Any, tool_def: Any, args: Any, error: Exception, + ) -> None: + tool_name = getattr(call, "tool_name", "unknown") + self._stop_timer(f"tool:{tool_name}") + + run = self._get_run() + if run is not None: + run.data.get("tool_spans", {}).pop(tool_name, None) + + root_span = self._get_root_span() + collector = self._ensure_collector() + + payload = self._payload( + tool_name=tool_name, + error=str(error), + error_type=type(error).__name__, + ) + collector.emit( + "agent.error", payload, + span_id=self._new_span_id(), parent_span_id=root_span, + ) + raise error + + # ------------------------------------------------------------------ + # Static helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_agent_name(ctx: Any) -> str: + agent = getattr(ctx, "agent", None) + if agent is not None: + name = getattr(agent, "name", None) + if name: + return str(name) + return PydanticAIAdapter._get_model_name(ctx) or "pydantic_ai_agent" + + @staticmethod + def _get_model_name(ctx: Any) -> Optional[str]: + model = getattr(ctx, "model", None) + if model is None: + agent = getattr(ctx, "agent", None) + model = getattr(agent, "model", None) if agent else None + if model is None: + return None + if isinstance(model, str): + return model + name = getattr(model, "model_name", None) + if name: + return str(name) + return str(model) + + @staticmethod + def _extract_output(result: Any) -> Any: + if result is None: + return None + output = getattr(result, "output", None) + if output is not None: + return safe_serialize(output) + return None + + @staticmethod + def _extract_usage(result: Any) -> Dict[str, Any]: + tokens: Dict[str, Any] = {} + usage = getattr(result, "usage", None) + if usage is None: + return tokens + + if callable(usage): + try: + usage = usage() + except Exception: + return tokens + + input_t = getattr(usage, "input_tokens", 0) or 0 + output_t = getattr(usage, "output_tokens", 0) or 0 + + if input_t: + tokens["tokens_prompt"] = input_t + if output_t: + tokens["tokens_completion"] = output_t + if input_t or output_t: + tokens["tokens_total"] = input_t + output_t + + requests = getattr(usage, "requests", 0) or 0 + if requests: + tokens["model_requests"] = requests + + return tokens diff --git a/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py b/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py new file mode 100644 index 0000000..f02fecd --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize, truncate +from ..._capture_config import CaptureConfig + +log = logging.getLogger(__name__) + +try: + import semantic_kernel as _sk # pyright: ignore[reportMissingImports] # noqa: F401 + + _HAS_SEMANTIC_KERNEL = True +except ImportError: + _HAS_SEMANTIC_KERNEL = False + + +class SemanticKernelAdapter(FrameworkAdapter): + """Semantic Kernel adapter using the SK filter API (semantic-kernel >= 1.0). + + Registers function invocation, prompt rendering, and auto-function + invocation filters on a Kernel instance to capture plugin calls, + prompt templates, and LLM-initiated function calls as flat events. + + Usage:: + + adapter = SemanticKernelAdapter(client) + adapter.connect(target=kernel) + result = await kernel.invoke(my_function, arg1=val1) + adapter.disconnect() + """ + + name = "semantic_kernel" + package = "semantic-kernel" + + def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + super().__init__(client, capture_config) + self._kernel: Any = None + self._filter_ids: List[tuple] = [] # (FilterTypes, filter_id) for removal + self._seen_plugins: set = set() + self._patched_services: Dict[str, Any] = {} # service_id → original method + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + self._check_dependency(_HAS_SEMANTIC_KERNEL) + if target is None: + raise ValueError("SemanticKernelAdapter requires a target kernel: adapter.connect(target=kernel)") + + from semantic_kernel.filters.filter_types import FilterTypes # pyright: ignore[reportMissingImports] + + self._kernel = target + + filters = [ + (FilterTypes.FUNCTION_INVOCATION, self._function_invocation_filter), + (FilterTypes.PROMPT_RENDERING, self._prompt_rendering_filter), + (FilterTypes.AUTO_FUNCTION_INVOCATION, self._auto_function_invocation_filter), + ] + for filter_type, handler in filters: + target.add_filter(filter_type, handler) + filter_list = _get_filter_list(target, filter_type) + if filter_list: + self._filter_ids.append((filter_type, filter_list[-1][0])) + + # Wrap LLM calls on registered chat services + self._patch_chat_services(target) + + # Discover existing plugins + self._discover_plugins(target) + + def _on_disconnect(self) -> None: + if self._kernel is not None: + for filter_type, filter_id in self._filter_ids: + try: + self._kernel.remove_filter(filter_type, filter_id=filter_id) + except Exception: + log.debug("layerlens: could not remove SK filter %s/%s", filter_type, filter_id) + self._unpatch_chat_services() + self._filter_ids.clear() + self._seen_plugins.clear() + self._kernel = None + + # ------------------------------------------------------------------ + # LLM call wrapping + # ------------------------------------------------------------------ + + def _patch_chat_services(self, kernel: Any) -> None: + """Wrap _inner_get_chat_message_contents on all registered chat services.""" + services = getattr(kernel, "services", None) + if not services or not isinstance(services, dict): + return + + for service_id, service in services.items(): + if not hasattr(service, "_inner_get_chat_message_contents"): + continue + original = service._inner_get_chat_message_contents + adapter = self + + async def _traced_inner(chat_history: Any, settings: Any, _orig: Any = original, _svc: Any = service) -> Any: + span_id = adapter._new_span_id() + root_span = adapter._get_root_span() + adapter._start_timer(span_id) + collector = adapter._ensure_collector() + + model_name = getattr(_svc, "ai_model_id", None) + + try: + result = await _orig(chat_history, settings) + except Exception as exc: + latency_ms = adapter._stop_timer(span_id) + payload = adapter._payload( + error=str(exc), + error_type=type(exc).__name__, + ) + if model_name: + payload["model"] = model_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + collector.emit( + "agent.error", payload, + span_id=span_id, parent_span_id=root_span, + ) + raise + + latency_ms = adapter._stop_timer(span_id) + tokens = adapter._extract_usage_from_response(result) + + payload = adapter._payload() + if model_name: + payload["model"] = model_name + if latency_ms is not None: + payload["latency_ms"] = latency_ms + payload.update(tokens) + collector.emit( + "model.invoke", payload, + span_id=span_id, parent_span_id=root_span, + ) + + if tokens: + cost_payload = adapter._payload() + if model_name: + cost_payload["model"] = model_name + cost_payload.update(tokens) + collector.emit( + "cost.record", cost_payload, + span_id=span_id, parent_span_id=root_span, + ) + + return result + + service._inner_get_chat_message_contents = _traced_inner + self._patched_services[service_id] = original + + def _unpatch_chat_services(self) -> None: + """Restore original _inner_get_chat_message_contents on all patched services.""" + if self._kernel is not None: + services = getattr(self._kernel, "services", {}) + for service_id, original in self._patched_services.items(): + service = services.get(service_id) + if service is not None: + try: + service._inner_get_chat_message_contents = original + except Exception: + log.debug("layerlens: could not restore SK chat service %s", service_id) + self._patched_services.clear() + + def _extract_usage_from_response(self, result: Any) -> Dict[str, Any]: + """Extract token usage from ChatMessageContent list returned by _inner_get_chat_message_contents.""" + if not result: + return {} + msg = result[0] if isinstance(result, list) else result + metadata = getattr(msg, "metadata", None) + if not metadata or not isinstance(metadata, dict): + return {} + return self._normalize_tokens(metadata.get("usage")) + + # ------------------------------------------------------------------ + # Plugin discovery + # ------------------------------------------------------------------ + + def _discover_plugins(self, kernel: Any) -> None: + try: + plugins = getattr(kernel, "plugins", None) + if plugins is None: + return + names = list(plugins.keys()) if hasattr(plugins, "keys") else [str(p) for p in plugins] + collector = self._ensure_collector() + for name in names: + if name not in self._seen_plugins: + self._seen_plugins.add(name) + collector.emit( + "environment.config", + self._payload(plugin_name=name, event_subtype="plugin_registered"), + span_id=self._new_span_id(), + parent_span_id=self._get_root_span(), + ) + except Exception: + log.debug("layerlens: error discovering SK plugins", exc_info=True) + + def _maybe_discover_plugin(self, plugin_name: str) -> None: + if not plugin_name or plugin_name in self._seen_plugins: + return + with self._lock: + if plugin_name in self._seen_plugins: + return + self._seen_plugins.add(plugin_name) + collector = self._ensure_collector() + collector.emit( + "environment.config", + self._payload(plugin_name=plugin_name, event_subtype="plugin_registered"), + span_id=self._new_span_id(), + parent_span_id=self._get_root_span(), + ) + + # ------------------------------------------------------------------ + # Shared filter logic + # ------------------------------------------------------------------ + + async def _wrap_invocation( + self, + context: Any, + next: Any, + *, + auto_invoked: bool = False, + ) -> None: + """Shared wrap-and-emit logic for function and auto-function filters. + + Emits tool.call on start, tool.result on success (or agent.error on failure), + with timing. The ``auto_invoked`` flag adds LLM-specific metadata. + """ + plugin_name = _extract_plugin_name(context) + function_name = _extract_function_name(context) + tool_name = f"{plugin_name}.{function_name}" if plugin_name else function_name + + self._maybe_discover_plugin(plugin_name) + + span_id = self._new_span_id() + root_span = self._get_root_span() + self._start_timer(span_id) + collector = self._ensure_collector() + + # -- Emit tool.call (start) -- + call_payload = self._payload( + tool_name=tool_name, + plugin_name=plugin_name, + function_name=function_name, + ) + if auto_invoked: + call_payload["auto_invoked"] = True + call_payload["request_sequence_index"] = getattr(context, "request_sequence_index", 0) + call_payload["function_sequence_index"] = getattr(context, "function_sequence_index", 0) + # Auto-invoked: args come from the LLM's function_call_content + call_content = getattr(context, "function_call_content", None) + if call_content: + self._set_if_capturing( + call_payload, "input", + safe_serialize(getattr(call_content, "arguments", None)), + ) + else: + # User-invoked: args come from context.arguments + self._set_if_capturing( + call_payload, "input", + safe_serialize(_extract_arguments(context)), + ) + + collector.emit( + "tool.call", call_payload, + span_id=span_id, parent_span_id=root_span, + span_name=f"sk:{tool_name}", + ) + + # -- Execute -- + error = None + try: + await next(context) + except Exception as exc: + error = exc + raise + finally: + latency_ms = self._stop_timer(span_id) + + if error: + err_payload = self._payload( + tool_name=tool_name, + error=str(error), + error_type=type(error).__name__, + ) + if auto_invoked: + err_payload["auto_invoked"] = True + if latency_ms is not None: + err_payload["latency_ms"] = latency_ms + collector.emit( + "agent.error", err_payload, + span_id=span_id, parent_span_id=root_span, + ) + else: + # Extract result from the appropriate field + if auto_invoked: + func_result = getattr(context, "function_result", None) + else: + func_result = getattr(context, "result", None) + result_value = getattr(func_result, "value", None) if func_result else None + + result_payload = self._payload( + tool_name=tool_name, + status="ok", + ) + if auto_invoked: + result_payload["auto_invoked"] = True + if latency_ms is not None: + result_payload["latency_ms"] = latency_ms + self._set_if_capturing(result_payload, "output", safe_serialize(result_value)) + collector.emit( + "tool.result", result_payload, + span_id=span_id, parent_span_id=root_span, + span_name=f"sk:{tool_name}", + ) + + # ------------------------------------------------------------------ + # Filters + # ------------------------------------------------------------------ + + async def _function_invocation_filter(self, context: Any, next: Any) -> None: + await self._wrap_invocation(context, next, auto_invoked=False) + + async def _prompt_rendering_filter(self, context: Any, next: Any) -> None: + await next(context) + + function_name = _extract_function_name(context) + rendered = getattr(context, "rendered_prompt", None) + + payload = self._payload(event_subtype="prompt_render") + if function_name: + payload["function_name"] = function_name + if rendered and self._config.capture_content: + payload["rendered_prompt"] = truncate(str(rendered), 2000) + + collector = self._ensure_collector() + collector.emit( + "agent.code", payload, + span_id=self._new_span_id(), parent_span_id=self._get_root_span(), + ) + + async def _auto_function_invocation_filter(self, context: Any, next: Any) -> None: + await self._wrap_invocation(context, next, auto_invoked=True) + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _get_filter_list(kernel: Any, filter_type: Any) -> list: + name = filter_type.value if hasattr(filter_type, "value") else str(filter_type) + attr_map = { + "function_invocation": "function_invocation_filters", + "prompt_rendering": "prompt_rendering_filters", + "auto_function_invocation": "auto_function_invocation_filters", + } + return getattr(kernel, attr_map.get(name, ""), []) + + +def _extract_plugin_name(context: Any) -> str: + fn = getattr(context, "function", None) + if fn is not None: + return getattr(fn, "plugin_name", "") or "" + return getattr(context, "plugin_name", "") or "" + + +def _extract_function_name(context: Any) -> str: + fn = getattr(context, "function", None) + if fn is not None: + return getattr(fn, "name", "") or "" + return getattr(context, "function_name", "") or "" + + +def _extract_arguments(context: Any) -> Optional[Dict[str, Any]]: + args = getattr(context, "arguments", None) + if args is None: + return None + if isinstance(args, dict): + return args + if hasattr(args, "items"): + return dict(args.items()) + return None diff --git a/tests/instrument/adapters/frameworks/test_crewai.py b/tests/instrument/adapters/frameworks/test_crewai.py new file mode 100644 index 0000000..3b914a5 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_crewai.py @@ -0,0 +1,808 @@ +"""Tests for CrewAI adapter using real CrewAI event bus. + +These tests exercise the real crewai.events module — no mocking of CrewAI +internals. Events are constructed and emitted on the real event bus, and +we verify the correct layerlens events come out. + +Requires crewai >= 1.0.0 (Python >= 3.10). +""" + +from __future__ import annotations + +import datetime + +import pytest + +from .conftest import capture_framework_trace, find_event, find_events + +# Skip entire module if crewai is not importable (Python < 3.10 or not installed). +# crewai uses `type | None` syntax which causes TypeError on Python < 3.10, +# and importorskip only catches ImportError, so we guard explicitly. +import sys +if sys.version_info < (3, 10): + pytest.skip("crewai requires Python >= 3.10", allow_module_level=True) +try: + import crewai # noqa: F401 +except (ImportError, TypeError): + pytest.skip("crewai not installed or incompatible", allow_module_level=True) + +from crewai.events import ( # noqa: E402 + TaskFailedEvent, + TaskStartedEvent, + LLMCallFailedEvent, + TaskCompletedEvent, + ToolUsageErrorEvent, + ToolUsageStartedEvent, + LLMCallCompletedEvent, + CrewKickoffFailedEvent, + ToolUsageFinishedEvent, + CrewKickoffStartedEvent, + CrewKickoffCompletedEvent, + AgentExecutionErrorEvent, + AgentExecutionStartedEvent, + AgentExecutionCompletedEvent, + crewai_event_bus, # noqa: E402 +) +from crewai.tasks.task_output import TaskOutput # noqa: E402 + +from layerlens.instrument.adapters.frameworks.crewai import CrewAIAdapter # noqa: E402 + + +@pytest.fixture +def adapter_and_trace(mock_client): + """Create a connected CrewAI adapter with trace capture.""" + uploaded = capture_framework_trace(mock_client) + adapter = CrewAIAdapter(mock_client) + with crewai_event_bus.scoped_handlers(): + adapter.connect() + yield adapter, uploaded + adapter.disconnect() + + +class TestCrewAIAdapterLifecycle: + def test_connect_sets_connected(self, mock_client): + adapter = CrewAIAdapter(mock_client) + assert not adapter.is_connected + with crewai_event_bus.scoped_handlers(): + adapter.connect() + assert adapter.is_connected + adapter.disconnect() + assert not adapter.is_connected + + def test_adapter_info(self, mock_client): + adapter = CrewAIAdapter(mock_client) + with crewai_event_bus.scoped_handlers(): + adapter.connect() + info = adapter.adapter_info() + assert info.name == "crewai" + assert info.adapter_type == "framework" + assert info.connected is True + adapter.disconnect() + + def test_disconnect_clears_state(self, mock_client): + adapter = CrewAIAdapter(mock_client) + with crewai_event_bus.scoped_handlers(): + adapter.connect() + adapter.disconnect() + assert adapter._collector is None + assert adapter._crew_span_id is None + assert adapter._task_span_ids == {} + + +class TestCrewKickoff: + def test_crew_start_emits_agent_input(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + evt = CrewKickoffStartedEvent(crew_name="Research Crew", inputs={"topic": "AI"}) + adapter._on_crew_started(None, evt) + # Crew completed triggers flush + to = TaskOutput(description="test", raw="done", agent="R") + completed = CrewKickoffCompletedEvent(crew_name="Research Crew", output=to) + adapter._on_crew_completed(None, completed) + + events = uploaded["events"] + agent_in = find_event(events, "agent.input") + assert agent_in["payload"]["crew_name"] == "Research Crew" + assert agent_in["payload"]["input"] == {"topic": "AI"} + assert agent_in["payload"]["framework"] == "crewai" + + def test_crew_completed_emits_agent_output(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + start = CrewKickoffStartedEvent(crew_name="MyCrew", inputs={}) + adapter._on_crew_started(None, start) + + to = TaskOutput(description="test", raw="final answer", agent="R") + completed = CrewKickoffCompletedEvent(crew_name="MyCrew", output=to, total_tokens=500) + adapter._on_crew_completed(None, completed) + + events = uploaded["events"] + agent_out = find_event(events, "agent.output") + assert agent_out["payload"]["crew_name"] == "MyCrew" + assert agent_out["payload"]["duration_ns"] > 0 + assert agent_out["payload"]["tokens_total"] == 500 + + # Should also emit cost.record for total_tokens + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] == 500 + + def test_crew_failed_emits_agent_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + start = CrewKickoffStartedEvent(crew_name="FailCrew", inputs={}) + adapter._on_crew_started(None, start) + + failed = CrewKickoffFailedEvent(crew_name="FailCrew", error="LLM rate limit exceeded") + adapter._on_crew_failed(None, failed) + + events = uploaded["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "LLM rate limit exceeded" + assert error["payload"]["crew_name"] == "FailCrew" + + def test_crew_lifecycle_flushes_trace(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + start = CrewKickoffStartedEvent(crew_name="FlushCrew", inputs={}) + adapter._on_crew_started(None, start) + + to = TaskOutput(description="t", raw="ok", agent="R") + completed = CrewKickoffCompletedEvent(crew_name="FlushCrew", output=to) + adapter._on_crew_completed(None, completed) + + assert uploaded["trace_id"] is not None + assert len(uploaded["events"]) >= 2 + assert uploaded["attestation"] is not None + # Collector should be reset after flush + assert adapter._collector is None + + +class TestTaskEvents: + def test_task_start_and_complete(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + # Start crew + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # Task lifecycle + adapter._on_task_started( + None, TaskStartedEvent(context="research context", task_name="Research Task", agent_role="Researcher") + ) + to = TaskOutput(description="Research Task", raw="found it", agent="Researcher") + adapter._on_task_completed(None, TaskCompletedEvent(output=to, task_name="Research Task")) + + # Flush + to2 = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to2)) + + events = uploaded["events"] + # Should have crew agent.input, task agent.input, task agent.output, crew agent.output + agent_inputs = find_events(events, "agent.input") + assert len(agent_inputs) == 2 # crew + task + task_input = [e for e in agent_inputs if e["payload"].get("task_name")] + assert len(task_input) == 1 + assert task_input[0]["payload"]["task_name"] == "Research Task" + assert task_input[0]["payload"]["agent_role"] == "Researcher" + + # Task events should be children of crew span + crew_span_id = agent_inputs[0]["span_id"] + assert task_input[0]["parent_span_id"] == crew_span_id + + def test_task_failed(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="Bad Task")) + adapter._on_task_failed(None, TaskFailedEvent(error="task timeout", task_name="Bad Task")) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="task failed")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + task_error = [e for e in errors if e["payload"].get("task_name")] + assert len(task_error) == 1 + assert task_error[0]["payload"]["error"] == "task timeout" + + +class TestLLMEvents: + def test_llm_completed_emits_model_invoke(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # LLM call with token usage in response + response = {"content": "hello", "usage": {"prompt_tokens": 100, "completion_tokens": 50}} + evt = LLMCallCompletedEvent(model="gpt-4o", call_id="call_1", call_type="llm_call", response=response) + adapter._on_llm_completed(None, evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["model"] == "gpt-4o" + assert model_invoke["payload"]["tokens_prompt"] == 100 + assert model_invoke["payload"]["tokens_completion"] == 50 + assert model_invoke["payload"]["tokens_total"] == 150 + + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] == 150 + + def test_llm_failed_emits_agent_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + evt = LLMCallFailedEvent(model="gpt-4o", call_id="call_1", error="rate limit exceeded") + adapter._on_llm_failed(None, evt) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="llm fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + llm_error = [e for e in errors if e["payload"].get("model")] + assert len(llm_error) == 1 + assert llm_error[0]["payload"]["error"] == "rate limit exceeded" + assert llm_error[0]["payload"]["model"] == "gpt-4o" + + +class TestToolEvents: + def test_tool_started_emits_tool_call(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + started_evt = ToolUsageStartedEvent( + tool_name="web_search", + tool_args="AI safety research", + agent_key="researcher_1", + ) + adapter._on_tool_started(None, started_evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["tool_name"] == "web_search" + assert tool_call["payload"]["input"] == "AI safety research" + + def test_tool_finished_emits_tool_result(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + now = datetime.datetime.now() + later = now + datetime.timedelta(milliseconds=150) + evt = ToolUsageFinishedEvent( + tool_name="web_search", + tool_args="AI safety research", + started_at=now, + finished_at=later, + output="Found 10 results about AI safety", + ) + adapter._on_tool_finished(None, evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["tool_name"] == "web_search" + assert tool_result["payload"]["output"] == "Found 10 results about AI safety" + assert tool_result["payload"]["latency_ms"] == pytest.approx(150, abs=5) + + def test_tool_start_end_share_span_id(self, adapter_and_trace): + """tool.call and tool.result for the same tool use share a span_id.""" + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + started_evt = ToolUsageStartedEvent( + tool_name="calculator", + tool_args="2+2", + agent_key="math_agent_1", + ) + adapter._on_tool_started(None, started_evt) + + now = datetime.datetime.now() + finished_evt = ToolUsageFinishedEvent( + tool_name="calculator", + tool_args="2+2", + agent_key="math_agent_1", + started_at=now, + finished_at=now, + output="4", + ) + adapter._on_tool_finished(None, finished_evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + tool_result = find_event(events, "tool.result") + assert tool_call["span_id"] == tool_result["span_id"] + + def test_tool_from_cache(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + now = datetime.datetime.now() + evt = ToolUsageFinishedEvent( + tool_name="cached_tool", + tool_args="query", + started_at=now, + finished_at=now, + output="cached result", + from_cache=True, + ) + adapter._on_tool_finished(None, evt) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["from_cache"] is True + + def test_tool_error_emits_agent_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + evt = ToolUsageErrorEvent(tool_name="calculator", tool_args="1/0", error="division by zero") + adapter._on_tool_error(None, evt) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="tool fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + tool_error = [e for e in errors if e["payload"].get("tool_name")] + assert len(tool_error) == 1 + assert tool_error[0]["payload"]["tool_name"] == "calculator" + assert tool_error[0]["payload"]["error"] == "division by zero" + + +class TestFullCrewLifecycle: + """End-to-end test simulating a complete crew run with multiple tasks.""" + + def test_full_crew_with_tasks_and_llm(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + # 1. Crew starts + adapter._on_crew_started( + None, CrewKickoffStartedEvent(crew_name="Analysis Crew", inputs={"topic": "quantum computing"}) + ) + + # 2. Task 1: Research + adapter._on_task_started( + None, TaskStartedEvent(context="research quantum computing", task_name="Research", agent_role="Researcher") + ) + + # 2a. Agent execution starts within task 1 + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Researcher", task_prompt="Research quantum computing") + ) + + # 3. LLM call within task 1 + response = {"content": "Quantum computing uses qubits...", "usage": {"prompt_tokens": 200, "completion_tokens": 100}} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="claude-3-opus", call_id="c1", call_type="llm_call", response=response) + ) + + # 4. Tool use within task 1 (start + finish) + now = datetime.datetime.now() + adapter._on_tool_started( + None, + ToolUsageStartedEvent(tool_name="arxiv_search", tool_args="quantum computing 2024", agent_key="researcher_1"), + ) + adapter._on_tool_finished( + None, + ToolUsageFinishedEvent( + tool_name="arxiv_search", + tool_args="quantum computing 2024", + agent_key="researcher_1", + started_at=now, + finished_at=now, + output="3 papers found", + ), + ) + + # 4a. Agent execution completes + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="Researcher", output="Research complete") + ) + + # 5. Task 1 completes + to1 = TaskOutput(description="Research", raw="Research complete", agent="Researcher") + adapter._on_task_completed(None, TaskCompletedEvent(output=to1, task_name="Research")) + + # 6. Task 2: Writing + adapter._on_task_started( + None, TaskStartedEvent(context="write about quantum computing", task_name="Write Report", agent_role="Writer") + ) + + # 6a. Agent execution starts within task 2 + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Writer", task_prompt="Write the report") + ) + + # 7. Another LLM call + response2 = {"content": "Final report..."} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="gpt-4o", call_id="c2", call_type="llm_call", response=response2) + ) + + # 7a. Agent execution completes + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="Writer", output="Report written") + ) + + # 8. Task 2 completes + to2 = TaskOutput(description="Write Report", raw="Report written", agent="Writer") + adapter._on_task_completed(None, TaskCompletedEvent(output=to2, task_name="Write Report")) + + # 9. Crew completes + final = TaskOutput(description="final", raw="All done", agent="Writer") + adapter._on_crew_completed( + None, CrewKickoffCompletedEvent(crew_name="Analysis Crew", output=final, total_tokens=1500) + ) + + # Verify full event trace + events = uploaded["events"] + assert uploaded["trace_id"] is not None + + # Count event types + agent_inputs = find_events(events, "agent.input") + agent_outputs = find_events(events, "agent.output") + model_invokes = find_events(events, "model.invoke") + tool_calls = find_events(events, "tool.call") + tool_results = find_events(events, "tool.result") + cost_records = find_events(events, "cost.record") + + # crew + 2 tasks + 2 agent executions = 5 agent.input events + assert len(agent_inputs) == 5 + # crew + 2 tasks + 2 agent executions = 5 agent.output events + assert len(agent_outputs) == 5 + assert len(model_invokes) == 2 # 2 LLM calls + assert len(tool_calls) == 1 # 1 tool.call (started) + assert len(tool_results) == 1 # 1 tool.result (finished) + assert len(cost_records) >= 1 # at least crew total_tokens + + # Verify span hierarchy: tasks are children of crew + crew_span = agent_inputs[0]["span_id"] + task_inputs = [e for e in agent_inputs if e["payload"].get("task_name")] + for task_event in task_inputs: + assert task_event["parent_span_id"] == crew_span + + # Verify all events share the same trace_id + trace_ids = {e["trace_id"] for e in events} + assert len(trace_ids) == 1 + + # Verify sequence ordering + sequence_ids = [e["sequence_id"] for e in events] + assert sequence_ids == sorted(sequence_ids) + + # Verify attestation was built + assert uploaded["attestation"].get("root_hash") is not None + + +class TestEventBusIntegration: + """Test that the adapter actually receives events through the real CrewAI event bus.""" + + def test_events_flow_through_bus(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = CrewAIAdapter(mock_client) + + with crewai_event_bus.scoped_handlers(): + adapter.connect() + + # Emit events on the real bus — adapter should pick them up. + # Flush between events so the async started-handler completes + # before completed-handler triggers _flush() (which resets state). + crewai_event_bus.emit(None, event=CrewKickoffStartedEvent(crew_name="BusCrew", inputs={"x": 1})) + crewai_event_bus.flush(timeout=5.0) + + to = TaskOutput(description="t", raw="bus result", agent="A") + crewai_event_bus.emit(None, event=CrewKickoffCompletedEvent(crew_name="BusCrew", output=to)) + crewai_event_bus.flush(timeout=5.0) + + events = uploaded["events"] + assert len(events) >= 2 + + agent_in = find_event(events, "agent.input") + assert agent_in["payload"]["crew_name"] == "BusCrew" + + agent_out = find_event(events, "agent.output") + assert agent_out["payload"]["crew_name"] == "BusCrew" + + def test_scoped_handlers_cleanup(self, mock_client): + """Verify that scoped_handlers prevents handler leaks between tests.""" + uploaded = capture_framework_trace(mock_client) + adapter = CrewAIAdapter(mock_client) + + with crewai_event_bus.scoped_handlers(): + adapter.connect() + + # Events emitted AFTER scope should NOT be captured + crewai_event_bus.emit(None, event=CrewKickoffStartedEvent(crew_name="Ghost", inputs={})) + crewai_event_bus.flush(timeout=2.0) + + # Nothing should have been captured (no flush happened either) + assert uploaded.get("events") is None or len(uploaded.get("events", [])) == 0 + + +class TestCaptureConfigGating: + """Verify CaptureConfig correctly gates event types.""" + + def test_minimal_config_skips_model_and_tool(self, mock_client): + from layerlens.instrument._capture_config import CaptureConfig + + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig.minimal() # l3_model_metadata=False, l5a_tool_calls=False + adapter = CrewAIAdapter(mock_client, capture_config=config) + + with crewai_event_bus.scoped_handlers(): + adapter.connect() + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # These should be filtered by CaptureConfig + response = {"content": "hi", "usage": {"prompt_tokens": 10, "completion_tokens": 5}} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="gpt-4o", call_id="c1", call_type="llm_call", response=response) + ) + now = datetime.datetime.now() + adapter._on_tool_started( + None, ToolUsageStartedEvent(tool_name="x", tool_args="y", agent_key="a1") + ) + adapter._on_tool_finished( + None, ToolUsageFinishedEvent(tool_name="x", tool_args="y", agent_key="a1", started_at=now, finished_at=now, output="z") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + # model.invoke should be filtered out + assert len(find_events(events, "model.invoke")) == 0 + # tool.call and tool.result should be filtered out + assert len(find_events(events, "tool.call")) == 0 + assert len(find_events(events, "tool.result")) == 0 + # agent.input and agent.output should still be there (L1 is enabled) + assert len(find_events(events, "agent.input")) >= 1 + assert len(find_events(events, "agent.output")) >= 1 + # cost.record IS always-enabled, so if tokens were extracted it should be there + cost_events = find_events(events, "cost.record") + assert len(cost_events) >= 1 # cost.record bypasses CaptureConfig + + +class TestFlowEvents: + """Test CrewAI Flow lifecycle event handling.""" + + def test_flow_start_and_finish(self, adapter_and_trace): + from crewai.events import FlowStartedEvent, FlowFinishedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_flow_started(None, FlowStartedEvent(flow_name="AnalysisFlow", inputs={"topic": "AI"})) + adapter._on_flow_finished(None, FlowFinishedEvent(flow_name="AnalysisFlow", result="done", state={})) + + events = uploaded["events"] + flow_in = find_event(events, "agent.input") + assert flow_in["payload"]["flow_name"] == "AnalysisFlow" + assert flow_in["payload"]["input"] == {"topic": "AI"} + assert flow_in["span_name"] == "flow:AnalysisFlow" + + flow_out = find_event(events, "agent.output") + assert flow_out["payload"]["flow_name"] == "AnalysisFlow" + assert flow_out["payload"]["duration_ns"] > 0 + + +class TestMCPToolEvents: + """Test MCP tool execution event handling.""" + + def test_mcp_tool_completed(self, adapter_and_trace): + from crewai.events import MCPToolExecutionCompletedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + now = datetime.datetime.now() + adapter._on_mcp_tool_completed( + None, + MCPToolExecutionCompletedEvent( + tool_name="read_file", + tool_args={"path": "/etc/hosts"}, + server_name="filesystem", + server_url="stdio://mcp-fs", + transport_type="stdio", + result="127.0.0.1 localhost", + started_at=now, + completed_at=now, + execution_duration_ms=42, + ), + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["tool_name"] == "read_file" + assert tool_call["payload"]["mcp_server"] == "filesystem" + assert tool_call["payload"]["latency_ms"] == 42 + assert tool_call["payload"]["output"] == "127.0.0.1 localhost" + + def test_mcp_tool_failed(self, adapter_and_trace): + from crewai.events import MCPToolExecutionFailedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_mcp_tool_failed( + None, + MCPToolExecutionFailedEvent( + tool_name="exec_sql", + tool_args={"query": "DROP TABLE users"}, + server_name="db-server", + server_url="http://localhost:3000", + transport_type="http", + error="permission denied", + ), + ) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="mcp fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + mcp_error = [e for e in errors if e["payload"].get("mcp_server")] + assert len(mcp_error) == 1 + assert mcp_error[0]["payload"]["tool_name"] == "exec_sql" + assert mcp_error[0]["payload"]["mcp_server"] == "db-server" + + +class TestLLMLatencyTracking: + """Test LLM call latency computation from start→complete events.""" + + def test_latency_computed_from_started_event(self, adapter_and_trace): + from crewai.events import LLMCallStartedEvent + + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + # Start event stores timestamp + adapter._on_llm_started(None, LLMCallStartedEvent( + model="gpt-4o", call_id="latency_test", messages=[], call_type="llm_call", + )) + + # Small delay to get measurable latency + import time + time.sleep(0.01) + + # Complete event computes latency + response = {"content": "hi", "usage": {"prompt_tokens": 5, "completion_tokens": 3}} + adapter._on_llm_completed(None, LLMCallCompletedEvent( + model="gpt-4o", call_id="latency_test", call_type="llm_call", response=response, + )) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + model_invoke = find_event(events, "model.invoke") + assert "latency_ms" in model_invoke["payload"] + assert model_invoke["payload"]["latency_ms"] >= 5 # at least 5ms from the sleep + + +class TestAgentExecutionLifecycle: + """Test agent execution start/complete/error events.""" + + def test_agent_execution_started(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="T", agent_role="Researcher")) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct( + agent_role="Researcher", task_prompt="Find AI papers", tools=[] + ) + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + agent_inputs = find_events(events, "agent.input") + # Filter for agent execution events (have agent_role but NOT task_name) + agent_exec = [e for e in agent_inputs if e["payload"].get("agent_role") == "Researcher" and "task_name" not in e["payload"]] + assert len(agent_exec) == 1 + assert agent_exec[0]["payload"]["framework"] == "crewai" + assert agent_exec[0]["payload"]["task_prompt"] == "Find AI papers" + + def test_agent_execution_completed(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Writer") + ) + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="Writer", output="Final draft") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + agent_outputs = find_events(events, "agent.output") + agent_out = [e for e in agent_outputs if e["payload"].get("agent_role") == "Writer"] + assert len(agent_out) == 1 + assert agent_out[0]["payload"]["status"] == "ok" + assert agent_out[0]["payload"]["output"] == "Final draft" + + def test_agent_execution_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="Researcher") + ) + adapter._on_agent_execution_error( + None, AgentExecutionErrorEvent.model_construct(agent_role="Researcher", error="agent crashed") + ) + + adapter._on_crew_failed(None, CrewKickoffFailedEvent(crew_name="C", error="agent fail")) + + events = uploaded["events"] + errors = find_events(events, "agent.error") + agent_err = [e for e in errors if e["payload"].get("agent_role") == "Researcher"] + assert len(agent_err) == 1 + assert agent_err[0]["payload"]["error"] == "agent crashed" + + def test_agent_span_hierarchy(self, adapter_and_trace): + """Agent execution events are children of the current task span.""" + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="T1", agent_role="R")) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="R") + ) + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="R", output="done") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + # Find the task span_id + task_inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("task_name") == "T1"] + assert len(task_inputs) == 1 + task_span = task_inputs[0]["span_id"] + + # Agent execution should be parented to task (filter out task event which also has agent_role) + agent_exec_inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("agent_role") == "R" and "task_name" not in e["payload"]] + assert len(agent_exec_inputs) == 1 + assert agent_exec_inputs[0]["parent_span_id"] == task_span + + def test_llm_parented_to_agent(self, adapter_and_trace): + """LLM events should be children of the current agent execution span.""" + adapter, uploaded = adapter_and_trace + adapter._on_crew_started(None, CrewKickoffStartedEvent(crew_name="C", inputs={})) + adapter._on_task_started(None, TaskStartedEvent(context="ctx", task_name="T1", agent_role="R")) + + adapter._on_agent_execution_started( + None, AgentExecutionStartedEvent.model_construct(agent_role="R") + ) + + response = {"content": "hi", "usage": {"prompt_tokens": 5, "completion_tokens": 3}} + adapter._on_llm_completed( + None, LLMCallCompletedEvent(model="gpt-4o", call_id="c1", call_type="llm_call", response=response) + ) + + adapter._on_agent_execution_completed( + None, AgentExecutionCompletedEvent.model_construct(agent_role="R", output="done") + ) + + to = TaskOutput(description="t", raw="ok", agent="R") + adapter._on_crew_completed(None, CrewKickoffCompletedEvent(crew_name="C", output=to)) + + events = uploaded["events"] + # Find the agent execution span_id (not the task event which also has agent_role) + agent_exec_inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("agent_role") == "R" and "task_name" not in e["payload"]] + assert len(agent_exec_inputs) == 1 + agent_span = agent_exec_inputs[0]["span_id"] + + # LLM event should be parented to agent execution + model_invoke = find_event(events, "model.invoke") + assert model_invoke["parent_span_id"] == agent_span diff --git a/tests/instrument/adapters/frameworks/test_langchain.py b/tests/instrument/adapters/frameworks/test_langchain.py index d2a3057..db82d0d 100644 --- a/tests/instrument/adapters/frameworks/test_langchain.py +++ b/tests/instrument/adapters/frameworks/test_langchain.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import BaseCallbackHandler +from layerlens.instrument._capture_config import CaptureConfig from layerlens.instrument.adapters.frameworks.langchain import LangChainCallbackHandler from .conftest import capture_framework_trace, find_event, find_events @@ -23,14 +24,21 @@ def test_name(self): handler = LangChainCallbackHandler(Mock()) assert handler.name == "langchain" + def test_adapter_info(self): + handler = LangChainCallbackHandler(Mock()) + info = handler.adapter_info() + assert info.name == "langchain" + assert info.adapter_type == "framework" + assert info.connected is False + # --------------------------------------------------------------------------- -# Emit events +# Chain lifecycle # --------------------------------------------------------------------------- -class TestEmitsEvents: - def test_chain_lifecycle(self, mock_client): +class TestChainLifecycle: + def test_chain_emits_input_and_output(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) @@ -49,49 +57,99 @@ def test_chain_lifecycle(self, mock_client): agent_output = find_event(events, "agent.output") assert agent_output["payload"]["status"] == "ok" + assert agent_output["payload"]["output"] == {"output": "AI is..."} + + def test_chain_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) + handler.on_chain_error(ValueError("broke"), run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "broke" + assert error["payload"]["status"] == "error" + + +# --------------------------------------------------------------------------- +# LLM lifecycle — single merged model.invoke +# --------------------------------------------------------------------------- + - def test_llm_lifecycle(self, mock_client): +def _make_llm_response( + text: str = "AI is...", + model_name: str = "gpt-4", + prompt_tokens: int = 100, + completion_tokens: int = 50, +) -> Mock: + resp = Mock() + resp.generations = [[Mock(text=text)]] + resp.llm_output = { + "model_name": model_name, + "token_usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + return resp + + +class TestLLMLifecycle: + def test_single_model_invoke_with_merged_data(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() llm_id = uuid4() - handler.on_chain_start( - {"name": "Chain"}, {"input": "x"}, run_id=chain_id, - ) + handler.on_chain_start({"name": "Chain"}, {"input": "x"}, run_id=chain_id) handler.on_llm_start( - {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, + {"name": "ChatOpenAI"}, ["What is AI?"], - run_id=llm_id, - parent_run_id=chain_id, + run_id=llm_id, parent_run_id=chain_id, ) - - llm_response = Mock() - llm_response.generations = [[Mock(text="AI is...")]] - llm_response.llm_output = { - "token_usage": {"total_tokens": 50}, - "model_name": "gpt-4", - } - handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) handler.on_chain_end({"output": "AI is..."}, run_id=chain_id) events = uploaded["events"] - model_invokes = find_events(events, "model.invoke") - assert len(model_invokes) >= 1 - # Start event has name and messages - start_invoke = [m for m in model_invokes if m["payload"].get("name") == "ChatOpenAI"] - assert len(start_invoke) == 1 - # End event has model and output - end_invoke = [m for m in model_invokes if m["payload"].get("model") == "gpt-4"] - assert len(end_invoke) == 1 - assert end_invoke[0]["payload"]["output_message"] == "AI is..." + # Single event, not two + assert len(model_invokes) == 1 + + invoke = model_invokes[0] + assert invoke["payload"]["name"] == "ChatOpenAI" + assert invoke["payload"]["model"] == "gpt-4" + assert invoke["payload"]["messages"] == ["What is AI?"] + assert invoke["payload"]["output_message"] == "AI is..." + assert invoke["payload"]["latency_ms"] >= 0 + + def test_normalized_token_fields(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["p"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + invoke = find_event(events, "model.invoke") + assert invoke["payload"]["tokens_prompt"] == 100 + assert invoke["payload"]["tokens_completion"] == 50 + assert invoke["payload"]["tokens_total"] == 150 cost = find_event(events, "cost.record") - assert cost["payload"]["total_tokens"] == 50 + assert cost["payload"]["tokens_prompt"] == 100 + assert cost["payload"]["tokens_completion"] == 50 + assert cost["payload"]["tokens_total"] == 150 + assert cost["payload"]["model"] == "gpt-4" - def test_chat_model_start(self, mock_client): + def test_chat_model_start_serializes_messages(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) @@ -105,8 +163,11 @@ def test_chat_model_start(self, mock_client): handler.on_chat_model_start( {"name": "ChatAnthropic"}, [[msg]], + run_id=chat_id, parent_run_id=chain_id, + ) + handler.on_llm_end( + _make_llm_response(text="Hi!", model_name="claude-3"), run_id=chat_id, - parent_run_id=chain_id, ) handler.on_chain_end({}, run_id=chain_id) @@ -114,6 +175,101 @@ def test_chat_model_start(self, mock_client): invoke = find_event(events, "model.invoke") assert invoke["payload"]["name"] == "ChatAnthropic" assert invoke["payload"]["messages"] == [[{"type": "human", "content": "Hello"}]] + assert invoke["payload"]["output_message"] == "Hi!" + + def test_llm_error_emits_model_invoke_with_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_error(RuntimeError("timeout"), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + invoke = find_event(events, "model.invoke") + assert invoke["payload"]["error"] == "timeout" + assert invoke["payload"]["latency_ms"] >= 0 + + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "timeout" + + +# --------------------------------------------------------------------------- +# CaptureConfig content gating +# --------------------------------------------------------------------------- + + +class TestCaptureConfig: + def test_capture_content_false_strips_inputs_and_messages(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {"secret": "data"}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["secret prompt"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_end(_make_llm_response(text="secret reply"), run_id=llm_id) + handler.on_chain_end({"output": "secret"}, run_id=chain_id) + + events = uploaded["events"] + + # Chain events should not contain content + agent_input = find_event(events, "agent.input") + assert "input" not in agent_input["payload"] + agent_output = find_event(events, "agent.output") + assert "output" not in agent_output["payload"] + + # Model invoke should not contain messages or output + invoke = find_event(events, "model.invoke") + assert "messages" not in invoke["payload"] + assert "output_message" not in invoke["payload"] + # But structural fields are still present + assert invoke["payload"]["name"] == "LLM" + + def test_capture_content_false_strips_tool_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "secret query", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("secret results", run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert "input" not in tool_call["payload"] + assert tool_call["payload"]["name"] == "search" + + tool_result = find_event(events, "tool.result") + assert "output" not in tool_result["payload"] + + def test_capture_content_false_strips_retriever_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) + + chain_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start({"name": "vs"}, "secret query", run_id=ret_id, parent_run_id=chain_id) + docs = [Mock(page_content="secret doc", metadata={"source": "a.txt"})] + handler.on_retriever_end(docs, run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert "input" not in tool_call["payload"] + + tool_result = find_event(events, "tool.result") + assert "output" not in tool_result["payload"] # --------------------------------------------------------------------------- @@ -189,69 +345,105 @@ def test_combined_tools_and_retrievers(self, mock_client): assert len(find_events(events, "tool.call")) == 2 assert len(find_events(events, "tool.result")) == 2 + def test_tool_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) -# --------------------------------------------------------------------------- -# Error handling -# --------------------------------------------------------------------------- + chain_id = uuid4() + tool_id = uuid4() + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_error(RuntimeError("404"), run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) -class TestErrors: - def test_chain_error(self, mock_client): + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "404" + + def test_retriever_error(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() - handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) - handler.on_chain_error(ValueError("broke"), run_id=chain_id) + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) + handler.on_retriever_error(ConnectionError("down"), run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "broke" - assert error["payload"]["status"] == "error" + assert error["payload"]["error"] == "down" - def test_llm_error(self, mock_client): + +# --------------------------------------------------------------------------- +# Agent action / finish callbacks +# --------------------------------------------------------------------------- + + +class TestAgentCallbacks: + def test_agent_action_emits_input(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() - llm_id = uuid4() + agent_id = uuid4() - handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) - handler.on_llm_start({"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id) - handler.on_llm_error(RuntimeError("timeout"), run_id=llm_id) + handler.on_chain_start({"name": "AgentExecutor"}, {}, run_id=chain_id) + + action = Mock() + action.tool = "search" + action.tool_input = "what is AI" + action.log = "Thought: I need to search" + handler.on_agent_action(action, run_id=agent_id, parent_run_id=chain_id) handler.on_chain_end({}, run_id=chain_id) - error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "timeout" + events = uploaded["events"] + inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("tool") == "search"] + assert len(inputs) == 1 + assert inputs[0]["payload"]["tool_input"] == "what is AI" + assert inputs[0]["payload"]["log"] == "Thought: I need to search" - def test_tool_error(self, mock_client): + def test_agent_finish_emits_output(self, mock_client): uploaded = capture_framework_trace(mock_client) handler = LangChainCallbackHandler(mock_client) chain_id = uuid4() - tool_id = uuid4() + agent_id = uuid4() - handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) - handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) - handler.on_tool_error(RuntimeError("404"), run_id=tool_id) + handler.on_chain_start({"name": "AgentExecutor"}, {}, run_id=chain_id) + + finish = Mock() + finish.return_values = {"output": "AI is artificial intelligence"} + finish.log = "Final Answer: AI is artificial intelligence" + handler.on_agent_finish(finish, run_id=agent_id, parent_run_id=chain_id) handler.on_chain_end({}, run_id=chain_id) - error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "404" + events = uploaded["events"] + outputs = [e for e in find_events(events, "agent.output") if e["payload"].get("log")] + assert len(outputs) == 1 + assert outputs[0]["payload"]["output"] == {"output": "AI is artificial intelligence"} - def test_retriever_error(self, mock_client): + def test_agent_action_respects_capture_content(self, mock_client): uploaded = capture_framework_trace(mock_client) - handler = LangChainCallbackHandler(mock_client) + handler = LangChainCallbackHandler(mock_client, capture_config=CaptureConfig(capture_content=False)) chain_id = uuid4() - ret_id = uuid4() + agent_id = uuid4() handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) - handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) - handler.on_retriever_error(ConnectionError("down"), run_id=ret_id) + action = Mock() + action.tool = "secret_tool" + action.tool_input = "secret input" + action.log = "secret reasoning" + handler.on_agent_action(action, run_id=agent_id, parent_run_id=chain_id) handler.on_chain_end({}, run_id=chain_id) - error = find_event(uploaded["events"], "agent.error") - assert error["payload"]["error"] == "down" + events = uploaded["events"] + inputs = [e for e in find_events(events, "agent.input") if e["payload"].get("tool") == "secret_tool"] + assert len(inputs) == 1 + assert "tool_input" not in inputs[0]["payload"] + assert "log" not in inputs[0]["payload"] # --------------------------------------------------------------------------- @@ -272,16 +464,13 @@ def test_llm_parent_is_chain(self, mock_client): {"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id, ) - llm_response = Mock() - llm_response.generations = [[Mock(text="out")]] - llm_response.llm_output = {} - handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) handler.on_chain_end({}, run_id=chain_id) events = uploaded["events"] chain_input = find_event(events, "agent.input") - llm_invoke = [e for e in find_events(events, "model.invoke") if e["payload"].get("name") == "LLM"][0] - assert llm_invoke["parent_span_id"] == chain_input["span_id"] + invoke = find_event(events, "model.invoke") + assert invoke["parent_span_id"] == chain_input["span_id"] # --------------------------------------------------------------------------- @@ -329,17 +518,23 @@ def test_llm_end_no_output(self, mock_client): handler.on_llm_end(empty_response, run_id=llm_id) handler.on_chain_end({}, run_id=chain_id) - # Should complete without error — no model.invoke end event since no output/model + # Should emit model.invoke with name but no output_message + invoke = find_event(uploaded["events"], "model.invoke") + assert invoke["payload"]["name"] == "LLM" + assert "output_message" not in invoke["payload"] + def test_llm_end_without_start(self, mock_client): + """on_llm_end without a preceding on_llm_start should not crash.""" + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) -# --------------------------------------------------------------------------- -# adapter_info -# --------------------------------------------------------------------------- + chain_id = uuid4() + llm_id = uuid4() + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_end(_make_llm_response(), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) -class TestAdapterInfo: - def test_info(self): - handler = LangChainCallbackHandler(Mock()) - info = handler.adapter_info() - assert info.name == "langchain" - assert info.adapter_type == "framework" + # Should still emit model.invoke from the response data + invoke = find_event(uploaded["events"], "model.invoke") + assert invoke["payload"]["model"] == "gpt-4" diff --git a/tests/instrument/adapters/frameworks/test_langgraph.py b/tests/instrument/adapters/frameworks/test_langgraph.py index 7ff6e9d..87097ad 100644 --- a/tests/instrument/adapters/frameworks/test_langgraph.py +++ b/tests/instrument/adapters/frameworks/test_langgraph.py @@ -52,7 +52,7 @@ def test_llm_events_inherited(self, mock_client): events = uploaded["events"] assert len(find_events(events, "model.invoke")) >= 1 - assert find_event(events, "cost.record")["payload"]["total_tokens"] == 10 + assert find_event(events, "cost.record")["payload"]["tokens_total"] == 10 def test_tool_events_inherited(self, mock_client): uploaded = capture_framework_trace(mock_client) diff --git a/tests/instrument/adapters/frameworks/test_openai_agents.py b/tests/instrument/adapters/frameworks/test_openai_agents.py new file mode 100644 index 0000000..111be7d --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_openai_agents.py @@ -0,0 +1,823 @@ +"""Tests for the OpenAI Agents SDK adapter using real SDK types. + +Uses real TracingProcessor, SpanImpl, Trace, and span data types. +No mocking of Agents SDK internals — only our mock_client for upload capture. +""" +from __future__ import annotations + +import json +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest + +import sys +if sys.version_info < (3, 10): + pytest.skip("openai-agents requires Python >= 3.10", allow_module_level=True) +try: + import agents # noqa: F401 +except (ImportError, Exception): + pytest.skip("openai-agents not installed or incompatible", allow_module_level=True) + +from agents.tracing import TracingProcessor, set_trace_processors # noqa: E402 +from agents.tracing.spans import SpanImpl # noqa: E402 +from agents.tracing.traces import TraceImpl # noqa: E402 +from agents.tracing.span_data import ( # noqa: E402 + AgentSpanData, + HandoffSpanData, + FunctionSpanData, + GuardrailSpanData, + GenerationSpanData, +) + +from layerlens.instrument._capture_config import CaptureConfig # noqa: E402 +from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter # noqa: E402 + +from .conftest import capture_framework_trace, find_event, find_events # noqa: E402 + +# -- Helpers -- + + +class _NoOpProcessor(TracingProcessor): + """Minimal processor that does nothing — used to reset global state.""" + + def on_trace_start(self, trace): + pass + + def on_trace_end(self, trace): + pass + + def on_span_start(self, span): + pass + + def on_span_end(self, span): + pass + + def shutdown(self): + pass + + def force_flush(self): + pass + + +_noop = _NoOpProcessor() + + +def _make_span( + _adapter: Any, + trace_id: str, + span_id: str, + span_data: Any, + parent_id: str | None = None, +) -> SpanImpl: + """Create a real SpanImpl for testing. + + Uses a NoOpProcessor internally so span.start()/finish() don't + double-trigger our adapter. Tests call adapter.on_span_end() manually. + The _adapter param is accepted for call-site readability but unused. + """ + return SpanImpl( + trace_id=trace_id, + span_id=span_id, + parent_id=parent_id, + processor=_noop, + span_data=span_data, + tracing_api_key=None, + ) + + +def _make_trace(name: str = "test_trace", trace_id: str = "trace_001", processor: Any = None) -> TraceImpl: + """Create a real TraceImpl for testing. + + If processor is None, uses a no-op processor. In actual tests, + pass the adapter's processor so trace lifecycle events route correctly. + """ + proc = processor or _NoOpProcessor() + return TraceImpl(name=name, trace_id=trace_id, group_id=None, metadata=None, processor=proc) + + +# -- Fixtures -- + + +@pytest.fixture +def adapter_and_trace(mock_client): + """Create adapter, connect, yield (adapter, uploaded_dict), then clean up. + + The adapter IS the TracingProcessor, so tests call adapter.on_span_end() etc. + directly — no separate processor object. + """ + uploaded = capture_framework_trace(mock_client) + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + yield adapter, uploaded + adapter.disconnect() + set_trace_processors([]) # ensure clean slate + + +@pytest.fixture(autouse=True) +def clean_processors(): + """Reset global trace processors after each test.""" + yield + set_trace_processors([]) + + +# -- Tests -- + + +class TestOpenAIAgentsAdapterLifecycle: + def test_connect_sets_connected(self, mock_client): + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + assert adapter.is_connected + info = adapter.adapter_info() + assert info.name == "openai-agents" + assert info.adapter_type == "framework" + adapter.disconnect() + + def test_disconnect_clears_state(self, mock_client): + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + adapter.disconnect() + assert not adapter.is_connected + + def test_connect_without_agents_raises(self, mock_client, monkeypatch): + import layerlens.instrument.adapters.frameworks.openai_agents as mod + + monkeypatch.setattr(mod, "_HAS_OPENAI_AGENTS", False) + adapter = OpenAIAgentsAdapter(mock_client) + with pytest.raises(ImportError, match="openai-agents"): + adapter.connect() + + +class TestAgentSpans: + """Test agent span handling with real AgentSpanData.""" + + def test_agent_span_emits_input_and_output(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t1") + + # Simulate trace + agent span lifecycle + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t1", "s_agent", + AgentSpanData(name="research_agent", tools=["search", "browse"], handoffs=["writer"]), + ) + span.start() + adapter.on_span_start(span) + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + assert len(events) >= 2 + + inp = find_event(events, "agent.input") + assert inp["payload"]["agent_name"] == "research_agent" + assert inp["payload"]["tools"] == ["search", "browse"] + assert inp["payload"]["handoffs"] == ["writer"] + assert inp["payload"]["framework"] == "openai-agents" + assert inp["span_id"] == "s_agent" + + out = find_event(events, "agent.output") + assert out["payload"]["agent_name"] == "research_agent" + assert out["payload"]["status"] == "ok" + assert out["span_id"] == "s_agent" + + def test_agent_span_with_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_err") + + adapter.on_trace_start(trace) + + span = _make_span(adapter,"t_err", "s_err", AgentSpanData(name="buggy_agent")) + span.start() + adapter.on_span_start(span) + span.set_error({"message": "Agent crashed", "data": {"step": 3}}) + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["agent_name"] == "buggy_agent" + assert err["payload"]["status"] == "error" + assert "Agent crashed" in str(err["payload"]["error"]) + + def test_nested_agent_spans(self, adapter_and_trace): + """Multi-agent: parent agent delegates to child agent.""" + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_nested") + + adapter.on_trace_start(trace) + + # Parent agent + parent = _make_span(adapter,"t_nested", "s_parent", AgentSpanData(name="orchestrator")) + parent.start() + adapter.on_span_start(parent) + + # Child agent + child = _make_span(adapter,"t_nested", "s_child", AgentSpanData(name="researcher"), parent_id="s_parent") + child.start() + adapter.on_span_start(child) + child.finish() + adapter.on_span_end(child) + + parent.finish() + adapter.on_span_end(parent) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + agent_inputs = find_events(events, "agent.input") + assert len(agent_inputs) == 2 + + # Child should have parent_span_id pointing to parent + child_input = [e for e in agent_inputs if e["payload"]["agent_name"] == "researcher"][0] + assert child_input["parent_span_id"] == "s_parent" + + +class TestGenerationSpans: + """Test LLM generation span handling.""" + + def test_generation_emits_model_invoke(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_gen") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_gen", "s_gen", + GenerationSpanData( + input=[{"role": "user", "content": "What is 2+2?"}], + output=[{"role": "assistant", "content": "4"}], + model="gpt-4o", + model_config={"temperature": 0.7}, + usage={"input_tokens": 50, "output_tokens": 10}, + ), + parent_id="s_agent", + ) + span.start() + adapter.on_span_start(span) + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + me = find_event(events, "model.invoke") + assert me["payload"]["model"] == "gpt-4o" + assert me["payload"]["tokens_prompt"] == 50 + assert me["payload"]["tokens_completion"] == 10 + assert me["payload"]["tokens_total"] == 60 + assert me["payload"]["latency_ms"] >= 0 + assert me["payload"]["messages"] == [{"role": "user", "content": "What is 2+2?"}] + assert me["payload"]["output_message"] == [{"role": "assistant", "content": "4"}] + assert me["span_id"] == "s_gen" + assert me["parent_span_id"] == "s_agent" + + def test_generation_emits_cost_record(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_cost") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_cost", "s_cost", + GenerationSpanData( + input=[], output=[], model="gpt-4o-mini", + model_config={}, + usage={"input_tokens": 100, "output_tokens": 25}, + ), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + cost = find_event(events, "cost.record") + assert cost["payload"]["model"] == "gpt-4o-mini" + assert cost["payload"]["tokens_prompt"] == 100 + assert cost["payload"]["tokens_completion"] == 25 + assert cost["payload"]["tokens_total"] == 125 + + def test_generation_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_gen_err") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_gen_err", "s_gen_err", + GenerationSpanData( + input=[{"role": "user", "content": "fail"}], + output=[], model="gpt-4o", + model_config={}, usage={}, + ), + ) + span.start() + span.set_error({"message": "Rate limit exceeded"}) + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert "Rate limit" in str(err["payload"]["error"]) + + def test_multiple_generations(self, adapter_and_trace): + """Agent makes multiple LLM calls (e.g. tool use loop).""" + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_multi_gen") + + adapter.on_trace_start(trace) + + for i, (inp_tok, out_tok) in enumerate([(50, 15), (80, 20)]): + span = _make_span( + adapter,"t_multi_gen", f"s_gen_{i}", + GenerationSpanData( + input=[], output=[], model="gpt-4o", + model_config={}, + usage={"input_tokens": inp_tok, "output_tokens": out_tok}, + ), + parent_id="s_agent", + ) + span.start() + span.finish() + adapter.on_span_end(span) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + gens = find_events(events, "model.invoke") + assert len(gens) == 2 + assert gens[0]["span_id"] != gens[1]["span_id"] + + +class TestFunctionSpans: + """Test tool/function span handling.""" + + def test_function_span_emits_tool_call(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_func") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_func", "s_func", + FunctionSpanData(name="get_weather", input='{"city":"NYC"}', output='{"temp":72}'), + parent_id="s_agent", + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + tc = find_event(events, "tool.call") + assert tc["payload"]["tool_name"] == "get_weather" + assert tc["payload"]["input"] == '{"city":"NYC"}' + assert tc["payload"]["output"] == '{"temp":72}' + assert tc["payload"]["latency_ms"] >= 0 + assert tc["parent_span_id"] == "s_agent" + + def test_function_span_with_error(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_func_err") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_func_err", "s_func_err", + FunctionSpanData(name="dangerous_tool", input="delete all", output=None), + ) + span.start() + span.set_error({"message": "Permission denied"}) + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["tool_name"] == "dangerous_tool" + assert "Permission denied" in str(err["payload"]["error"]) + + def test_function_span_with_mcp(self, adapter_and_trace): + """Function spans can include MCP data.""" + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_mcp") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_mcp", "s_mcp", + FunctionSpanData(name="mcp_tool", input="query", output="result"), + ) + # Set mcp_data manually + span.span_data.mcp_data = {"server": "my-mcp-server", "tool": "query_db"} + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + tc = find_event(events, "tool.call") + assert tc["payload"]["mcp_data"]["server"] == "my-mcp-server" + + +class TestHandoffSpans: + """Test handoff span handling.""" + + def test_handoff_emits_event(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_handoff") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_handoff", "s_handoff", + HandoffSpanData(from_agent="triage", to_agent="specialist"), + parent_id="s_agent", + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + ho = find_event(events, "agent.handoff") + assert ho["payload"]["from_agent"] == "triage" + assert ho["payload"]["to_agent"] == "specialist" + assert ho["parent_span_id"] == "s_agent" + + +class TestGuardrailSpans: + """Test guardrail span handling.""" + + def test_guardrail_emits_evaluation_result(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_guard") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_guard", "s_guard", + GuardrailSpanData(name="content_filter", triggered=True), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + ev = find_event(events, "evaluation.result") + assert ev["payload"]["guardrail_name"] == "content_filter" + assert ev["payload"]["triggered"] is True + + def test_guardrail_not_triggered(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_guard2") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_guard2", "s_guard2", + GuardrailSpanData(name="pii_detector", triggered=False), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + ev = find_event(events, "evaluation.result") + assert ev["payload"]["triggered"] is False + + +class TestFullAgentFlow: + """End-to-end test simulating a complete agent run with tools and handoff.""" + + def test_complete_flow(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_flow", name="customer_support") + + adapter.on_trace_start(trace) + + # Agent span + agent = _make_span(adapter,"t_flow", "s_agent", AgentSpanData(name="triage", tools=["classify"])) + agent.start() + adapter.on_span_start(agent) + + # LLM call + gen = _make_span( + adapter,"t_flow", "s_gen", + GenerationSpanData( + input=[{"role": "user", "content": "I need help"}], + output=[{"role": "assistant", "content": "Let me classify this"}], + model="gpt-4o-mini", + model_config={}, + usage={"input_tokens": 30, "output_tokens": 10}, + ), + parent_id="s_agent", + ) + gen.start() + gen.finish() + adapter.on_span_end(gen) + + # Tool call + tool = _make_span( + adapter,"t_flow", "s_tool", + FunctionSpanData(name="classify", input="I need help", output="billing"), + parent_id="s_agent", + ) + tool.start() + tool.finish() + adapter.on_span_end(tool) + + # Guardrail + guard = _make_span( + adapter,"t_flow", "s_guard", + GuardrailSpanData(name="safety_check", triggered=False), + parent_id="s_agent", + ) + guard.start() + guard.finish() + adapter.on_span_end(guard) + + # Handoff + handoff = _make_span( + adapter,"t_flow", "s_handoff", + HandoffSpanData(from_agent="triage", to_agent="billing_agent"), + parent_id="s_agent", + ) + handoff.start() + handoff.finish() + adapter.on_span_end(handoff) + + agent.finish() + adapter.on_span_end(agent) + + adapter.on_trace_end(trace) + + events = uploaded["events"] + types = [e["event_type"] for e in events] + + assert "agent.input" in types + assert "agent.output" in types + assert "model.invoke" in types + assert "cost.record" in types + assert "tool.call" in types + assert "evaluation.result" in types + assert "agent.handoff" in types + + # Verify ordering + seq_ids = [e["sequence_id"] for e in events] + assert seq_ids == sorted(seq_ids) + assert len(set(seq_ids)) == len(seq_ids) + + # Verify parent-child relationships + me = find_event(events, "model.invoke") + assert me["parent_span_id"] == "s_agent" + + tc = find_event(events, "tool.call") + assert tc["parent_span_id"] == "s_agent" + + +class TestCaptureConfigGating: + """Test that CaptureConfig gates events properly.""" + + def test_minimal_config(self, mock_client): + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig.minimal() + adapter = OpenAIAgentsAdapter(mock_client, capture_config=config) + adapter.connect() + + + trace = _make_trace(trace_id="t_min") + + adapter.on_trace_start(trace) + + # Agent span (L1 — should be captured) + agent = _make_span(adapter,"t_min", "s_agent", AgentSpanData(name="test")) + agent.start() + agent.finish() + adapter.on_span_end(agent) + + # Generation span (L3 — should be skipped) + gen = _make_span( + adapter,"t_min", "s_gen", + GenerationSpanData( + input=[], output=[], model="gpt-4o", + model_config={}, usage={"input_tokens": 10, "output_tokens": 5}, + ), + ) + gen.start() + gen.finish() + adapter.on_span_end(gen) + + # Tool span (L5a — should be skipped) + tool = _make_span( + adapter,"t_min", "s_tool", + FunctionSpanData(name="search", input="q", output="r"), + ) + tool.start() + tool.finish() + adapter.on_span_end(tool) + + adapter.on_trace_end(trace) + + events = uploaded.get("events", []) + types = [e["event_type"] for e in events] + + assert "agent.input" in types + assert "agent.output" in types + assert "model.invoke" not in types + assert "tool.call" not in types + # cost.record is always enabled + assert "cost.record" in types + + adapter.disconnect() + + +class TestConcurrentTraces: + """Test that multiple concurrent traces are isolated.""" + + def test_parallel_traces_isolated(self, mock_client): + all_uploads: List[Dict[str, Any]] = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + all_uploads.append(data[0]) + + mock_client.traces.upload = MagicMock(side_effect=_capture) + + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + + + # Two concurrent traces + t1 = _make_trace(trace_id="t_par_1") + t2 = _make_trace(trace_id="t_par_2") + + adapter.on_trace_start(t1) + adapter.on_trace_start(t2) + + # Agent in trace 1 + s1 = _make_span(adapter,"t_par_1", "s1", AgentSpanData(name="agent_1")) + s1.start() + s1.finish() + adapter.on_span_end(s1) + + # Agent in trace 2 + s2 = _make_span(adapter,"t_par_2", "s2", AgentSpanData(name="agent_2")) + s2.start() + s2.finish() + adapter.on_span_end(s2) + + adapter.on_trace_end(t1) + adapter.on_trace_end(t2) + + assert len(all_uploads) == 2 + + # Each trace should have its own events + names = set() + for upload in all_uploads: + for e in upload["events"]: + if e["event_type"] == "agent.input": + names.add(e["payload"]["agent_name"]) + + assert names == {"agent_1", "agent_2"} + + adapter.disconnect() + + +class TestErrorIsolation: + """Verify hooks never crash the SDK.""" + + def test_broken_collector_does_not_crash(self, mock_client): + adapter = OpenAIAgentsAdapter(mock_client) + adapter.connect() + + + trace = _make_trace(trace_id="t_safe") + adapter.on_trace_start(trace) + + # Break the collector + adapter._collectors["t_safe"] = None # type: ignore[assignment] + + # This should not raise + span = _make_span(adapter,"t_safe", "s_safe", AgentSpanData(name="test")) + span.start() + span.finish() + adapter.on_span_end(span) # Should log warning, not crash + + # Trace end should not crash either + adapter.on_trace_end(trace) + + adapter.disconnect() + + +class TestEdgeCases: + def test_empty_usage(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_empty") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_empty", "s_empty", + GenerationSpanData(input=[], output=[], model="gpt-4o", model_config={}, usage={}), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + me = find_event(events, "model.invoke") + assert "tokens_prompt" not in me["payload"] + assert "tokens_completion" not in me["payload"] + + def test_none_values_in_span_data(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_none") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_none", "s_none", + AgentSpanData(name="minimal_agent"), # no tools, no handoffs + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + inp = find_event(events, "agent.input") + assert inp["payload"]["agent_name"] == "minimal_agent" + assert "tools" not in inp["payload"] + assert "handoffs" not in inp["payload"] + + def test_function_span_with_none_output(self, adapter_and_trace): + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_none_out") + + adapter.on_trace_start(trace) + + span = _make_span( + adapter,"t_none_out", "s_func", + FunctionSpanData(name="void_tool", input="run", output=None), + ) + span.start() + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + tc = find_event(events, "tool.call") + assert tc["payload"]["tool_name"] == "void_tool" + # output should not be in payload since it was None + assert "output" not in tc["payload"] + + def test_span_duration_tracking(self, adapter_and_trace): + """Verify duration_ms is computed from span timing.""" + import time as _time + + adapter, uploaded = adapter_and_trace + + trace = _make_trace(trace_id="t_dur") + + adapter.on_trace_start(trace) + + span = _make_span(adapter,"t_dur", "s_dur", AgentSpanData(name="slow_agent")) + span.start() + _time.sleep(0.02) # 20ms + span.finish() + adapter.on_span_end(span) + adapter.on_trace_end(trace) + + events = uploaded["events"] + out = find_event(events, "agent.output") + assert out["payload"]["duration_ms"] >= 15 # allow tolerance diff --git a/tests/instrument/adapters/frameworks/test_pydantic_ai.py b/tests/instrument/adapters/frameworks/test_pydantic_ai.py new file mode 100644 index 0000000..c60ae7a --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_pydantic_ai.py @@ -0,0 +1,471 @@ +"""Tests for the PydanticAI adapter using the native Hooks capability API. + +Tests use PydanticAI's TestModel to exercise the real agent loop with +hooks firing at each lifecycle point — no monkey-patching or mocking of +PydanticAI internals. +""" +from __future__ import annotations + +import asyncio +from typing import Optional + +import pytest + +pydantic_ai = pytest.importorskip("pydantic_ai") + +from pydantic_ai import Agent # noqa: E402 +from pydantic_ai.models.test import TestModel # noqa: E402 + +from layerlens.instrument._capture_config import CaptureConfig # noqa: E402 +from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter # noqa: E402 + +from .conftest import capture_framework_trace, find_event, find_events # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_agent( + name: Optional[str] = None, + output_text: str = "Hello!", + model_name: str = "test", + tools: Optional[list] = None, +) -> Agent: + """Create a PydanticAI Agent with TestModel for deterministic testing.""" + agent = Agent( + model=TestModel(custom_output_text=output_text, model_name=model_name), + name=name, + ) + if tools: + for tool_fn in tools: + agent.tool_plain(tool_fn) + return agent + + +def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"72F in {city}" + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +class TestPydanticAIAdapterLifecycle: + def test_connect_injects_hooks(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent() + + caps_before = len(agent._root_capability.capabilities) + adapter.connect(target=agent) + + assert adapter.is_connected + assert len(agent._root_capability.capabilities) == caps_before + 1 + info = adapter.adapter_info() + assert info.name == "pydantic-ai" + assert info.adapter_type == "framework" + assert info.connected is True + + adapter.disconnect() + + def test_disconnect_removes_hooks(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent() + caps_before = len(agent._root_capability.capabilities) + + adapter.connect(target=agent) + adapter.disconnect() + + assert not adapter.is_connected + assert len(agent._root_capability.capabilities) == caps_before + + def test_connect_without_target_raises(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + with pytest.raises(ValueError, match="requires a target agent"): + adapter.connect() + + def test_connect_without_pydantic_ai_raises(self, mock_client, monkeypatch): + import layerlens.instrument.adapters.frameworks.pydantic_ai as mod + + monkeypatch.setattr(mod, "_HAS_PYDANTIC_AI", False) + adapter = PydanticAIAdapter(mock_client) + with pytest.raises(ImportError, match="pydantic-ai"): + adapter.connect(target=_make_agent()) + + +# --------------------------------------------------------------------------- +# run_sync +# --------------------------------------------------------------------------- + + +class TestRunSync: + def test_basic_run(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="The weather is sunny") + + adapter.connect(target=agent) + result = agent.run_sync("What is the weather?") + adapter.disconnect() + + assert result.output == "The weather is sunny" + events = uploaded["events"] + + inp = find_event(events, "agent.input") + assert inp["payload"]["framework"] == "pydantic-ai" + assert inp["payload"]["input"] == "What is the weather?" + + out = find_event(events, "agent.output") + assert out["payload"]["status"] == "ok" + assert out["payload"]["output"] == "The weather is sunny" + assert out["payload"]["latency_ms"] >= 0 + assert out["payload"]["tokens_prompt"] > 0 + assert out["payload"]["tokens_completion"] > 0 + + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] > 0 + + def test_named_agent(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(name="my_agent", output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("test") + adapter.disconnect() + + inp = find_event(uploaded["events"], "agent.input") + assert inp["payload"]["agent_name"] == "my_agent" + + +# --------------------------------------------------------------------------- +# async run +# --------------------------------------------------------------------------- + + +class TestRunAsync: + def test_async_run(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(name="async_agent", output_text="Async result") + + adapter.connect(target=agent) + result = asyncio.get_event_loop().run_until_complete(agent.run("async test")) + adapter.disconnect() + + assert result.output == "Async result" + + inp = find_event(uploaded["events"], "agent.input") + assert inp["payload"]["agent_name"] == "async_agent" + assert inp["payload"]["input"] == "async test" + + out = find_event(uploaded["events"], "agent.output") + assert out["payload"]["status"] == "ok" + + +# --------------------------------------------------------------------------- +# Model invocation events +# --------------------------------------------------------------------------- + + +class TestModelInvocation: + def test_model_invoke_emitted(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="hello", model_name="gpt-4o-test") + + adapter.connect(target=agent) + agent.run_sync("hi") + adapter.disconnect() + + model_invokes = find_events(uploaded["events"], "model.invoke") + assert len(model_invokes) >= 1 + assert model_invokes[0]["payload"]["model"] == "gpt-4o-test" + assert model_invokes[0]["payload"]["tokens_prompt"] > 0 + + def test_model_invoke_with_tools_has_two_calls(self, mock_client): + """When a tool is called, TestModel makes 2 model requests: + first to call the tool, then to produce the final text.""" + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather NYC") + adapter.disconnect() + + model_invokes = find_events(uploaded["events"], "model.invoke") + assert len(model_invokes) == 2 + + +# --------------------------------------------------------------------------- +# Tool events +# --------------------------------------------------------------------------- + + +class TestToolEvents: + def test_tool_call_and_result(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather NYC") + adapter.disconnect() + + events = uploaded["events"] + + tool_calls = find_events(events, "tool.call") + assert len(tool_calls) == 1 + assert tool_calls[0]["payload"]["tool_name"] == "get_weather" + + tool_results = find_events(events, "tool.result") + assert len(tool_results) == 1 + assert tool_results[0]["payload"]["tool_name"] == "get_weather" + + def test_tool_result_has_output(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client, capture_config=CaptureConfig.full()) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather NYC") + adapter.disconnect() + + tool_results = find_events(uploaded["events"], "tool.result") + assert len(tool_results) == 1 + # The output should contain the tool's return value + assert "72F" in str(tool_results[0]["payload"]["output"]) + + def test_tool_result_has_latency(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather") + adapter.disconnect() + + tool_results = find_events(uploaded["events"], "tool.result") + assert len(tool_results) == 1 + assert tool_results[0]["payload"]["latency_ms"] >= 0 + + +# --------------------------------------------------------------------------- +# Span hierarchy +# --------------------------------------------------------------------------- + + +class TestSpanHierarchy: + def test_per_step_events_parented_to_root(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="Done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("weather") + adapter.disconnect() + + events = uploaded["events"] + root = find_event(events, "agent.input") + root_span = root["span_id"] + + for evt in find_events(events, "model.invoke"): + assert evt["parent_span_id"] == root_span + for evt in find_events(events, "tool.call"): + assert evt["parent_span_id"] == root_span + for evt in find_events(events, "tool.result"): + assert evt["parent_span_id"] == root_span + + +# --------------------------------------------------------------------------- +# CaptureConfig gating +# --------------------------------------------------------------------------- + + +class TestCaptureConfigGating: + def test_no_content_capture_omits_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig(capture_content=False) + adapter = PydanticAIAdapter(mock_client, capture_config=config) + agent = _make_agent(output_text="done", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("secret prompt") + adapter.disconnect() + + events = uploaded["events"] + + inp = find_event(events, "agent.input") + assert "input" not in inp["payload"] + + tool_calls = find_events(events, "tool.call") + assert len(tool_calls) >= 1 + assert "input" not in tool_calls[0]["payload"] + + tool_results = find_events(events, "tool.result") + assert len(tool_results) >= 1 + assert "output" not in tool_results[0]["payload"] + + # cost.record should still exist + assert len(find_events(events, "cost.record")) == 1 + + def test_full_config_includes_content(self, mock_client): + uploaded = capture_framework_trace(mock_client) + config = CaptureConfig.full() + adapter = PydanticAIAdapter(mock_client, capture_config=config) + agent = _make_agent(output_text="Hi Alice", tools=[get_weather]) + + adapter.connect(target=agent) + agent.run_sync("greet Alice") + adapter.disconnect() + + events = uploaded["events"] + + inp = find_event(events, "agent.input") + assert inp["payload"]["input"] == "greet Alice" + + out = find_event(events, "agent.output") + assert out["payload"]["output"] == "Hi Alice" + + tool_calls = find_events(events, "tool.call") + assert "input" in tool_calls[0]["payload"] + + +# --------------------------------------------------------------------------- +# Multiple runs +# --------------------------------------------------------------------------- + + +class TestMultipleRuns: + def test_sequential_runs_separate_traces(self, mock_client): + import json + + all_uploads: list = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + all_uploads.append(data[0]) + + mock_client.traces.upload.side_effect = _capture + + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("first") + agent.run_sync("second") + adapter.disconnect() + + assert len(all_uploads) == 2 + trace_ids = {u["trace_id"] for u in all_uploads} + assert len(trace_ids) == 2 + + +# --------------------------------------------------------------------------- +# Event structure +# --------------------------------------------------------------------------- + + +class TestEventStructure: + def test_event_fields(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(name="test_agent", output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("hello") + adapter.disconnect() + + events = uploaded["events"] + for event in events: + assert "event_type" in event + assert "trace_id" in event + assert "span_id" in event + assert "sequence_id" in event + assert "timestamp_ns" in event + assert "payload" in event + + seq_ids = [e["sequence_id"] for e in events] + assert seq_ids == sorted(seq_ids) + assert len(set(seq_ids)) == len(seq_ids) + + trace_ids = {e["trace_id"] for e in events} + assert len(trace_ids) == 1 + + def test_attestation_present(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("test") + adapter.disconnect() + + assert uploaded.get("trace_id") is not None + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_prompt(self, mock_client): + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("") + adapter.disconnect() + + inp = find_event(uploaded["events"], "agent.input") + assert inp["payload"]["framework"] == "pydantic-ai" + + def test_pydantic_model_output(self, mock_client): + from pydantic import BaseModel + + class CityInfo(BaseModel): + city: str + temp: int + + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = Agent( + model=TestModel(custom_output_args={"city": "NYC", "temp": 72}), + output_type=CityInfo, + ) + + adapter.connect(target=agent) + result = agent.run_sync("weather") + adapter.disconnect() + + out = find_event(uploaded["events"], "agent.output") + assert out["payload"]["output"] == {"city": "NYC", "temp": 72} + + def test_zero_token_usage_still_has_tokens(self, mock_client): + """TestModel always produces some tokens, so we verify they're present.""" + uploaded = capture_framework_trace(mock_client) + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent(output_text="ok") + + adapter.connect(target=agent) + agent.run_sync("test") + adapter.disconnect() + + out = find_event(uploaded["events"], "agent.output") + # TestModel always has some token usage + assert "tokens_prompt" in out["payload"] + assert len(find_events(uploaded["events"], "cost.record")) == 1 + + def test_disconnect_idempotent(self, mock_client): + adapter = PydanticAIAdapter(mock_client) + agent = _make_agent() + adapter.connect(target=agent) + adapter.disconnect() + adapter.disconnect() # should not raise diff --git a/tests/instrument/adapters/frameworks/test_semantic_kernel.py b/tests/instrument/adapters/frameworks/test_semantic_kernel.py new file mode 100644 index 0000000..9ae833a --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_semantic_kernel.py @@ -0,0 +1,753 @@ +"""Tests for the Semantic Kernel adapter using the SK filter API. + +Tests use real Kernel objects and KernelFunctions. Filters are exercised +either through actual kernel.invoke() calls or by directly invoking the +filter callables with mock contexts. +""" +from __future__ import annotations + +import asyncio +from typing import Any, Optional +from unittest.mock import MagicMock + +import pytest + +sk = pytest.importorskip("semantic_kernel") + +from semantic_kernel import Kernel # noqa: E402 +from semantic_kernel.functions import kernel_function # noqa: E402 +from semantic_kernel.filters.filter_types import FilterTypes # noqa: E402 + +from layerlens.instrument._capture_config import CaptureConfig # noqa: E402 +from layerlens.instrument.adapters.frameworks.semantic_kernel import ( # noqa: E402 + SemanticKernelAdapter, + _extract_arguments, + _extract_function_name, + _extract_plugin_name, +) + +from .conftest import capture_framework_trace, find_event, find_events # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class MathPlugin: + @kernel_function(name="add", description="Add two numbers") + def add(self, a: int, b: int) -> int: + return a + b + + @kernel_function(name="divide", description="Divide a by b") + def divide(self, a: int, b: int) -> float: + return a / b + + +class TextPlugin: + @kernel_function(name="upper", description="Uppercase text") + def upper(self, text: str) -> str: + return text.upper() + + +class MockFunction: + def __init__(self, name: str = "test_func", plugin_name: str = "TestPlugin"): + self.name = name + self.plugin_name = plugin_name + + +class MockContext: + def __init__( + self, + function: Any = None, + arguments: Any = None, + result: Any = None, + rendered_prompt: Optional[str] = None, + function_call_content: Any = None, + function_result: Any = None, + request_sequence_index: int = 0, + function_sequence_index: int = 0, + ): + self.function = function or MockFunction() + self.arguments = arguments + self.result = result + self.rendered_prompt = rendered_prompt + self.function_call_content = function_call_content + self.function_result = function_result + self.request_sequence_index = request_sequence_index + self.function_sequence_index = function_sequence_index + + +class MockFunctionCallContent: + def __init__(self, arguments: Any = None): + self.arguments = arguments + + +class MockFunctionResult: + def __init__(self, value: Any = None): + self.value = value + + +def _run(coro: Any) -> Any: + return asyncio.get_event_loop().run_until_complete(coro) + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def test_connect_registers_filters(self, mock_client): + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + assert adapter.is_connected + assert len(kernel.function_invocation_filters) == 1 + assert len(kernel.prompt_rendering_filters) == 1 + assert len(kernel.auto_function_invocation_filters) == 1 + + info = adapter.adapter_info() + assert info.name == "semantic_kernel" + assert info.adapter_type == "framework" + assert info.connected is True + + adapter.disconnect() + + def test_disconnect_removes_filters(self, mock_client): + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + adapter.disconnect() + + assert not adapter.is_connected + assert len(kernel.function_invocation_filters) == 0 + assert len(kernel.prompt_rendering_filters) == 0 + assert len(kernel.auto_function_invocation_filters) == 0 + + def test_connect_without_target_raises(self, mock_client): + adapter = SemanticKernelAdapter(mock_client) + with pytest.raises(ValueError, match="requires a target kernel"): + adapter.connect() + + def test_connect_without_sk_raises(self, mock_client, monkeypatch): + import layerlens.instrument.adapters.frameworks.semantic_kernel as mod + + monkeypatch.setattr(mod, "_HAS_SEMANTIC_KERNEL", False) + adapter = SemanticKernelAdapter(mock_client) + with pytest.raises(ImportError, match="semantic_kernel"): + adapter.connect(target=Kernel()) + + def test_disconnect_idempotent(self, mock_client): + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + adapter.disconnect() + adapter.disconnect() # should not raise + + +# --------------------------------------------------------------------------- +# Function invocation via real kernel.invoke() +# --------------------------------------------------------------------------- + + +class TestFunctionInvocation: + def test_invoke_emits_tool_call(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + result = _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=2, b=3)) + assert str(result) == "5" + + adapter.disconnect() + + events = uploaded["events"] + tool_calls = find_events(events, "tool.call") + assert len(tool_calls) >= 1 + assert tool_calls[0]["payload"]["tool_name"] == "MathPlugin.add" + assert tool_calls[0]["payload"]["plugin_name"] == "MathPlugin" + assert tool_calls[0]["payload"]["function_name"] == "add" + + tool_results = find_events(events, "tool.result") + assert len(tool_results) >= 1 + assert tool_results[0]["payload"]["status"] == "ok" + assert tool_results[0]["payload"]["latency_ms"] >= 0 + + def test_invoke_captures_output(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=10, b=20)) + adapter.disconnect() + + events = uploaded["events"] + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["output"] == 30 + + def test_invoke_error_emits_agent_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + with pytest.raises(Exception): + _run(kernel.invoke(plugin_name="MathPlugin", function_name="divide", a=1, b=0)) + + adapter.disconnect() + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert "division by zero" in err["payload"]["error"] + assert err["payload"]["error_type"] == "ZeroDivisionError" + assert err["payload"]["tool_name"] == "MathPlugin.divide" + + def test_sequential_invocations(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=1, b=2)) + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=3, b=4)) + adapter.disconnect() + + events = uploaded["events"] + assert len(find_events(events, "tool.call")) == 2 + assert len(find_events(events, "tool.result")) == 2 + + +# --------------------------------------------------------------------------- +# Function invocation filter via direct call +# --------------------------------------------------------------------------- + + +class TestFunctionInvocationFilter: + def test_filter_calls_next_and_emits(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("greet", "HelloPlugin"), + ) + + async def mock_next(context): + context.result = MockFunctionResult("Hi") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["plugin_name"] == "HelloPlugin" + assert tool_call["payload"]["function_name"] == "greet" + + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["status"] == "ok" + + def test_filter_propagates_exception(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx = MockContext() + + async def failing_next(context): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + _run(adapter._function_invocation_filter(ctx, failing_next)) + + adapter.disconnect() + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["error"] == "boom" + assert err["payload"]["error_type"] == "RuntimeError" + + +# --------------------------------------------------------------------------- +# Prompt rendering +# --------------------------------------------------------------------------- + + +class TestPromptRendering: + def test_prompt_render_emits_agent_code(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("summarize", "TextPlugin"), + rendered_prompt="Summarize: Hello world", + ) + + async def mock_next(context): + pass + + _run(adapter._prompt_rendering_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + ev = find_event(events, "agent.code") + assert ev["payload"]["event_subtype"] == "prompt_render" + assert ev["payload"]["function_name"] == "summarize" + assert "Summarize" in ev["payload"]["rendered_prompt"] + + def test_prompt_render_no_content_when_disabled(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + config = CaptureConfig(l2_agent_code=True, capture_content=False) + adapter = SemanticKernelAdapter(mock_client, capture_config=config) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("summarize", "TextPlugin"), + rendered_prompt="secret prompt", + ) + + async def mock_next(context): + pass + + _run(adapter._prompt_rendering_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + ev = find_event(events, "agent.code") + assert "rendered_prompt" not in ev["payload"] + + +# --------------------------------------------------------------------------- +# Auto function invocation (LLM-initiated tool calls) +# --------------------------------------------------------------------------- + + +class TestAutoFunctionInvocation: + def test_auto_function_emits_tool_call_and_result(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("web_search", "SearchPlugin"), + function_call_content=MockFunctionCallContent(arguments={"query": "test"}), + function_result=MockFunctionResult("found it"), + request_sequence_index=1, + function_sequence_index=0, + ) + + async def mock_next(context): + pass + + _run(adapter._auto_function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["auto_invoked"] is True + assert tool_call["payload"]["tool_name"] == "SearchPlugin.web_search" + assert tool_call["payload"]["input"] == {"query": "test"} + assert tool_call["payload"]["request_sequence_index"] == 1 + + tool_results = find_events(events, "tool.result") + assert len(tool_results) == 1 + assert tool_results[0]["payload"]["auto_invoked"] is True + assert tool_results[0]["payload"]["output"] == "found it" + assert tool_results[0]["payload"]["latency_ms"] >= 0 + + def test_auto_function_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("fail_tool", "ToolPlugin"), + ) + + async def failing_next(context): + raise ValueError("tool exploded") + + with pytest.raises(ValueError, match="tool exploded"): + _run(adapter._auto_function_invocation_filter(ctx, failing_next)) + + adapter.disconnect() + + events = uploaded["events"] + # tool.call should still be emitted (before the error) + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["auto_invoked"] is True + + err = find_event(events, "agent.error") + assert err["payload"]["error"] == "tool exploded" + assert err["payload"]["auto_invoked"] is True + + +# --------------------------------------------------------------------------- +# Plugin discovery +# --------------------------------------------------------------------------- + + +class TestPluginDiscovery: + def test_discover_plugins_on_connect(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + kernel.add_plugin(TextPlugin(), "TextPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + adapter.disconnect() + + events = uploaded["events"] + config_events = find_events(events, "environment.config") + plugin_names = {e["payload"]["plugin_name"] for e in config_events} + assert "MathPlugin" in plugin_names + assert "TextPlugin" in plugin_names + + def test_new_plugin_discovered_on_first_call(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + # Invoke filter directly with a plugin not yet seen + ctx = MockContext(function=MockFunction("do_stuff", "NewPlugin")) + + async def mock_next(context): + context.result = MockFunctionResult("ok") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + config_events = find_events(events, "environment.config") + names = {e["payload"]["plugin_name"] for e in config_events} + assert "NewPlugin" in names + + def test_duplicate_plugin_not_rediscovered(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + ctx1 = MockContext(function=MockFunction("f1", "SamePlugin")) + ctx2 = MockContext(function=MockFunction("f2", "SamePlugin")) + + async def mock_next(context): + context.result = MockFunctionResult("ok") + + _run(adapter._function_invocation_filter(ctx1, mock_next)) + _run(adapter._function_invocation_filter(ctx2, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + config_events = find_events(events, "environment.config") + same_plugin = [e for e in config_events if e["payload"]["plugin_name"] == "SamePlugin"] + assert len(same_plugin) == 1 + + +# --------------------------------------------------------------------------- +# CaptureConfig gating +# --------------------------------------------------------------------------- + + +class TestCaptureConfigGating: + def test_no_content_strips_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig(capture_content=False)) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("search", "Plugin"), + arguments={"secret": "key"}, + ) + + async def mock_next(context): + context.result = MockFunctionResult("classified") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert "input" not in tool_call["payload"] + + tool_result = find_event(events, "tool.result") + assert "output" not in tool_result["payload"] + + def test_full_config_includes_io(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + adapter = SemanticKernelAdapter(mock_client, capture_config=CaptureConfig.full()) + adapter.connect(target=kernel) + + ctx = MockContext( + function=MockFunction("search", "Plugin"), + arguments={"query": "test"}, + ) + + async def mock_next(context): + context.result = MockFunctionResult("results") + + _run(adapter._function_invocation_filter(ctx, mock_next)) + adapter.disconnect() + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["input"] == {"query": "test"} + + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["output"] == "results" + + +# --------------------------------------------------------------------------- +# LLM call wrapping +# --------------------------------------------------------------------------- + + +class MockUsage: + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + + +class MockChatMessage: + def __init__(self, text: str = "Hello!", model_id: str = "gpt-4o", usage: Any = None): + self.content = text + self.ai_model_id = model_id + self.metadata = {"usage": usage} if usage else {} + + +class MockChatService: + """Minimal mock that looks like a ChatCompletionClientBase to the adapter.""" + + def __init__(self, response_text: str = "Hello!", model_id: str = "gpt-4o", + prompt_tokens: int = 100, completion_tokens: int = 50): + self.ai_model_id = model_id + self._response = MockChatMessage( + text=response_text, + model_id=model_id, + usage=MockUsage(prompt_tokens, completion_tokens), + ) + + async def _inner_get_chat_message_contents(self, chat_history: Any, settings: Any) -> list: + return [self._response] + + +class TestLLMCallWrapping: + def _register_mock_service(self, kernel, service): + """Register a mock service directly on the kernel.""" + kernel.services["mock"] = service + + def test_model_invoke_emitted(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService(prompt_tokens=100, completion_tokens=50) + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + # Call the wrapped method directly + _run(service._inner_get_chat_message_contents(None, None)) + + adapter.disconnect() + + events = uploaded["events"] + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["model"] == "gpt-4o" + assert model_invoke["payload"]["tokens_prompt"] == 100 + assert model_invoke["payload"]["tokens_completion"] == 50 + assert model_invoke["payload"]["latency_ms"] >= 0 + + def test_cost_record_emitted(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService(prompt_tokens=200, completion_tokens=100) + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + _run(service._inner_get_chat_message_contents(None, None)) + adapter.disconnect() + + events = uploaded["events"] + cost = find_event(events, "cost.record") + assert cost["payload"]["tokens_total"] == 300 + assert cost["payload"]["model"] == "gpt-4o" + + def test_no_cost_record_without_tokens(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService(prompt_tokens=0, completion_tokens=0) + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + _run(service._inner_get_chat_message_contents(None, None)) + adapter.disconnect() + + events = uploaded["events"] + cost_events = find_events(events, "cost.record") + assert len(cost_events) == 0 + + def test_llm_error_emits_agent_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + service = MockChatService() + self._register_mock_service(kernel, service) + + # Replace inner method with one that fails + original = service._inner_get_chat_message_contents + + async def failing_inner(chat_history, settings): + raise RuntimeError("API timeout") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + # The adapter wrapped the original, so replace the original call path + # We need to set up the service to fail BEFORE connect wraps it + # Let's test by reconnecting + adapter.disconnect() + + service._inner_get_chat_message_contents = failing_inner + adapter.connect(target=kernel) + + with pytest.raises(RuntimeError, match="API timeout"): + _run(service._inner_get_chat_message_contents(None, None)) + + adapter.disconnect() + + events = uploaded["events"] + err = find_event(events, "agent.error") + assert err["payload"]["error"] == "API timeout" + assert err["payload"]["model"] == "gpt-4o" + + def test_disconnect_restores_original(self, mock_client): + kernel = Kernel() + service = MockChatService() + self._register_mock_service(kernel, service) + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + # After connect, the method is our wrapper (an instance attribute, not the class method) + assert "_traced_inner" in service._inner_get_chat_message_contents.__name__ + + adapter.disconnect() + # After disconnect, the instance override is removed and the class method is accessible again + assert "_traced_inner" not in service._inner_get_chat_message_contents.__name__ + + +# --------------------------------------------------------------------------- +# Span hierarchy +# --------------------------------------------------------------------------- + + +class TestSpanHierarchy: + def test_events_share_root_span(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=1, b=2)) + adapter.disconnect() + + events = uploaded["events"] + # All events should share the same root span (via parent_span_id) + parent_spans = {e.get("parent_span_id") for e in events if e.get("parent_span_id")} + # There should be at most one root + assert len(parent_spans) <= 2 # root_span_id from _ensure_collector + our root + + +# --------------------------------------------------------------------------- +# Event structure +# --------------------------------------------------------------------------- + + +class TestEventStructure: + def test_event_fields(self, mock_client): + uploaded = capture_framework_trace(mock_client) + kernel = Kernel() + kernel.add_plugin(MathPlugin(), "MathPlugin") + + adapter = SemanticKernelAdapter(mock_client) + adapter.connect(target=kernel) + + _run(kernel.invoke(plugin_name="MathPlugin", function_name="add", a=1, b=2)) + adapter.disconnect() + + events = uploaded["events"] + for event in events: + assert "event_type" in event + assert "trace_id" in event + assert "span_id" in event + assert "sequence_id" in event + assert "timestamp_ns" in event + assert "payload" in event + assert event["payload"]["framework"] == "semantic_kernel" + + seq_ids = [e["sequence_id"] for e in events] + assert seq_ids == sorted(seq_ids) + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +class TestHelpers: + def test_extract_plugin_name_from_function(self): + ctx = MockContext(function=MockFunction(plugin_name="MyPlugin")) + assert _extract_plugin_name(ctx) == "MyPlugin" + + def test_extract_plugin_name_fallback(self): + class Ctx: + function = None + plugin_name = "FallbackPlugin" + + assert _extract_plugin_name(Ctx()) == "FallbackPlugin" + + def test_extract_function_name(self): + ctx = MockContext(function=MockFunction(name="my_func")) + assert _extract_function_name(ctx) == "my_func" + + def test_extract_arguments_dict(self): + ctx = MockContext(arguments={"x": 1, "y": 2}) + assert _extract_arguments(ctx) == {"x": 1, "y": 2} + + def test_extract_arguments_none(self): + ctx = MockContext(arguments=None) + assert _extract_arguments(ctx) is None + + def test_extract_arguments_mapping(self): + """SK KernelArguments has .items() but isn't a dict.""" + class FakeArgs: + def items(self): + return [("a", 1)] + + ctx = MockContext(arguments=FakeArgs()) + assert _extract_arguments(ctx) == {"a": 1} From 96ada84868e5b4eae1bd3d90ff4a1578d64dcc59 Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Mon, 6 Apr 2026 13:23:29 -0700 Subject: [PATCH 3/4] feat: unify context model, per-client uploads, and adapter hardening --- pyproject.toml | 35 +- src/layerlens/instrument/_collector.py | 75 +++-- src/layerlens/instrument/_context.py | 2 + src/layerlens/instrument/_decorator.py | 4 +- src/layerlens/instrument/_upload.py | 259 +++++++++++---- .../adapters/frameworks/_base_framework.py | 314 +++++++----------- .../instrument/adapters/frameworks/crewai.py | 6 +- .../adapters/frameworks/langchain.py | 30 +- .../adapters/frameworks/openai_agents.py | 110 +++--- .../adapters/frameworks/pydantic_ai.py | 79 ++--- .../adapters/frameworks/semantic_kernel.py | 114 ++++--- .../adapters/providers/_base_provider.py | 2 + tests/conftest.py | 10 + .../adapters/frameworks/test_concurrency.py | 93 ++++++ .../adapters/frameworks/test_crewai.py | 4 +- .../adapters/frameworks/test_openai_agents.py | 12 +- .../frameworks/test_semantic_kernel.py | 16 +- tests/instrument/test_trace_context.py | 158 +++++---- 18 files changed, 745 insertions(+), 578 deletions(-) create mode 100644 tests/instrument/adapters/frameworks/test_concurrency.py diff --git a/pyproject.toml b/pyproject.toml index d0fabba..54be8cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ openai = ["openai>=1.0.0"] anthropic = ["anthropic>=0.18.0"] langchain = ["langchain-core>=0.1.0"] litellm = ["litellm>=1.0.0"] +pydantic-ai = ["pydantic-ai>=0.2.0"] +openai-agents = ["openai-agents>=0.1.0"] +semantic-kernel = ["semantic-kernel>=1.0.0"] [project.urls] Homepage = "https://github.com/LayerLens/stratix-python" @@ -50,14 +53,15 @@ stratix = "layerlens.cli:main" managed = true # version pins are in requirements-dev.lock dev-dependencies = [ - "mypy", - "pytest", - "pyright==1.1.399", - "pytest-cov>=6.2.1", - "ruff", - "build", - "twine==6.1.0", - "click>=8.0.0", + "mypy", + "pytest", + "pyright==1.1.399", + "pytest-cov>=6.2.1", + "ruff", + "build", + "twine==6.1.0", + "click>=8.0.0", + "crewai>=0.5.0", ] [tool.rye.scripts] @@ -146,6 +150,21 @@ known-first-party = ["openai", "tests"] "src/layerlens/cli/**" = ["T201", "T203"] "src/layerlens/instrument/adapters/frameworks/langchain.py" = ["ARG002"] "src/layerlens/instrument/adapters/frameworks/langgraph.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/crewai.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/pydantic_ai.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/openai_agents.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/autogen.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/llamaindex.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/semantic_kernel.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/smolagents.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/google_adk.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/agno.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/strands.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/bedrock_agents.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/ms_agent_framework.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/haystack.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/langfuse.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/agentforce.py" = ["ARG002"] [tool.pyright] include = ["src", "tests"] diff --git a/src/layerlens/instrument/_collector.py b/src/layerlens/instrument/_collector.py index 031576f..beb9964 100644 --- a/src/layerlens/instrument/_collector.py +++ b/src/layerlens/instrument/_collector.py @@ -3,18 +3,25 @@ import time import uuid import logging +import threading from typing import Any, Dict, List, Optional from layerlens.attestation import HashChain from ._capture_config import CaptureConfig -from ._upload import upload_trace, async_upload_trace +from ._upload import enqueue_upload log: logging.Logger = logging.getLogger(__name__) class TraceCollector: - """Collects flat events for a single trace, with CaptureConfig gating and attestation.""" + """Collects flat events for a single trace, with CaptureConfig gating and attestation. + + Thread-safe: all mutations go through ``self._lock``. + Once ``flush()`` is called the collector is sealed — further ``emit()`` calls are no-ops. + """ + + MAX_EVENTS = 10_000 def __init__(self, client: Any, config: CaptureConfig) -> None: self._client = client @@ -23,6 +30,9 @@ def __init__(self, client: Any, config: CaptureConfig) -> None: self._events: List[Dict[str, Any]] = [] self._sequence: int = 0 self._chain = HashChain() + self._capped = False + self._sealed = False + self._lock = threading.Lock() @property def trace_id(self) -> str: @@ -46,19 +56,32 @@ def emit( payload = self._config.redact_payload(event_type, payload) - self._sequence += 1 - event: Dict[str, Any] = { - "event_type": event_type, - "trace_id": self._trace_id, - "span_id": span_id, - "parent_span_id": parent_span_id, - "span_name": span_name, - "sequence_id": self._sequence, - "timestamp_ns": time.time_ns(), - "payload": payload, - } - self._chain.add_event(event) - self._events.append(event) + with self._lock: + if self._sealed: + return + + if len(self._events) >= self.MAX_EVENTS: + if not self._capped: + self._capped = True + log.warning( + "layerlens: trace %s hit %d event limit, further events dropped", + self._trace_id, self.MAX_EVENTS, + ) + return + + self._sequence += 1 + event: Dict[str, Any] = { + "event_type": event_type, + "trace_id": self._trace_id, + "span_id": span_id, + "parent_span_id": parent_span_id, + "span_name": span_name, + "sequence_id": self._sequence, + "timestamp_ns": time.time_ns(), + "payload": payload, + } + self._chain.add_event(event) + self._events.append(event) def _build_trace_payload(self) -> Dict[str, Any]: """Build the attestation envelope and trace payload.""" @@ -73,21 +96,23 @@ def _build_trace_payload(self) -> Dict[str, Any]: log.warning("Failed to build attestation chain", exc_info=True) attestation = {"attestation_error": str(exc)} - return { + trace_payload: Dict[str, Any] = { "trace_id": self._trace_id, "events": self._events, "capture_config": self._config.to_dict(), "attestation": attestation, } + if self._capped: + trace_payload["truncated"] = True + trace_payload["max_events"] = self.MAX_EVENTS + return trace_payload def flush(self) -> None: - """Build attestation and upload the trace.""" - if not self._events: - return - upload_trace(self._client, self._build_trace_payload()) + """Seal the collector, build attestation, and enqueue the trace for background upload.""" + with self._lock: + if self._sealed or not self._events: + return + self._sealed = True + payload = self._build_trace_payload() + enqueue_upload(self._client, payload) - async def async_flush(self) -> None: - """Async version of flush.""" - if not self._events: - return - await async_upload_trace(self._client, self._build_trace_payload()) diff --git a/src/layerlens/instrument/_context.py b/src/layerlens/instrument/_context.py index fce18d2..98716cf 100644 --- a/src/layerlens/instrument/_context.py +++ b/src/layerlens/instrument/_context.py @@ -27,6 +27,8 @@ class RunState: timers: Dict[str, int] = field(default_factory=dict) data: Dict[str, Any] = field(default_factory=dict) _token: Any = field(default=None, repr=False) + _col_token: Any = field(default=None, repr=False) + _span_snapshot: Any = field(default=None, repr=False) _current_run: ContextVar[Optional[RunState]] = ContextVar("_current_run", default=None) diff --git a/src/layerlens/instrument/_decorator.py b/src/layerlens/instrument/_decorator.py index b4a118c..6f76f37 100644 --- a/src/layerlens/instrument/_decorator.py +++ b/src/layerlens/instrument/_decorator.py @@ -46,7 +46,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: span_id=root_span_id, span_name=span_name, ) - await collector.async_flush() + collector.flush() return result except Exception as exc: collector.emit( @@ -55,7 +55,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: span_id=root_span_id, span_name=span_name, ) - await collector.async_flush() + collector.flush() raise finally: _pop_span(span_snapshot) diff --git a/src/layerlens/instrument/_upload.py b/src/layerlens/instrument/_upload.py index ae42048..ff471a8 100644 --- a/src/layerlens/instrument/_upload.py +++ b/src/layerlens/instrument/_upload.py @@ -1,67 +1,209 @@ from __future__ import annotations +import atexit import os import json +import queue import time -import asyncio import logging import tempfile import threading -from typing import Any, Dict +from typing import Any, Dict, Optional, Tuple log: logging.Logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Circuit breaker +# Per-client upload channel # --------------------------------------------------------------------------- -_lock = threading.Lock() -_error_count = 0 -_circuit_open = False -_opened_at: float = 0.0 - -_THRESHOLD = 10 -_COOLDOWN_S = 60.0 +class UploadChannel: + """Per-client upload state: circuit breaker + background worker + queue. + + Each ``client`` gets its own channel so that a failing backend A + doesn't trip the breaker for a healthy backend B. + """ + + _THRESHOLD = 10 + _COOLDOWN_S = 60.0 + + def __init__(self) -> None: + self._lock = threading.Lock() + self._error_count = 0 + self._circuit_open = False + self._opened_at: float = 0.0 + self._queue: queue.Queue[Optional[Tuple[Any, Dict[str, Any]]]] = queue.Queue(maxsize=64) + self._worker: Optional[threading.Thread] = None + + # -- Circuit breaker -- + + def _allow(self) -> bool: + with self._lock: + if not self._circuit_open: + return True + if time.monotonic() - self._opened_at >= self._COOLDOWN_S: + self._circuit_open = False + self._error_count = 0 + log.info("layerlens: upload circuit breaker half-open, retrying") + return True + return False + + def _on_success(self) -> None: + with self._lock: + if self._error_count > 0: + self._error_count = 0 + self._circuit_open = False + + def _on_failure(self) -> None: + with self._lock: + self._error_count += 1 + if self._error_count >= self._THRESHOLD and not self._circuit_open: + self._circuit_open = True + self._opened_at = time.monotonic() + log.warning( + "layerlens: upload circuit breaker OPEN after %d errors (cooldown %.0fs)", + self._error_count, + self._COOLDOWN_S, + ) + + # -- Worker thread -- + + def _worker_loop(self) -> None: + while True: + item = self._queue.get() + if item is None: + break + client, payload = item + if not self._allow(): + continue + path = _write_trace_file(payload) + try: + client.traces.upload(path) + self._on_success() + except Exception: + self._on_failure() + log.warning("layerlens: background trace upload failed", exc_info=True) + finally: + try: + os.unlink(path) + except OSError: + log.debug("Failed to remove temp trace file: %s", path) + + def _ensure_worker(self) -> None: + if self._worker is not None and self._worker.is_alive(): + return + with self._lock: + if self._worker is not None and self._worker.is_alive(): + return + self._worker = threading.Thread( + target=self._worker_loop, daemon=True, name="layerlens-upload", + ) + self._worker.start() -def _allow() -> bool: - global _circuit_open, _error_count - with _lock: - if not _circuit_open: + def enqueue(self, client: Any, payload: Dict[str, Any]) -> bool: + """Enqueue a trace for background upload. Returns False if dropped.""" + if _sync_mode: + self._upload_sync(client, payload) return True - if time.monotonic() - _opened_at >= _COOLDOWN_S: - _circuit_open = False - _error_count = 0 - log.info("layerlens: upload circuit breaker half-open, retrying") + if not self._allow(): + return False + self._ensure_worker() + try: + self._queue.put_nowait((client, payload)) return True - return False - - -def _on_success() -> None: - global _error_count, _circuit_open - with _lock: - if _error_count > 0: - _error_count = 0 - _circuit_open = False - - -def _on_failure() -> None: - global _error_count, _circuit_open, _opened_at - with _lock: - _error_count += 1 - if _error_count >= _THRESHOLD and not _circuit_open: - _circuit_open = True - _opened_at = time.monotonic() - log.warning( - "layerlens: upload circuit breaker OPEN after %d errors (cooldown %.0fs)", - _error_count, - _COOLDOWN_S, - ) + except queue.Full: + log.warning("layerlens: upload queue full, dropping trace %s", payload.get("trace_id", "?")) + return False + + def _upload_sync(self, client: Any, payload: Dict[str, Any]) -> None: + """Synchronous upload (used in tests).""" + if not self._allow(): + return + path = _write_trace_file(payload) + try: + client.traces.upload(path) + self._on_success() + except Exception: + self._on_failure() + log.warning("layerlens: trace upload failed", exc_info=True) + finally: + try: + os.unlink(path) + except OSError: + log.debug("Failed to remove temp trace file: %s", path) + + def shutdown(self, timeout: float = 5.0) -> None: + """Drain the queue and stop the worker thread.""" + if self._worker is None or not self._worker.is_alive(): + return + try: + self._queue.put_nowait(None) + except queue.Full: + pass + self._worker.join(timeout) + self._worker = None # --------------------------------------------------------------------------- -# Upload +# Channel registry (one per client) +# --------------------------------------------------------------------------- + +_ATTR = "_layerlens_upload_channel" +_channels: list[UploadChannel] = [] # keeps refs for shutdown_uploads +_registry_lock = threading.Lock() + + +def _get_channel(client: Any) -> UploadChannel: + """Return (or create) the UploadChannel for *client*. + + The channel is stored directly on the client object so that identity + is tied to the object's lifetime, not its ``id()`` (which can be + reused after garbage collection). + """ + ch = getattr(client, _ATTR, None) + if isinstance(ch, UploadChannel): + return ch + with _registry_lock: + # Double-check under lock + ch = getattr(client, _ATTR, None) + if isinstance(ch, UploadChannel): + return ch + ch = UploadChannel() + try: + object.__setattr__(client, _ATTR, ch) + except (AttributeError, TypeError): + # Frozen / slotted objects — fall back to a side dict + pass + _channels.append(ch) + return ch + + +# --------------------------------------------------------------------------- +# Public API (used by TraceCollector) +# --------------------------------------------------------------------------- + +_sync_mode = False + + +def enqueue_upload(client: Any, payload: Dict[str, Any]) -> bool: + """Enqueue a trace for background upload via the client's channel.""" + return _get_channel(client).enqueue(client, payload) + + +def shutdown_uploads(timeout: float = 5.0) -> None: + """Shut down all upload channels.""" + with _registry_lock: + channels = list(_channels) + for ch in channels: + ch.shutdown(timeout) + + +atexit.register(shutdown_uploads) + + +# --------------------------------------------------------------------------- +# Helpers # --------------------------------------------------------------------------- @@ -73,34 +215,5 @@ def _write_trace_file(payload: Dict[str, Any]) -> str: def upload_trace(client: Any, payload: Dict[str, Any]) -> None: - if not _allow(): - return - path = _write_trace_file(payload) - try: - client.traces.upload(path) - _on_success() - except Exception: - _on_failure() - log.warning("layerlens: trace upload failed", exc_info=True) - finally: - try: - os.unlink(path) - except OSError: - log.debug("Failed to remove temp trace file: %s", path) - - -async def async_upload_trace(client: Any, payload: Dict[str, Any]) -> None: - if not _allow(): - return - path = await asyncio.to_thread(_write_trace_file, payload) - try: - await client.traces.upload(path) - _on_success() - except Exception: - _on_failure() - log.warning("layerlens: async trace upload failed", exc_info=True) - finally: - try: - os.unlink(path) - except OSError: - log.debug("Failed to remove temp trace file: %s", path) + """Synchronous upload (testing convenience).""" + _get_channel(client)._upload_sync(client, payload) diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 8190510..f933d1c 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -9,24 +9,34 @@ import uuid import logging import threading -from typing import Any, Dict, Generator, Optional -from contextlib import contextmanager +from typing import Any, Dict, Optional from .._base import AdapterInfo, BaseAdapter from ..._collector import TraceCollector from ..._capture_config import CaptureConfig -from ..._context import _current_collector, _current_span_id, _push_span, _pop_span, _current_run, RunState +from ..._context import ( + _current_collector, + _current_span_id, + _push_span, + _pop_span, + _current_run, + RunState, +) log = logging.getLogger(__name__) -_UNSET: Any = object() # sentinel: distinguish "not passed" from explicit None - class FrameworkAdapter(BaseAdapter): - """Base for framework adapters with collector lifecycle management.""" + """Base for framework adapters with collector lifecycle management. + + Every adapter call that produces events MUST be inside a + ``_begin_run`` / ``_end_run`` pair. ``_begin_run`` pushes the + collector and root span into ContextVars so provider adapters + can see it automatically. + """ - name: str # Subclass must set: "crewai", "llamaindex", etc. - package: str = "" # pip extra name, e.g. "crewai" → pip install layerlens[crewai] + name: str # Subclass must set: "langchain", "pydantic-ai", etc. + package: str = "" # pip extra name, e.g. "semantic-kernel" def _check_dependency(self, available: bool) -> None: """Raise ImportError with a helpful install message if the dependency is missing.""" @@ -42,15 +52,6 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._config = capture_config or CaptureConfig.standard() self._lock = threading.Lock() self._connected = False - self._collector: Optional[TraceCollector] = None - self._root_span_id: Optional[str] = None - self._using_shared_collector = False - # Optional run_id → span_id mapping for callback-style frameworks - self._span_ids: Dict[str, str] = {} - # Root run tracking for auto-flush on outermost callback completion - self._root_run_id: Optional[str] = None - # Timing: key → start_ns for _start_timer / _stop_timer - self._timers: Dict[str, int] = {} # Subclasses populate during connect() for adapter_info() metadata self._metadata: Dict[str, Any] = {} @@ -61,67 +62,70 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) def _begin_run(self) -> RunState: """Start a new run with its own collector, root span, and timers. - Stores the RunState in a ContextVar so all subsequent calls to - ``_ensure_collector``, ``_start_timer``, ``_stop_timer``, and - ``_get_root_span`` use per-run state instead of instance state. + Pushes the collector and root span into ContextVars so that: + - Subsequent ``_emit`` calls route to this run's collector + - Provider adapters see the collector via ``_current_collector`` + - ContextVars are automatically isolated per ``asyncio.Task`` - ContextVars are automatically isolated per ``asyncio.Task``, so - concurrent runs on the same adapter get independent state. + If called inside an existing ``trace_context()``, reuses the + shared collector instead of creating a new one. """ + existing = _current_collector.get() + if existing is not None: + collector = existing + col_token = None + else: + collector = TraceCollector(self._client, self._config) + col_token = _current_collector.set(collector) + + root_span_id = uuid.uuid4().hex[:16] + span_snapshot = _push_span(root_span_id, f"{self.name}:root") + run = RunState( - collector=TraceCollector(self._client, self._config), - root_span_id=uuid.uuid4().hex[:16], + collector=collector, + root_span_id=root_span_id, + _token=None, + _col_token=col_token, + _span_snapshot=span_snapshot, ) run._token = _current_run.set(run) return run def _end_run(self) -> None: - """Flush the current run's collector and restore the previous ContextVar state.""" + """Pop ContextVars and flush the collector.""" run = _current_run.get() if run is None: return + + # Restore ContextVars — use try/except for each because + # frameworks like PydanticAI can copy contexts between hook + # callbacks, making tokens invalid in the current Context. + if run._span_snapshot is not None: + try: + _pop_span(run._span_snapshot) + except ValueError: + pass + if run._col_token is not None: + try: + _current_collector.reset(run._col_token) + except ValueError: + _current_collector.set(None) if run._token is not None: try: _current_run.reset(run._token) except ValueError: - # Token created in a different Context (e.g. framework copies - # contexts between hook callbacks). Fall back to plain set. _current_run.set(None) else: _current_run.set(None) - run.collector.flush() + + # Only flush if we own the collector (not shared from trace_context) + if run._col_token is not None: + run.collector.flush() def _get_run(self) -> Optional[RunState]: """Return the current RunState, or None if not inside a ``_begin_run`` scope.""" return _current_run.get() - # ------------------------------------------------------------------ - # Collector lifecycle - # ------------------------------------------------------------------ - - def _ensure_collector(self) -> TraceCollector: - """Return the collector for the current context. - - Checks (in order): active RunState, shared collector from ContextVars, - then creates a private instance-level collector as fallback. - """ - run = _current_run.get() - if run is not None: - return run.collector - - shared = _current_collector.get() - if shared is not None: - self._using_shared_collector = True - if self._root_span_id is None: - self._root_span_id = _current_span_id.get() - return shared - - if self._collector is None: - self._using_shared_collector = False - self._collector = TraceCollector(self._client, self._config) - self._root_span_id = uuid.uuid4().hex[:16] - return self._collector - @staticmethod def _new_span_id() -> str: return uuid.uuid4().hex[:16] @@ -131,12 +135,7 @@ def _new_span_id() -> str: # ------------------------------------------------------------------ def _payload(self, **extra: Any) -> Dict[str, Any]: - """Start a payload dict with ``framework: self.name``. - - Usage:: - - payload = self._payload(agent_name="foo", status="ok") - """ + """Start a payload dict with ``framework: self.name``.""" p: Dict[str, Any] = {"framework": self.name} if extra: p.update(extra) @@ -145,30 +144,20 @@ def _payload(self, **extra: Any) -> Dict[str, Any]: def _get_root_span(self) -> str: """Return the root span ID for the current run. - Checks RunState first, then falls back to instance-level ``_root_span_id``. - If neither is set, generates a new one. + Returns a new random span ID if no run is active — callers should + only call this inside a ``_begin_run`` scope. """ run = _current_run.get() if run is not None: return run.root_span_id - - with self._lock: - sid = self._root_span_id - if sid is not None: - return sid - sid = self._new_span_id() - with self._lock: - self._root_span_id = sid - return sid + log.debug("layerlens: _get_root_span called outside _begin_run scope") + return self._new_span_id() def _start_timer(self, key: str) -> None: """Record a start timestamp (nanoseconds) under *key*.""" run = _current_run.get() if run is not None: run.timers[key] = time.time_ns() - return - with self._lock: - self._timers[key] = time.time_ns() def _stop_timer(self, key: str) -> Optional[float]: """Pop the start time for *key* and return elapsed ``latency_ms``, or ``None``.""" @@ -176,8 +165,7 @@ def _stop_timer(self, key: str) -> Optional[float]: if run is not None: start_ns = run.timers.pop(key, 0) else: - with self._lock: - start_ns = self._timers.pop(key, 0) + start_ns = 0 if not start_ns: return None return (time.time_ns() - start_ns) / 1_000_000 @@ -187,29 +175,32 @@ def _normalize_tokens(usage: Any) -> Dict[str, Any]: """Extract token counts from any usage object or dict. Handles field-name variants across providers: - ``prompt_tokens`` / ``input_tokens`` → ``tokens_prompt`` - ``completion_tokens`` / ``output_tokens`` → ``tokens_completion`` + ``prompt_tokens`` / ``input_tokens`` -> ``tokens_prompt`` + ``completion_tokens`` / ``output_tokens`` -> ``tokens_completion`` Returns a dict with ``tokens_prompt``, ``tokens_completion``, - ``tokens_total`` — only keys that have non-zero values. + ``tokens_total`` -- only keys that have non-zero values. + Returns empty dict when all values are zero. """ tokens: Dict[str, Any] = {} if usage is None: return tokens if isinstance(usage, dict): - prompt = usage.get("prompt_tokens") or usage.get("input_tokens") - completion = usage.get("completion_tokens") or usage.get("output_tokens") + prompt = usage.get("prompt_tokens") + if prompt is None: + prompt = usage.get("input_tokens") + completion = usage.get("completion_tokens") + if completion is None: + completion = usage.get("output_tokens") total = usage.get("total_tokens") else: - prompt = ( - getattr(usage, "prompt_tokens", None) - or getattr(usage, "input_tokens", None) - ) - completion = ( - getattr(usage, "completion_tokens", None) - or getattr(usage, "output_tokens", None) - ) + prompt = getattr(usage, "prompt_tokens", None) + if prompt is None: + prompt = getattr(usage, "input_tokens", None) + completion = getattr(usage, "completion_tokens", None) + if completion is None: + completion = getattr(usage, "output_tokens", None) total = getattr(usage, "total_tokens", None) if prompt is not None: @@ -220,6 +211,10 @@ def _normalize_tokens(usage: Any) -> Dict[str, Any]: tokens["tokens_total"] = int(prompt) + int(completion) elif total is not None: tokens["tokens_total"] = int(total) + + # Strip all-zero results so callers can use ``if tokens:`` + if tokens and not any(tokens.values()): + return {} return tokens def _set_if_capturing(self, payload: Dict[str, Any], key: str, value: Any) -> None: @@ -228,64 +223,7 @@ def _set_if_capturing(self, payload: Dict[str, Any], key: str, value: Any) -> No payload[key] = value # ------------------------------------------------------------------ - # Callback scope — bridges framework callbacks to ContextVars - # ------------------------------------------------------------------ - - def _push_context(self, span_id: str, span_name: Optional[str] = None) -> Any: - """Push collector + span into ContextVars. Returns an opaque token for ``_pop_context``.""" - with self._lock: - collector = self._ensure_collector() - needs_collector_push = _current_collector.get() is None - col_token = _current_collector.set(collector) if needs_collector_push else None - snapshot = _push_span(span_id, span_name) - return (snapshot, col_token) - - def _pop_context(self, token: Any) -> None: - """Restore ContextVars from a token returned by ``_push_context``.""" - if token is None: - return - snapshot, col_token = token - _pop_span(snapshot) - if col_token is not None: - _current_collector.reset(col_token) - - @contextmanager - def _callback_scope( - self, - span_name: Optional[str] = None, - ) -> Generator[str, None, None]: - """Push collector + new span into ContextVars; yields the span_id.""" - span_id = self._new_span_id() - token = self._push_context(span_id, span_name) - try: - yield span_id - finally: - self._pop_context(token) - - def _traced_call( - self, - original: Any, - *args: Any, - _span_name: Optional[str] = None, - **kwargs: Any, - ) -> Any: - """Call *original* inside a _callback_scope so providers see this collector.""" - with self._callback_scope(_span_name): - return original(*args, **kwargs) - - async def _async_traced_call( - self, - original: Any, - *args: Any, - _span_name: Optional[str] = None, - **kwargs: Any, - ) -> Any: - """Async version of _traced_call.""" - with self._callback_scope(_span_name): - return await original(*args, **kwargs) - - # ------------------------------------------------------------------ - # Event emission (thread-safe) + # Event emission # ------------------------------------------------------------------ def _emit( @@ -293,61 +231,48 @@ def _emit( event_type: str, payload: Dict[str, Any], span_id: Optional[str] = None, - parent_span_id: Any = _UNSET, + parent_span_id: Optional[str] = None, span_name: Optional[str] = None, run_id: Any = None, parent_run_id: Any = None, ) -> None: - """Thread-safe event emission through the collector. - - When *run_id* is provided, it is translated to a span_id via - ``_span_id_for`` and the first run_id seen is tracked as the root - (for flush-on-completion in callback-style frameworks). + """Emit an event into the active collector. - When *parent_span_id* is omitted, falls back to ``_root_span_id``. - Pass ``parent_span_id=None`` explicitly to emit with no parent - (for adapters that manage their own span hierarchy). + Single path: reads ``_current_collector``. If there's also a + RunState, uses it for run_id mapping and root_span_id fallback. + No-op when no collector is active. """ - # RunState path: per-run isolation, no lock needed - run = _current_run.get() - if run is not None: - if run_id is not None: - span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) - sid = span_id or self._new_span_id() - parent = run.root_span_id if parent_span_id is _UNSET else parent_span_id - run.collector.emit( - event_type, payload, - span_id=sid, parent_span_id=parent, span_name=span_name, - ) + collector = _current_collector.get() + if collector is None: return - # Legacy path: instance-level state with lock - if run_id is not None: + run = _current_run.get() + + if run_id is not None and run is not None: span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) - if self._root_run_id is None: - self._root_run_id = str(run_id) - with self._lock: - collector = self._ensure_collector() - sid = span_id or self._new_span_id() - parent = self._root_span_id if parent_span_id is _UNSET else parent_span_id - collector.emit( - event_type, payload, - span_id=sid, parent_span_id=parent, span_name=span_name, - ) + + sid = span_id or self._new_span_id() + if parent_span_id is None: + parent_span_id = run.root_span_id if run is not None else _current_span_id.get() + + collector.emit( + event_type, payload, + span_id=sid, parent_span_id=parent_span_id, span_name=span_name, + ) # ------------------------------------------------------------------ - # Run ID → span ID mapping (opt-in for callback-style frameworks) + # Run ID -> span ID mapping (for callback-style frameworks) # ------------------------------------------------------------------ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: """Map a framework run_id to a (span_id, parent_span_id) pair. - When a RunState is active, span_ids are stored per-run in - ``run.data["span_ids"]`` for concurrent-run isolation. - Falls back to instance-level ``_span_ids`` otherwise. + Span IDs are stored per-run in ``run.data["span_ids"]``. """ run = _current_run.get() - span_ids = run.data.setdefault("span_ids", {}) if run is not None else self._span_ids + if run is None: + return self._new_span_id(), None + span_ids = run.data.setdefault("span_ids", {}) rid = str(run_id) if rid not in span_ids: span_ids[rid] = self._new_span_id() @@ -355,22 +280,6 @@ def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Opt parent_span_id = span_ids.get(str(parent_run_id)) if parent_run_id else None return span_id, parent_span_id - # ------------------------------------------------------------------ - # Flush - # ------------------------------------------------------------------ - - def _flush_collector(self) -> None: - """Flush private collector (no-op for shared collectors).""" - with self._lock: - collector = self._collector - is_shared = self._using_shared_collector - self._collector = None - self._root_span_id = None - self._using_shared_collector = False - self._span_ids.clear() - if collector is not None and not is_shared: - collector.flush() - # ------------------------------------------------------------------ # BaseAdapter interface # ------------------------------------------------------------------ @@ -386,9 +295,8 @@ def _on_connect(self, target: Any = None, **kwargs: Any) -> None: pass def disconnect(self) -> None: - """Clean up framework resources, flush events, and mark as disconnected.""" + """Clean up framework resources and mark as disconnected.""" self._on_disconnect() - self._flush_collector() self._connected = False self._metadata.clear() diff --git a/src/layerlens/instrument/adapters/frameworks/crewai.py b/src/layerlens/instrument/adapters/frameworks/crewai.py index b922748..f80f70c 100644 --- a/src/layerlens/instrument/adapters/frameworks/crewai.py +++ b/src/layerlens/instrument/adapters/frameworks/crewai.py @@ -3,9 +3,9 @@ import logging from typing import Any, Dict, Optional -from ._base_framework import FrameworkAdapter -from ._utils import safe_serialize -from ..._capture_config import CaptureConfig +from .._base_framework import FrameworkAdapter +from .._utils import safe_serialize +from ...._capture_config import CaptureConfig log = logging.getLogger(__name__) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py index a69a7d6..79f3990 100644 --- a/src/layerlens/instrument/adapters/frameworks/langchain.py +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -14,12 +14,8 @@ def _auto_flush(fn): # type: ignore[type-arg] def wrapper(self, *args, run_id, **kwargs): # type: ignore[no-untyped-def] fn(self, *args, run_id=run_id, **kwargs) run = self._get_run() - if run is not None: - if str(run_id) == run.data.get("root_run_id"): - self._end_run() - elif str(run_id) == self._root_run_id and self._collector is not None: - self._flush_collector() - self._root_run_id = None + if run is not None and str(run_id) == run.data.get("root_run_id"): + self._end_run() return wrapper @@ -43,8 +39,6 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) FrameworkAdapter.__init__(self, client, capture_config=capture_config) # Pending LLM runs: run_id -> {name, messages, parent_run_id} self._pending_llm: Dict[str, Dict[str, Any]] = {} - # Context tokens for span propagation: run_id -> token from _push_context - self._run_contexts: Dict[str, Any] = {} # ------------------------------------------------------------------ # Chain callbacks @@ -79,7 +73,7 @@ def on_chain_end( ) -> None: payload = self._payload(status="ok") self._set_if_capturing(payload, "output", outputs) - self._emit("agent.output", payload, run_id=run_id) + self._emit("agent.output", payload, run_id=run_id, parent_run_id=parent_run_id) @_auto_flush def on_chain_error( @@ -90,7 +84,7 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id, parent_run_id=parent_run_id) # ------------------------------------------------------------------ # LLM callbacks — merged into single model.invoke on end @@ -114,8 +108,6 @@ def on_llm_start( } self._set_if_capturing(pending, "messages", prompts) self._pending_llm[str(run_id)] = pending - span_id, _ = self._span_id_for(run_id) - self._run_contexts[str(run_id)] = self._push_context(span_id) def on_chat_model_start( self, @@ -138,8 +130,6 @@ def on_chat_model_start( [[_serialize_lc_message(m) for m in batch] for batch in messages], ) self._pending_llm[str(run_id)] = pending - span_id, _ = self._span_id_for(run_id) - self._run_contexts[str(run_id)] = self._push_context(span_id) @_auto_flush def on_llm_end( @@ -150,7 +140,6 @@ def on_llm_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._pop_context(self._run_contexts.pop(str(run_id), None)) pending = self._pending_llm.pop(str(run_id), {}) # Extract response data @@ -210,7 +199,6 @@ def on_llm_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._pop_context(self._run_contexts.pop(str(run_id), None)) pending = self._pending_llm.pop(str(run_id), {}) payload = self._payload(error=str(error)) @@ -221,7 +209,7 @@ def on_llm_error( payload["latency_ms"] = latency_ms self._emit("model.invoke", payload, run_id=run_id, parent_run_id=pending.get("parent_run_id")) - self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id, parent_run_id=pending.get("parent_run_id")) # ------------------------------------------------------------------ # Tool callbacks @@ -252,7 +240,7 @@ def on_tool_end( ) -> None: payload = self._payload() self._set_if_capturing(payload, "output", output) - self._emit("tool.result", payload, run_id=run_id) + self._emit("tool.result", payload, run_id=run_id, parent_run_id=parent_run_id) @_auto_flush def on_tool_error( @@ -263,7 +251,7 @@ def on_tool_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id, parent_run_id=parent_run_id) # ------------------------------------------------------------------ # Retriever callbacks @@ -297,7 +285,7 @@ def on_retriever_end( payload, "output", [_serialize_lc_document(d) for d in documents], ) - self._emit("tool.result", payload, run_id=run_id) + self._emit("tool.result", payload, run_id=run_id, parent_run_id=parent_run_id) @_auto_flush def on_retriever_error( @@ -308,7 +296,7 @@ def on_retriever_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id) + self._emit("agent.error", self._payload(error=str(error), status="error"), run_id=run_id, parent_run_id=parent_run_id) # ------------------------------------------------------------------ # Agent callbacks diff --git a/src/layerlens/instrument/adapters/frameworks/openai_agents.py b/src/layerlens/instrument/adapters/frameworks/openai_agents.py index e175c34..73f28f0 100644 --- a/src/layerlens/instrument/adapters/frameworks/openai_agents.py +++ b/src/layerlens/instrument/adapters/frameworks/openai_agents.py @@ -6,8 +6,9 @@ from ._base_framework import FrameworkAdapter from ._utils import safe_serialize -from ..._collector import TraceCollector from ..._capture_config import CaptureConfig +from ..._collector import TraceCollector +from ..._context import _current_collector, _current_run, RunState log = logging.getLogger(__name__) @@ -30,9 +31,9 @@ class OpenAIAgentsAdapter(_Base, FrameworkAdapter): to receive all span lifecycle events, then maps agent, generation, function, handoff, and guardrail spans to flat layerlens events. - Unlike other adapters that use a single collector, this adapter manages - per-trace collectors because the SDK can run multiple concurrent traces - through the same global processor. + Each trace gets its own RunState created directly (bypassing + ``_begin_run``, which would pollute ContextVars for other traces), + stored per-trace in ``_trace_runs`` keyed by trace_id. Usage:: @@ -56,7 +57,8 @@ class OpenAIAgentsAdapter(_Base, FrameworkAdapter): def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: FrameworkAdapter.__init__(self, client, capture_config) - self._collectors: Dict[str, TraceCollector] = {} + # trace_id -> RunState for concurrent trace isolation + self._trace_runs: Dict[str, Any] = {} # ------------------------------------------------------------------ # Lifecycle @@ -73,7 +75,7 @@ def _on_disconnect(self) -> None: set_trace_processors([]) with self._lock: - self._collectors.clear() + self._trace_runs.clear() # ------------------------------------------------------------------ # TracingProcessor interface @@ -81,16 +83,24 @@ def _on_disconnect(self) -> None: def on_trace_start(self, trace: Any) -> None: try: - self._get_collector(trace.trace_id) + # OA manages multiple concurrent traces from one processor, + # so we create RunState directly instead of using _begin_run + # (which would pollute ContextVars for the next trace). + run = RunState( + collector=TraceCollector(self._client, self._config), + root_span_id=self._new_span_id(), + ) + with self._lock: + self._trace_runs[trace.trace_id] = run except Exception: log.warning("layerlens: error in on_trace_start", exc_info=True) def on_trace_end(self, trace: Any) -> None: try: with self._lock: - collector = self._collectors.pop(trace.trace_id, None) - if collector is not None: - collector.flush() + run = self._trace_runs.pop(trace.trace_id, None) + if run is not None: + run.collector.flush() except Exception: log.warning("layerlens: error in on_trace_end", exc_info=True) @@ -99,10 +109,22 @@ def on_span_start(self, span: Any) -> None: def on_span_end(self, span: Any) -> None: try: - span_type = getattr(span.span_data, "type", None) or "" - handler_name = self._SPAN_HANDLERS.get(span_type) - if handler_name is not None: - getattr(self, handler_name)(span) + with self._lock: + run = self._trace_runs.get(span.trace_id) + if run is None: + return + + # Temporarily set both ContextVars so _emit and providers work. + run_token = _current_run.set(run) + col_token = _current_collector.set(run.collector) + try: + span_type = getattr(span.span_data, "type", None) or "" + handler_name = self._SPAN_HANDLERS.get(span_type) + if handler_name is not None: + getattr(self, handler_name)(span) + finally: + _current_collector.reset(col_token) + _current_run.reset(run_token) except Exception: log.warning("layerlens: error handling OpenAI Agents span", exc_info=True) @@ -112,23 +134,12 @@ def shutdown(self) -> None: def force_flush(self) -> None: pass - # ------------------------------------------------------------------ - # Per-trace collector - # ------------------------------------------------------------------ - - def _get_collector(self, trace_id: str) -> TraceCollector: - with self._lock: - if trace_id not in self._collectors: - self._collectors[trace_id] = TraceCollector(self._client, self._config) - return self._collectors[trace_id] - # ------------------------------------------------------------------ # Span handlers # ------------------------------------------------------------------ def _handle_agent_span(self, span: Any) -> None: data = span.span_data - collector = self._get_collector(span.trace_id) agent_name = getattr(data, "name", "unknown") span_id = span.span_id or self._new_span_id() parent_id = span.parent_id @@ -139,7 +150,7 @@ def _handle_agent_span(self, span: Any) -> None: if val: input_payload[key] = val - collector.emit( + self._emit( "agent.input", input_payload, span_id=span_id, parent_span_id=parent_id, span_name=f"agent:{agent_name}", @@ -156,7 +167,7 @@ def _handle_agent_span(self, span: Any) -> None: if span.error: out_payload["error"] = safe_serialize(span.error) - collector.emit( + self._emit( event_type, out_payload, span_id=span_id, parent_span_id=parent_id, span_name=f"agent:{agent_name}", @@ -164,7 +175,6 @@ def _handle_agent_span(self, span: Any) -> None: def _handle_generation_span(self, span: Any) -> None: data = span.span_data - collector = self._get_collector(span.trace_id) model = getattr(data, "model", None) or "unknown" span_id = span.span_id or self._new_span_id() parent_id = span.parent_id @@ -186,43 +196,46 @@ def _handle_generation_span(self, span: Any) -> None: if span.error: payload["error"] = safe_serialize(span.error) - collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + self._emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) else: - collector.emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) + self._emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) if tokens: cost_payload = self._payload(model=model) cost_payload.update(tokens) - collector.emit("cost.record", cost_payload, span_id=span_id, parent_span_id=parent_id) + self._emit("cost.record", cost_payload, span_id=span_id, parent_span_id=parent_id) def _handle_function_span(self, span: Any) -> None: data = span.span_data - collector = self._get_collector(span.trace_id) tool_name = getattr(data, "name", "unknown") span_id = span.span_id or self._new_span_id() parent_id = span.parent_id - payload = self._payload(tool_name=tool_name) - self._set_if_capturing(payload, "input", safe_serialize(getattr(data, "input", None))) - self._set_if_capturing(payload, "output", safe_serialize(getattr(data, "output", None))) - - duration_ms = _compute_duration_ms(span) - if duration_ms is not None: - payload["latency_ms"] = duration_ms - + # Emit tool.call with input + call_payload = self._payload(tool_name=tool_name) + self._set_if_capturing(call_payload, "input", safe_serialize(getattr(data, "input", None))) mcp_data = getattr(data, "mcp_data", None) if mcp_data: - payload["mcp_data"] = safe_serialize(mcp_data) + call_payload["mcp_data"] = safe_serialize(mcp_data) + self._emit("tool.call", call_payload, span_id=span_id, parent_span_id=parent_id) + # Emit tool.result or agent.error + duration_ms = _compute_duration_ms(span) if span.error: - payload["error"] = safe_serialize(span.error) - collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + err_payload = self._payload(tool_name=tool_name, error=safe_serialize(span.error)) + if duration_ms is not None: + err_payload["latency_ms"] = duration_ms + self._emit("agent.error", err_payload, span_id=span_id, parent_span_id=parent_id) else: - collector.emit("tool.call", payload, span_id=span_id, parent_span_id=parent_id) + result_payload = self._payload(tool_name=tool_name, status="ok") + self._set_if_capturing(result_payload, "output", safe_serialize(getattr(data, "output", None))) + if duration_ms is not None: + result_payload["latency_ms"] = duration_ms + self._emit("tool.result", result_payload, span_id=span_id, parent_span_id=parent_id) def _handle_handoff_span(self, span: Any) -> None: data = span.span_data - self._get_collector(span.trace_id).emit( + self._emit( "agent.handoff", self._payload( from_agent=getattr(data, "from_agent", None) or "unknown", @@ -234,7 +247,7 @@ def _handle_handoff_span(self, span: Any) -> None: def _handle_guardrail_span(self, span: Any) -> None: data = span.span_data - self._get_collector(span.trace_id).emit( + self._emit( "evaluation.result", self._payload( guardrail_name=getattr(data, "name", "unknown"), @@ -250,7 +263,6 @@ def _handle_response_span(self, span: Any) -> None: if response is None: return - collector = self._get_collector(span.trace_id) span_id = span.span_id or self._new_span_id() parent_id = span.parent_id payload = self._payload() @@ -281,9 +293,9 @@ def _handle_response_span(self, span: Any) -> None: if span.error: payload["error"] = safe_serialize(span.error) - collector.emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) + self._emit("agent.error", payload, span_id=span_id, parent_span_id=parent_id) else: - collector.emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) + self._emit("model.invoke", payload, span_id=span_id, parent_span_id=parent_id) # ------------------------------------------------------------------ diff --git a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py index b5ae173..04e35f5 100644 --- a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py +++ b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py @@ -90,7 +90,8 @@ def _register_hooks(self, hooks: Any) -> None: # ------------------------------------------------------------------ def _on_before_run(self, ctx: Any) -> None: - run = self._begin_run() + self._begin_run() + root = self._get_root_span() agent_name = self._get_agent_name(ctx) model_name = self._get_model_name(ctx) @@ -99,19 +100,18 @@ def _on_before_run(self, ctx: Any) -> None: payload["model"] = model_name self._set_if_capturing(payload, "input", safe_serialize(ctx.prompt)) - run.collector.emit( + self._emit( "agent.input", payload, - span_id=run.root_span_id, parent_span_id=None, + span_id=root, parent_span_id=None, span_name=f"pydantic_ai:{agent_name}", ) self._start_timer("run") def _on_after_run(self, ctx: Any, *, result: Any) -> Any: latency_ms = self._stop_timer("run") + root = self._get_root_span() agent_name = self._get_agent_name(ctx) model_name = self._get_model_name(ctx) - root_span = self._get_root_span() - collector = self._ensure_collector() output = self._extract_output(result) usage = self._extract_usage(result) @@ -123,9 +123,9 @@ def _on_after_run(self, ctx: Any, *, result: Any) -> Any: payload["latency_ms"] = latency_ms self._set_if_capturing(payload, "output", output) payload.update(usage) - collector.emit( + self._emit( "agent.output", payload, - span_id=root_span, parent_span_id=None, + span_id=root, parent_span_id=None, span_name=f"pydantic_ai:{agent_name}", ) @@ -134,19 +134,15 @@ def _on_after_run(self, ctx: Any, *, result: Any) -> Any: if model_name: cost_payload["model"] = model_name cost_payload.update(usage) - collector.emit( - "cost.record", cost_payload, - span_id=self._new_span_id(), parent_span_id=root_span, - ) + self._emit("cost.record", cost_payload) self._end_run() return result def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: latency_ms = self._stop_timer("run") + root = self._get_root_span() agent_name = self._get_agent_name(ctx) - root_span = self._get_root_span() - collector = self._ensure_collector() payload = self._payload( agent_name=agent_name, @@ -155,9 +151,9 @@ def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: ) if latency_ms is not None: payload["latency_ms"] = latency_ms - collector.emit( + self._emit( "agent.error", payload, - span_id=root_span, parent_span_id=None, + span_id=root, parent_span_id=None, span_name=f"pydantic_ai:{agent_name}", ) @@ -171,9 +167,6 @@ def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: def _on_after_model_request( self, ctx: Any, *, request_context: Any, response: Any, ) -> Any: - root_span = self._get_root_span() - collector = self._ensure_collector() - model_name = getattr(response, "model_name", None) usage = getattr(response, "usage", None) tokens = self._normalize_tokens(usage) @@ -183,11 +176,7 @@ def _on_after_model_request( payload["model"] = model_name payload.update(tokens) - model_span = self._new_span_id() - collector.emit( - "model.invoke", payload, - span_id=model_span, parent_span_id=root_span, - ) + self._emit("model.invoke", payload) parts = getattr(response, "parts", None) or [] for part in parts: @@ -198,27 +187,18 @@ def _on_after_model_request( tool_payload, "input", safe_serialize(getattr(part, "args", None)), ) - collector.emit( - "tool.call", tool_payload, - span_id=self._new_span_id(), parent_span_id=root_span, - ) + self._emit("tool.call", tool_payload) return response def _on_model_request_error( self, ctx: Any, *, request_context: Any, error: Exception, ) -> None: - root_span = self._get_root_span() - collector = self._ensure_collector() - payload = self._payload( error=str(error), error_type=type(error).__name__, ) - collector.emit( - "agent.error", payload, - span_id=self._new_span_id(), parent_span_id=root_span, - ) + self._emit("agent.error", payload) raise error # ------------------------------------------------------------------ @@ -229,58 +209,49 @@ def _on_before_tool_execute( self, ctx: Any, *, call: Any, tool_def: Any, args: Any, ) -> Any: tool_name = getattr(call, "tool_name", "unknown") + call_id = getattr(call, "id", None) or tool_name span_id = self._new_span_id() run = self._get_run() if run is not None: - run.data.setdefault("tool_spans", {})[tool_name] = span_id - self._start_timer(f"tool:{tool_name}") + run.data.setdefault("tool_spans", {})[call_id] = span_id + self._start_timer(f"tool:{call_id}") return args def _on_after_tool_execute( self, ctx: Any, *, call: Any, tool_def: Any, args: Any, result: Any, ) -> Any: tool_name = getattr(call, "tool_name", "unknown") - latency_ms = self._stop_timer(f"tool:{tool_name}") + call_id = getattr(call, "id", None) or tool_name + latency_ms = self._stop_timer(f"tool:{call_id}") run = self._get_run() tool_spans = run.data.get("tool_spans", {}) if run is not None else {} - span_id = tool_spans.pop(tool_name, self._new_span_id()) - - root_span = self._get_root_span() - collector = self._ensure_collector() + span_id = tool_spans.pop(call_id, self._new_span_id()) payload = self._payload(tool_name=tool_name) self._set_if_capturing(payload, "output", safe_serialize(result)) if latency_ms is not None: payload["latency_ms"] = latency_ms - collector.emit( - "tool.result", payload, - span_id=span_id, parent_span_id=root_span, - ) + self._emit("tool.result", payload, span_id=span_id) return result def _on_tool_execute_error( self, ctx: Any, *, call: Any, tool_def: Any, args: Any, error: Exception, ) -> None: tool_name = getattr(call, "tool_name", "unknown") - self._stop_timer(f"tool:{tool_name}") + call_id = getattr(call, "id", None) or tool_name + self._stop_timer(f"tool:{call_id}") run = self._get_run() if run is not None: - run.data.get("tool_spans", {}).pop(tool_name, None) - - root_span = self._get_root_span() - collector = self._ensure_collector() + run.data.get("tool_spans", {}).pop(call_id, None) payload = self._payload( tool_name=tool_name, error=str(error), error_type=type(error).__name__, ) - collector.emit( - "agent.error", payload, - span_id=self._new_span_id(), parent_span_id=root_span, - ) + self._emit("agent.error", payload) raise error # ------------------------------------------------------------------ diff --git a/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py b/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py index f02fecd..dd474b0 100644 --- a/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py +++ b/src/layerlens/instrument/adapters/frameworks/semantic_kernel.py @@ -24,6 +24,11 @@ class SemanticKernelAdapter(FrameworkAdapter): invocation filters on a Kernel instance to capture plugin calls, prompt templates, and LLM-initiated function calls as flat events. + Uses a nesting depth counter to detect run boundaries: ``_begin_run`` + when the first (outermost) function invocation starts, ``_end_run`` + when it completes. Concurrent invocations on different asyncio tasks + are isolated via ContextVar-based RunState. + Usage:: adapter = SemanticKernelAdapter(client) @@ -40,7 +45,7 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._kernel: Any = None self._filter_ids: List[tuple] = [] # (FilterTypes, filter_id) for removal self._seen_plugins: set = set() - self._patched_services: Dict[str, Any] = {} # service_id → original method + self._patched_services: Dict[str, Any] = {} # service_id -> original method # ------------------------------------------------------------------ # Lifecycle @@ -84,6 +89,29 @@ def _on_disconnect(self) -> None: self._seen_plugins.clear() self._kernel = None + # ------------------------------------------------------------------ + # Run boundary tracking via nesting depth + # ------------------------------------------------------------------ + + def _enter_invocation(self) -> None: + """Increment depth; _begin_run on 0->1 transition.""" + run = self._get_run() + if run is None: + run = self._begin_run() + run.data["depth"] = 1 + else: + run.data["depth"] = run.data.get("depth", 0) + 1 + + def _leave_invocation(self) -> None: + """Decrement depth; _end_run on 1->0 transition.""" + run = self._get_run() + if run is None: + return + depth = run.data.get("depth", 1) - 1 + run.data["depth"] = depth + if depth <= 0: + self._end_run() + # ------------------------------------------------------------------ # LLM call wrapping # ------------------------------------------------------------------ @@ -102,9 +130,7 @@ def _patch_chat_services(self, kernel: Any) -> None: async def _traced_inner(chat_history: Any, settings: Any, _orig: Any = original, _svc: Any = service) -> Any: span_id = adapter._new_span_id() - root_span = adapter._get_root_span() adapter._start_timer(span_id) - collector = adapter._ensure_collector() model_name = getattr(_svc, "ai_model_id", None) @@ -120,10 +146,7 @@ async def _traced_inner(chat_history: Any, settings: Any, _orig: Any = original, payload["model"] = model_name if latency_ms is not None: payload["latency_ms"] = latency_ms - collector.emit( - "agent.error", payload, - span_id=span_id, parent_span_id=root_span, - ) + adapter._emit("agent.error", payload, span_id=span_id) raise latency_ms = adapter._stop_timer(span_id) @@ -135,20 +158,14 @@ async def _traced_inner(chat_history: Any, settings: Any, _orig: Any = original, if latency_ms is not None: payload["latency_ms"] = latency_ms payload.update(tokens) - collector.emit( - "model.invoke", payload, - span_id=span_id, parent_span_id=root_span, - ) + adapter._emit("model.invoke", payload, span_id=span_id) if tokens: cost_payload = adapter._payload() if model_name: cost_payload["model"] = model_name cost_payload.update(tokens) - collector.emit( - "cost.record", cost_payload, - span_id=span_id, parent_span_id=root_span, - ) + adapter._emit("cost.record", cost_payload, span_id=span_id) return result @@ -187,17 +204,23 @@ def _discover_plugins(self, kernel: Any) -> None: plugins = getattr(kernel, "plugins", None) if plugins is None: return - names = list(plugins.keys()) if hasattr(plugins, "keys") else [str(p) for p in plugins] - collector = self._ensure_collector() - for name in names: - if name not in self._seen_plugins: - self._seen_plugins.add(name) - collector.emit( - "environment.config", - self._payload(plugin_name=name, event_subtype="plugin_registered"), - span_id=self._new_span_id(), - parent_span_id=self._get_root_span(), - ) + # Need a run to emit events — start one temporarily if needed + owned_run = False + if self._get_run() is None: + self._begin_run() + owned_run = True + try: + names = list(plugins.keys()) if hasattr(plugins, "keys") else [str(p) for p in plugins] + for name in names: + if name not in self._seen_plugins: + self._seen_plugins.add(name) + self._emit( + "environment.config", + self._payload(plugin_name=name, event_subtype="plugin_registered"), + ) + finally: + if owned_run: + self._end_run() except Exception: log.debug("layerlens: error discovering SK plugins", exc_info=True) @@ -208,12 +231,9 @@ def _maybe_discover_plugin(self, plugin_name: str) -> None: if plugin_name in self._seen_plugins: return self._seen_plugins.add(plugin_name) - collector = self._ensure_collector() - collector.emit( + self._emit( "environment.config", self._payload(plugin_name=plugin_name, event_subtype="plugin_registered"), - span_id=self._new_span_id(), - parent_span_id=self._get_root_span(), ) # ------------------------------------------------------------------ @@ -229,9 +249,11 @@ async def _wrap_invocation( ) -> None: """Shared wrap-and-emit logic for function and auto-function filters. - Emits tool.call on start, tool.result on success (or agent.error on failure), - with timing. The ``auto_invoked`` flag adds LLM-specific metadata. + Manages run boundaries via depth counting: ``_begin_run`` on the + outermost invocation, ``_end_run`` when it completes. """ + self._enter_invocation() + plugin_name = _extract_plugin_name(context) function_name = _extract_function_name(context) tool_name = f"{plugin_name}.{function_name}" if plugin_name else function_name @@ -239,9 +261,7 @@ async def _wrap_invocation( self._maybe_discover_plugin(plugin_name) span_id = self._new_span_id() - root_span = self._get_root_span() self._start_timer(span_id) - collector = self._ensure_collector() # -- Emit tool.call (start) -- call_payload = self._payload( @@ -253,7 +273,6 @@ async def _wrap_invocation( call_payload["auto_invoked"] = True call_payload["request_sequence_index"] = getattr(context, "request_sequence_index", 0) call_payload["function_sequence_index"] = getattr(context, "function_sequence_index", 0) - # Auto-invoked: args come from the LLM's function_call_content call_content = getattr(context, "function_call_content", None) if call_content: self._set_if_capturing( @@ -261,16 +280,14 @@ async def _wrap_invocation( safe_serialize(getattr(call_content, "arguments", None)), ) else: - # User-invoked: args come from context.arguments self._set_if_capturing( call_payload, "input", safe_serialize(_extract_arguments(context)), ) - collector.emit( + self._emit( "tool.call", call_payload, - span_id=span_id, parent_span_id=root_span, - span_name=f"sk:{tool_name}", + span_id=span_id, span_name=f"sk:{tool_name}", ) # -- Execute -- @@ -293,12 +310,8 @@ async def _wrap_invocation( err_payload["auto_invoked"] = True if latency_ms is not None: err_payload["latency_ms"] = latency_ms - collector.emit( - "agent.error", err_payload, - span_id=span_id, parent_span_id=root_span, - ) + self._emit("agent.error", err_payload, span_id=span_id) else: - # Extract result from the appropriate field if auto_invoked: func_result = getattr(context, "function_result", None) else: @@ -314,12 +327,13 @@ async def _wrap_invocation( if latency_ms is not None: result_payload["latency_ms"] = latency_ms self._set_if_capturing(result_payload, "output", safe_serialize(result_value)) - collector.emit( + self._emit( "tool.result", result_payload, - span_id=span_id, parent_span_id=root_span, - span_name=f"sk:{tool_name}", + span_id=span_id, span_name=f"sk:{tool_name}", ) + self._leave_invocation() + # ------------------------------------------------------------------ # Filters # ------------------------------------------------------------------ @@ -339,11 +353,7 @@ async def _prompt_rendering_filter(self, context: Any, next: Any) -> None: if rendered and self._config.capture_content: payload["rendered_prompt"] = truncate(str(rendered), 2000) - collector = self._ensure_collector() - collector.emit( - "agent.code", payload, - span_id=self._new_span_id(), parent_span_id=self._get_root_span(), - ) + self._emit("agent.code", payload) async def _auto_function_invocation_filter(self, context: Any, next: Any) -> None: await self._wrap_invocation(context, next, auto_invoked=True) diff --git a/src/layerlens/instrument/adapters/providers/_base_provider.py b/src/layerlens/instrument/adapters/providers/_base_provider.py index a109c16..2d3c065 100644 --- a/src/layerlens/instrument/adapters/providers/_base_provider.py +++ b/src/layerlens/instrument/adapters/providers/_base_provider.py @@ -37,6 +37,7 @@ def _wrap_sync(self, event_name: str, original: Any) -> Any: def wrapped(*args: Any, **kwargs: Any) -> Any: if _current_collector.get() is None: + log.debug("layerlens.%s: no active trace context, passing through", event_name) return original(*args, **kwargs) start = time.time() try: @@ -61,6 +62,7 @@ def _wrap_async(self, event_name: str, original: Any) -> Any: async def wrapped(*args: Any, **kwargs: Any) -> Any: if _current_collector.get() is None: + log.debug("layerlens.%s: no active trace context, passing through", event_name) return await original(*args, **kwargs) start = time.time() try: diff --git a/tests/conftest.py b/tests/conftest.py index 16b8ed4..59c3a31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,16 @@ import pytest +from layerlens.instrument import _upload + + +@pytest.fixture(autouse=True) +def _upload_sync_mode(): + """Force synchronous uploads in all tests so assertions don't race the worker thread.""" + _upload._sync_mode = True + yield + _upload._sync_mode = False + @pytest.fixture def env_vars(): diff --git a/tests/instrument/adapters/frameworks/test_concurrency.py b/tests/instrument/adapters/frameworks/test_concurrency.py new file mode 100644 index 0000000..29690e0 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_concurrency.py @@ -0,0 +1,93 @@ +"""Concurrency test: prove that RunState gives per-task isolation. + +Two asyncio.gather runs on the same PydanticAI adapter must produce +two separate traces with independent events and distinct trace_ids. +""" +from __future__ import annotations + +import asyncio +import json +from typing import Any, Dict, List + +import pytest + +pydantic_ai = pytest.importorskip("pydantic_ai") + +from pydantic_ai import Agent # noqa: E402 +from pydantic_ai.models.test import TestModel # noqa: E402 + +from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter # noqa: E402 + + +def _make_agent(output_text: str = "Hello!", tools: list | None = None) -> Agent: + agent = Agent( + model=TestModel(custom_output_text=output_text, model_name="test-model"), + name="test_agent", + ) + if tools: + for fn in tools: + agent.tool_plain(fn) + return agent + + +def _collect_traces(mock_client: Any) -> List[Dict[str, Any]]: + """Set up mock_client to accumulate individual trace payloads.""" + traces: List[Dict[str, Any]] = [] + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + traces.append(data[0]) + + mock_client.traces.upload.side_effect = _capture + return traces + + +class TestConcurrentRunIsolation: + def test_concurrent_runs_produce_separate_traces(self, mock_client: Any) -> None: + """Two asyncio.gather runs on the same adapter → two distinct traces.""" + traces = _collect_traces(mock_client) + + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"72F in {city}" + + agent = _make_agent(output_text="done", tools=[get_weather]) + adapter = PydanticAIAdapter(mock_client) + adapter.connect(target=agent) + + async def run_both() -> None: + await asyncio.gather( + agent.run("question A"), + agent.run("question B"), + ) + + asyncio.run(run_both()) + + adapter.disconnect() + + # Two runs → two traces + assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}" + + # Distinct trace_ids + trace_ids = {t["trace_id"] for t in traces} + assert len(trace_ids) == 2, f"Traces must have different trace_ids, got {trace_ids}" + + for trace in traces: + events = trace["events"] + event_types = [e["event_type"] for e in events] + + # Each trace has the core lifecycle events + assert "agent.input" in event_types, f"Missing agent.input in {event_types}" + assert "agent.output" in event_types, f"Missing agent.output in {event_types}" + assert "model.invoke" in event_types, f"Missing model.invoke in {event_types}" + + # All events in a single trace share the same trace_id + assert all( + e["trace_id"] == trace["trace_id"] for e in events + ), "Events within a trace must share trace_id" + + # agent.output has status ok + output_events = [e for e in events if e["event_type"] == "agent.output"] + assert len(output_events) == 1 + assert output_events[0]["payload"]["status"] == "ok" diff --git a/tests/instrument/adapters/frameworks/test_crewai.py b/tests/instrument/adapters/frameworks/test_crewai.py index 3b914a5..e995f8e 100644 --- a/tests/instrument/adapters/frameworks/test_crewai.py +++ b/tests/instrument/adapters/frameworks/test_crewai.py @@ -13,7 +13,7 @@ import pytest -from .conftest import capture_framework_trace, find_event, find_events +from ..conftest import capture_framework_trace, find_event, find_events # Skip entire module if crewai is not importable (Python < 3.10 or not installed). # crewai uses `type | None` syntax which causes TypeError on Python < 3.10, @@ -45,7 +45,7 @@ ) from crewai.tasks.task_output import TaskOutput # noqa: E402 -from layerlens.instrument.adapters.frameworks.crewai import CrewAIAdapter # noqa: E402 +from layerlens.instrument.adapters.frameworks._staging.crewai import CrewAIAdapter # noqa: E402 @pytest.fixture diff --git a/tests/instrument/adapters/frameworks/test_openai_agents.py b/tests/instrument/adapters/frameworks/test_openai_agents.py index 111be7d..b9ac1af 100644 --- a/tests/instrument/adapters/frameworks/test_openai_agents.py +++ b/tests/instrument/adapters/frameworks/test_openai_agents.py @@ -388,10 +388,14 @@ def test_function_span_emits_tool_call(self, adapter_and_trace): tc = find_event(events, "tool.call") assert tc["payload"]["tool_name"] == "get_weather" assert tc["payload"]["input"] == '{"city":"NYC"}' - assert tc["payload"]["output"] == '{"temp":72}' - assert tc["payload"]["latency_ms"] >= 0 assert tc["parent_span_id"] == "s_agent" + tr = find_event(events, "tool.result") + assert tr["payload"]["tool_name"] == "get_weather" + assert tr["payload"]["output"] == '{"temp":72}' + assert tr["payload"]["latency_ms"] >= 0 + assert tr["parent_span_id"] == "s_agent" + def test_function_span_with_error(self, adapter_and_trace): adapter, uploaded = adapter_and_trace @@ -720,8 +724,8 @@ def test_broken_collector_does_not_crash(self, mock_client): trace = _make_trace(trace_id="t_safe") adapter.on_trace_start(trace) - # Break the collector - adapter._collectors["t_safe"] = None # type: ignore[assignment] + # Break the run's collector + adapter._trace_runs["t_safe"] = None # type: ignore[assignment] # This should not raise span = _make_span(adapter,"t_safe", "s_safe", AgentSpanData(name="test")) diff --git a/tests/instrument/adapters/frameworks/test_semantic_kernel.py b/tests/instrument/adapters/frameworks/test_semantic_kernel.py index 9ae833a..be70508 100644 --- a/tests/instrument/adapters/frameworks/test_semantic_kernel.py +++ b/tests/instrument/adapters/frameworks/test_semantic_kernel.py @@ -301,7 +301,11 @@ def test_prompt_render_emits_agent_code(self, mock_client): async def mock_next(context): pass + # Prompt rendering only fires inside a function invocation, + # so we need an active RunState. + adapter._begin_run() _run(adapter._prompt_rendering_filter(ctx, mock_next)) + adapter._end_run() adapter.disconnect() events = uploaded["events"] @@ -325,7 +329,9 @@ def test_prompt_render_no_content_when_disabled(self, mock_client): async def mock_next(context): pass + adapter._begin_run() _run(adapter._prompt_rendering_filter(ctx, mock_next)) + adapter._end_run() adapter.disconnect() events = uploaded["events"] @@ -568,8 +574,10 @@ def test_model_invoke_emitted(self, mock_client): adapter = SemanticKernelAdapter(mock_client) adapter.connect(target=kernel) - # Call the wrapped method directly + # In real usage, LLM calls happen inside a function invocation filter. + adapter._begin_run() _run(service._inner_get_chat_message_contents(None, None)) + adapter._end_run() adapter.disconnect() @@ -588,7 +596,9 @@ def test_cost_record_emitted(self, mock_client): adapter = SemanticKernelAdapter(mock_client) adapter.connect(target=kernel) + adapter._begin_run() _run(service._inner_get_chat_message_contents(None, None)) + adapter._end_run() adapter.disconnect() events = uploaded["events"] @@ -604,7 +614,9 @@ def test_no_cost_record_without_tokens(self, mock_client): adapter = SemanticKernelAdapter(mock_client) adapter.connect(target=kernel) + adapter._begin_run() _run(service._inner_get_chat_message_contents(None, None)) + adapter._end_run() adapter.disconnect() events = uploaded["events"] @@ -634,8 +646,10 @@ async def failing_inner(chat_history, settings): service._inner_get_chat_message_contents = failing_inner adapter.connect(target=kernel) + adapter._begin_run() with pytest.raises(RuntimeError, match="API timeout"): _run(service._inner_get_chat_message_contents(None, None)) + adapter._end_run() adapter.disconnect() diff --git a/tests/instrument/test_trace_context.py b/tests/instrument/test_trace_context.py index 03a09be..04e4f9c 100644 --- a/tests/instrument/test_trace_context.py +++ b/tests/instrument/test_trace_context.py @@ -32,15 +32,15 @@ class StubAdapter(FrameworkAdapter): name = "stub" - def connect(self, target: Any = None, **kwargs: Any) -> Any: - self._connected = True - return target - def fire_event(self, event_type: str, payload: Dict[str, Any], span_id: Optional[str] = None, parent_span_id: Optional[str] = None) -> None: - self._emit(event_type, payload, span_id=span_id, - parent_span_id=parent_span_id, span_name=event_type) + kwargs: Dict[str, Any] = {"span_name": event_type} + if span_id is not None: + kwargs["span_id"] = span_id + if parent_span_id is not None: + kwargs["parent_span_id"] = parent_span_id + self._emit(event_type, payload, **kwargs) # --------------------------------------------------------------------------- @@ -70,15 +70,11 @@ def _capture(path: str) -> None: @pytest.fixture(autouse=True) -def reset_circuit_breaker(): - """Reset the upload circuit breaker between tests.""" - _upload._error_count = 0 - _upload._circuit_open = False - _upload._opened_at = 0.0 +def reset_upload_channels(): + """Clear all upload channels between tests.""" + _upload._channels.clear() yield - _upload._error_count = 0 - _upload._circuit_open = False - _upload._opened_at = 0.0 + _upload._channels.clear() # =================================================================== @@ -133,7 +129,9 @@ def test_framework_adapter_standalone_creates_own_trace( ): adapter = StubAdapter(mock_client) adapter.connect() + adapter._begin_run() adapter.fire_event("agent.lifecycle", {"action": "standalone"}) + adapter._end_run() adapter.disconnect() assert len(capture_trace) == 1 @@ -365,12 +363,14 @@ def agent_run(): assert "tool.call" in types assert "agent.output" in types - def test_adapter_disconnect_flushes_own_collector_when_standalone( + def test_adapter_begin_end_run_flushes_collector( self, mock_client, capture_trace, ): adapter = StubAdapter(mock_client) adapter.connect() + adapter._begin_run() adapter.fire_event("agent.lifecycle", {"action": "standalone"}) + adapter._end_run() adapter.disconnect() assert len(capture_trace) == 1 @@ -397,34 +397,35 @@ def test_multiple_adapters_disconnect_independently_under_shared_context( # =================================================================== -# 6. Callback scope + _traced_call +# 6. Run lifecycle (_begin_run / _end_run) # =================================================================== -class TestCallbackScope: +class TestRunLifecycle: - def test_pushes_collector_when_standalone(self, mock_client, capture_trace): + def test_begin_run_pushes_collector_standalone(self, mock_client, capture_trace): adapter = StubAdapter(mock_client) adapter.connect() assert _current_collector.get() is None - with adapter._callback_scope("test_scope") as scope_span_id: - assert _current_collector.get() is not None - assert _current_span_id.get() == scope_span_id - emit("tool.call", {"name": "test", "input": "x"}) + run = adapter._begin_run() + assert _current_collector.get() is not None + assert _current_span_id.get() == run.root_span_id + emit("tool.call", {"name": "test", "input": "x"}) + adapter._end_run() - assert _current_collector.get() is None + assert len(capture_trace) == 1 - def test_preserves_shared_collector(self, mock_client, capture_trace): + def test_begin_run_preserves_shared_collector(self, mock_client, capture_trace): adapter = StubAdapter(mock_client) adapter.connect() @trace(mock_client) def run(): shared_collector = _current_collector.get() - with adapter._callback_scope("inner") as scope_span: - assert _current_collector.get() is shared_collector - assert _current_span_id.get() == scope_span - emit("tool.call", {"name": "inner_tool", "input": "x"}) + adapter_run = adapter._begin_run() + assert adapter_run.collector is shared_collector + emit("tool.call", {"name": "inner_tool", "input": "x"}) + adapter._end_run() return "done" run() @@ -434,34 +435,16 @@ def run(): tool_call = find_event(events, "tool.call") assert tool_call["payload"]["name"] == "inner_tool" - def test_creates_child_span(self, mock_client, capture_trace): + def test_end_run_cleans_up_on_error(self, mock_client): adapter = StubAdapter(mock_client) adapter.connect() - @trace(mock_client) - def run(): - root_span = _current_span_id.get() - with adapter._callback_scope("child"): - child_span = _current_span_id.get() - assert child_span != root_span - emit("tool.call", {"name": "scoped", "input": "x"}) - assert _current_span_id.get() == root_span - return "done" - - run() - - def test_cleans_up_on_error(self, mock_client): - adapter = StubAdapter(mock_client) - adapter.connect() - - with pytest.raises(RuntimeError): - with adapter._callback_scope("failing"): - raise RuntimeError("boom") - + adapter._begin_run() + assert _current_collector.get() is not None + adapter._end_run() assert _current_collector.get() is None - assert _current_span_id.get() is None - def test_traced_call_makes_providers_visible(self, mock_client, capture_trace): + def test_begin_run_makes_providers_visible(self, mock_client, capture_trace): adapter = StubAdapter(mock_client) adapter.connect() @@ -471,27 +454,27 @@ def fake_agent_run(prompt): return "result" assert _current_collector.get() is None - result = adapter._traced_call(fake_agent_run, "hello", _span_name="agent.run") + adapter._begin_run() + result = fake_agent_run("hello") + adapter._end_run() assert result == "result" assert _current_collector.get() is None - adapter.disconnect() assert len(capture_trace) == 1 events = capture_trace[0]["events"] model_event = find_event(events, "model.invoke") assert model_event["payload"]["model"] == "gpt-4" - def test_traced_call_under_shared_context(self, mock_client, capture_trace): + def test_begin_run_under_shared_context(self, mock_client, capture_trace): adapter = StubAdapter(mock_client) adapter.connect() - def fake_agent_run(prompt): - emit("model.invoke", {"model": "gpt-4", "input": prompt}) - return "result" - @trace(mock_client) def run(): - return adapter._traced_call(fake_agent_run, "hello", _span_name="agent.run") + adapter._begin_run() + emit("model.invoke", {"model": "gpt-4", "input": "hello"}) + adapter._end_run() + return "done" run() assert len(capture_trace) == 1 @@ -506,12 +489,16 @@ def run(): class TestUploadCircuitBreaker: + def _channel(self, mock_client): + """Get or create the upload channel for mock_client.""" + return _upload._get_channel(mock_client) + def test_successful_upload(self, mock_client, capture_trace): with trace_context(mock_client): emit("tool.call", {"name": "test", "input": "x"}) assert len(capture_trace) == 1 - assert _upload._error_count == 0 + assert self._channel(mock_client)._error_count == 0 def test_upload_failure_records_error(self, mock_client): mock_client.traces.upload.side_effect = RuntimeError("network error") @@ -519,22 +506,25 @@ def test_upload_failure_records_error(self, mock_client): with trace_context(mock_client): emit("tool.call", {"name": "test", "input": "x"}) - assert _upload._error_count == 1 - assert not _upload._circuit_open + ch = self._channel(mock_client) + assert ch._error_count == 1 + assert not ch._circuit_open def test_circuit_opens_after_threshold(self, mock_client): mock_client.traces.upload.side_effect = RuntimeError("network error") - for _ in range(_upload._THRESHOLD): + for _ in range(_upload.UploadChannel._THRESHOLD): with trace_context(mock_client): emit("tool.call", {"name": "test", "input": "x"}) - assert _upload._circuit_open - assert _upload._error_count == _upload._THRESHOLD + ch = self._channel(mock_client) + assert ch._circuit_open + assert ch._error_count == _upload.UploadChannel._THRESHOLD def test_open_circuit_skips_upload(self, mock_client): - _upload._circuit_open = True - _upload._opened_at = __import__("time").monotonic() + ch = self._channel(mock_client) + ch._circuit_open = True + ch._opened_at = __import__("time").monotonic() with trace_context(mock_client): emit("tool.call", {"name": "test", "input": "x"}) @@ -542,30 +532,33 @@ def test_open_circuit_skips_upload(self, mock_client): mock_client.traces.upload.assert_not_called() def test_circuit_resets_after_cooldown(self, mock_client, capture_trace): - _upload._circuit_open = True - _upload._error_count = _upload._THRESHOLD - _upload._opened_at = ( - __import__("time").monotonic() - _upload._COOLDOWN_S - 1 + ch = self._channel(mock_client) + ch._circuit_open = True + ch._error_count = _upload.UploadChannel._THRESHOLD + ch._opened_at = ( + __import__("time").monotonic() - _upload.UploadChannel._COOLDOWN_S - 1 ) with trace_context(mock_client): emit("tool.call", {"name": "test", "input": "x"}) assert len(capture_trace) == 1 - assert not _upload._circuit_open - assert _upload._error_count == 0 + assert not ch._circuit_open + assert ch._error_count == 0 def test_success_after_failures_resets_count(self, mock_client, capture_trace): - _upload._error_count = 5 + ch = self._channel(mock_client) + ch._error_count = 5 with trace_context(mock_client): emit("tool.call", {"name": "test", "input": "x"}) - assert _upload._error_count == 0 + assert ch._error_count == 0 def test_protects_trace_decorator(self, mock_client): - _upload._circuit_open = True - _upload._opened_at = __import__("time").monotonic() + ch = self._channel(mock_client) + ch._circuit_open = True + ch._opened_at = __import__("time").monotonic() @trace(mock_client) def run(): @@ -579,11 +572,12 @@ def test_protects_framework_adapter(self, mock_client): adapter = StubAdapter(mock_client) adapter.connect() - _upload._circuit_open = True - _upload._opened_at = __import__("time").monotonic() + ch = self._channel(mock_client) + ch._circuit_open = True + ch._opened_at = __import__("time").monotonic() - adapter.fire_event("tool.call", {"name": "test", "input": "x"}) - adapter.disconnect() + with trace_context(mock_client): + adapter.fire_event("tool.call", {"name": "test", "input": "x"}) mock_client.traces.upload.assert_not_called() @@ -635,7 +629,9 @@ def test_standalone_adapter_unaffected_by_previous_shared_context( adapter = StubAdapter(mock_client) adapter.connect() + adapter._begin_run() adapter.fire_event("agent.lifecycle", {"phase": "standalone"}) + adapter._end_run() adapter.disconnect() assert len(capture_trace) == 2 From 974e4dca70f3bbcb4b9ea6ddb40d04ae4f158945 Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Mon, 6 Apr 2026 14:25:40 -0700 Subject: [PATCH 4/4] fix: update crewai --- .../instrument/adapters/frameworks/crewai.py | 267 ++++++++---------- .../adapters/frameworks/test_crewai.py | 4 +- 2 files changed, 115 insertions(+), 156 deletions(-) diff --git a/src/layerlens/instrument/adapters/frameworks/crewai.py b/src/layerlens/instrument/adapters/frameworks/crewai.py index f80f70c..96edf3f 100644 --- a/src/layerlens/instrument/adapters/frameworks/crewai.py +++ b/src/layerlens/instrument/adapters/frameworks/crewai.py @@ -1,11 +1,13 @@ from __future__ import annotations +import time import logging from typing import Any, Dict, Optional -from .._base_framework import FrameworkAdapter -from .._utils import safe_serialize -from ...._capture_config import CaptureConfig +from ._base_framework import FrameworkAdapter +from ._utils import safe_serialize +from ..._collector import TraceCollector +from ..._capture_config import CaptureConfig log = logging.getLogger(__name__) @@ -18,15 +20,15 @@ class CrewAIAdapter(FrameworkAdapter): """CrewAI adapter using the typed event bus API (crewai >= 1.0). - Subscribes to CrewAI's event bus to capture crew lifecycle, agent - execution, LLM calls, tool usage, flows, and MCP tool events as - flat layerlens events. + CrewAI's event bus dispatches handlers across threads, so this + adapter manages its own collector and span state on the instance + rather than using ContextVar-based RunState. Usage:: adapter = CrewAIAdapter(client) adapter.connect() - crew.kickoff() # events flow automatically via event bus + crew.kickoff() adapter.disconnect() """ @@ -35,17 +37,15 @@ class CrewAIAdapter(FrameworkAdapter): def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: super().__init__(client, capture_config) self._registered_handlers: list = [] - - # Span tracking: crew/flow → task → agent → leaf hierarchy + self._collector: Optional[TraceCollector] = None self._crew_span_id: Optional[str] = None - self._task_span_ids: Dict[str, str] = {} # task name → span_id + self._task_span_ids: Dict[str, str] = {} self._current_task_span_id: Optional[str] = None - self._agent_span_ids: Dict[str, str] = {} # agent_role → span_id + self._agent_span_ids: Dict[str, str] = {} self._current_agent_span_id: Optional[str] = None - # tool.call span IDs keyed by tool_name+id for pairing start/end self._tool_span_ids: Dict[str, str] = {} + self._timers: Dict[str, int] = {} - # Event name → handler method name; resolved to real classes at subscribe time. _EVENT_MAP = [ ("CrewKickoffStartedEvent", "_on_crew_started"), ("CrewKickoffCompletedEvent", "_on_crew_completed"), @@ -68,6 +68,10 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) ("MCPToolExecutionFailedEvent", "_on_mcp_tool_failed"), ] + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: self._check_dependency(_BaseEventListener is not None) self._subscribe() @@ -75,14 +79,9 @@ def _on_connect(self, target: Any = None, **kwargs: Any) -> None: def _on_disconnect(self) -> None: self._unsubscribe() self._registered_handlers.clear() - self._reset_spans() - - # ------------------------------------------------------------------ - # Event bus wiring - # ------------------------------------------------------------------ + self._end_trace() def _subscribe(self) -> None: - """Register all event handlers on the CrewAI bus.""" import crewai.events as ev # pyright: ignore[reportMissingImports] for event_name, method_name in self._EVENT_MAP: @@ -99,7 +98,6 @@ def _handler(source: Any, event: Any, _m: Any = method) -> None: self._registered_handlers.append((event_cls, _handler)) def _unsubscribe(self) -> None: - """Remove all previously registered handlers from the CrewAI bus.""" try: from crewai.events import crewai_event_bus # pyright: ignore[reportMissingImports] except ImportError: @@ -111,7 +109,56 @@ def _unsubscribe(self) -> None: log.debug("layerlens: could not unregister %s handler", event_cls.__name__, exc_info=True) # ------------------------------------------------------------------ - # Internal helpers + # Collector + state management + # ------------------------------------------------------------------ + + def _fire( + self, + event_type: str, + payload: Dict[str, Any], + span_id: Optional[str] = None, + parent_span_id: Optional[str] = None, + span_name: Optional[str] = None, + ) -> None: + """Emit directly to the instance collector.""" + c = self._collector + if c is None: + return + c.emit( + event_type, payload, + span_id=span_id or self._new_span_id(), + parent_span_id=parent_span_id, + span_name=span_name, + ) + + def _leaf_parent(self) -> Optional[str]: + return self._current_agent_span_id or self._current_task_span_id or self._crew_span_id + + def _tick(self, key: str) -> None: + self._timers[key] = time.time_ns() + + def _tock(self, key: str) -> Optional[float]: + start = self._timers.pop(key, 0) + if not start: + return None + return (time.time_ns() - start) / 1_000_000 + + def _end_trace(self) -> None: + with self._lock: + collector = self._collector + self._collector = None + self._crew_span_id = None + self._task_span_ids.clear() + self._current_task_span_id = None + self._agent_span_ids.clear() + self._current_agent_span_id = None + self._tool_span_ids.clear() + self._timers.clear() + if collector is not None: + collector.flush() + + # ------------------------------------------------------------------ + # Helpers # ------------------------------------------------------------------ @staticmethod @@ -120,7 +167,6 @@ def _get_name(obj: Any) -> str: @staticmethod def _get_task_name(event: Any) -> str: - """Extract task name from a CrewAI event.""" name = getattr(event, "task_name", None) if name: return str(name) @@ -130,32 +176,11 @@ def _get_task_name(event: Any) -> str: return "" @staticmethod - def _tool_event_key(event: Any) -> str: - """Build a key to correlate ToolUsageStarted with ToolUsageFinished.""" + def _tool_key(event: Any) -> str: tool_name = getattr(event, "tool_name", None) or "" agent_key = getattr(event, "agent_key", None) or "" return f"{tool_name}:{agent_key}" - def _leaf_parent_span_id(self) -> Optional[str]: - """Return the innermost active parent span for leaf events (LLM, tool).""" - with self._lock: - return self._current_agent_span_id or self._current_task_span_id or self._crew_span_id - - def _reset_spans(self) -> None: - """Clear all span tracking state.""" - with self._lock: - self._crew_span_id = None - self._task_span_ids.clear() - self._current_task_span_id = None - self._agent_span_ids.clear() - self._current_agent_span_id = None - self._tool_span_ids.clear() - - def _end_trace(self) -> None: - """Reset spans and flush — called when a crew/flow run completes.""" - self._reset_spans() - self._flush_collector() - # ------------------------------------------------------------------ # Crew lifecycle # ------------------------------------------------------------------ @@ -163,16 +188,18 @@ def _end_trace(self) -> None: def _on_crew_started(self, source: Any, event: Any) -> None: span_id = self._new_span_id() with self._lock: + self._collector = TraceCollector(self._client, self._config) self._crew_span_id = span_id - self._start_timer("crew") + self._tick("crew") crew_name = getattr(event, "crew_name", None) or self._get_name(source) payload = self._payload(crew_name=crew_name) self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "inputs", None))) - self._emit("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=crew_name) + self._fire("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=crew_name) def _on_crew_completed(self, source: Any, event: Any) -> None: - latency_ms = self._stop_timer("crew") + latency_ms = self._tock("crew") crew_name = getattr(event, "crew_name", None) or self._get_name(source) + span_id = self._crew_span_id or self._new_span_id() payload = self._payload(crew_name=crew_name) if latency_ms is not None: payload["duration_ns"] = int(latency_ms * 1_000_000) @@ -180,29 +207,16 @@ def _on_crew_completed(self, source: Any, event: Any) -> None: total_tokens = getattr(event, "total_tokens", None) if total_tokens is not None: payload["tokens_total"] = total_tokens - self._emit( - "agent.output", payload, - span_id=self._crew_span_id or self._new_span_id(), - parent_span_id=None, span_name=crew_name, - ) + self._fire("agent.output", payload, span_id=span_id, parent_span_id=None, span_name=crew_name) if total_tokens: - self._emit( - "cost.record", - self._payload(tokens_total=total_tokens), - span_id=self._crew_span_id or self._new_span_id(), - parent_span_id=None, - ) + self._fire("cost.record", self._payload(tokens_total=total_tokens), span_id=span_id, parent_span_id=None) self._end_trace() def _on_crew_failed(self, source: Any, event: Any) -> None: error = str(getattr(event, "error", "unknown error")) crew_name = getattr(event, "crew_name", None) or self._get_name(source) - self._emit( - "agent.error", - self._payload(crew_name=crew_name, error=error), - span_id=self._crew_span_id or self._new_span_id(), - parent_span_id=None, span_name=crew_name, - ) + span_id = self._crew_span_id or self._new_span_id() + self._fire("agent.error", self._payload(crew_name=crew_name, error=error), span_id=span_id, parent_span_id=None, span_name=crew_name) self._end_trace() # ------------------------------------------------------------------ @@ -224,11 +238,7 @@ def _on_task_started(self, source: Any, event: Any) -> None: context = getattr(event, "context", None) if context: payload["context"] = str(context)[:500] - self._emit( - "agent.input", payload, - span_id=span_id, parent_span_id=parent, - span_name=f"task:{task_name[:60]}", - ) + self._fire("agent.input", payload, span_id=span_id, parent_span_id=parent, span_name=f"task:{task_name[:60]}") def _on_task_completed(self, source: Any, event: Any) -> None: task_name = self._get_task_name(event) @@ -237,39 +247,27 @@ def _on_task_completed(self, source: Any, event: Any) -> None: parent = self._crew_span_id payload = self._payload(task_name=task_name) self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) - self._emit( - "agent.output", payload, - span_id=span_id, parent_span_id=parent, - span_name=f"task:{task_name[:60]}", - ) + self._fire("agent.output", payload, span_id=span_id, parent_span_id=parent, span_name=f"task:{task_name[:60]}") def _on_task_failed(self, source: Any, event: Any) -> None: task_name = self._get_task_name(event) with self._lock: span_id = self._task_span_ids.pop(task_name, self._current_task_span_id or self._new_span_id()) parent = self._crew_span_id - error = str(getattr(event, "error", "unknown error")) - self._emit( - "agent.error", - self._payload(task_name=task_name, error=error), - span_id=span_id, parent_span_id=parent, - ) + self._fire("agent.error", self._payload(task_name=task_name, error=str(getattr(event, "error", "unknown error"))), span_id=span_id, parent_span_id=parent) # ------------------------------------------------------------------ - # Agent execution lifecycle + # Agent execution # ------------------------------------------------------------------ def _on_agent_execution_started(self, source: Any, event: Any) -> None: agent = getattr(event, "agent", None) - agent_role = getattr(event, "agent_role", None) or ( - getattr(agent, "role", None) if agent else None - ) or "unknown" + agent_role = getattr(event, "agent_role", None) or (getattr(agent, "role", None) if agent else None) or "unknown" span_id = self._new_span_id() with self._lock: self._agent_span_ids[agent_role] = span_id self._current_agent_span_id = span_id parent = self._current_task_span_id or self._crew_span_id - payload = self._payload(agent_role=agent_role) tools = getattr(event, "tools", None) if tools: @@ -278,49 +276,30 @@ def _on_agent_execution_started(self, source: Any, event: Any) -> None: task_prompt = getattr(event, "task_prompt", None) if task_prompt: payload["task_prompt"] = str(task_prompt)[:500] - self._emit( - "agent.input", payload, - span_id=span_id, parent_span_id=parent, - span_name=f"agent:{agent_role[:60]}", - ) + self._fire("agent.input", payload, span_id=span_id, parent_span_id=parent, span_name=f"agent:{agent_role[:60]}") def _on_agent_execution_completed(self, source: Any, event: Any) -> None: agent = getattr(event, "agent", None) - agent_role = getattr(event, "agent_role", None) or ( - getattr(agent, "role", None) if agent else None - ) or "unknown" + agent_role = getattr(event, "agent_role", None) or (getattr(agent, "role", None) if agent else None) or "unknown" with self._lock: span_id = self._agent_span_ids.pop(agent_role, self._current_agent_span_id or self._new_span_id()) parent = self._current_task_span_id or self._crew_span_id if self._current_agent_span_id == span_id: self._current_agent_span_id = None - payload = self._payload(agent_role=agent_role, status="ok") self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) - self._emit( - "agent.output", payload, - span_id=span_id, parent_span_id=parent, - span_name=f"agent:{agent_role[:60]}", - ) + self._fire("agent.output", payload, span_id=span_id, parent_span_id=parent, span_name=f"agent:{agent_role[:60]}") def _on_agent_execution_error(self, source: Any, event: Any) -> None: agent = getattr(event, "agent", None) - agent_role = getattr(event, "agent_role", None) or ( - getattr(agent, "role", None) if agent else None - ) or "unknown" + agent_role = getattr(event, "agent_role", None) or (getattr(agent, "role", None) if agent else None) or "unknown" error = str(getattr(event, "error", "unknown error")) with self._lock: span_id = self._agent_span_ids.pop(agent_role, self._current_agent_span_id or self._new_span_id()) parent = self._current_task_span_id or self._crew_span_id if self._current_agent_span_id == span_id: self._current_agent_span_id = None - - self._emit( - "agent.error", - self._payload(agent_role=agent_role, error=error), - span_id=span_id, parent_span_id=parent, - span_name=f"agent:{agent_role[:60]}", - ) + self._fire("agent.error", self._payload(agent_role=agent_role, error=error), span_id=span_id, parent_span_id=parent, span_name=f"agent:{agent_role[:60]}") # ------------------------------------------------------------------ # LLM calls @@ -329,12 +308,11 @@ def _on_agent_execution_error(self, source: Any, event: Any) -> None: def _on_llm_started(self, source: Any, event: Any) -> None: call_id = getattr(event, "call_id", None) if call_id: - self._start_timer(f"llm:{call_id}") + self._tick(f"llm:{call_id}") def _on_llm_completed(self, source: Any, event: Any) -> None: model = getattr(event, "model", None) response = getattr(event, "response", None) - # Unwrap .usage from the response before normalizing usage = getattr(response, "usage", None) if response and not isinstance(response, dict) else ( response.get("usage") if isinstance(response, dict) else None ) @@ -344,19 +322,15 @@ def _on_llm_completed(self, source: Any, event: Any) -> None: payload["model"] = model call_id = getattr(event, "call_id", None) if call_id: - latency_ms = self._stop_timer(f"llm:{call_id}") + latency_ms = self._tock(f"llm:{call_id}") if latency_ms is not None: payload["latency_ms"] = latency_ms payload.update(tokens) - parent = self._leaf_parent_span_id() + parent = self._leaf_parent() span_id = self._new_span_id() - self._emit("model.invoke", payload, span_id=span_id, parent_span_id=parent) + self._fire("model.invoke", payload, span_id=span_id, parent_span_id=parent) if tokens: - self._emit( - "cost.record", - self._payload(model=model, **tokens), - span_id=span_id, parent_span_id=parent, - ) + self._fire("cost.record", self._payload(model=model, **tokens), span_id=span_id, parent_span_id=parent) def _on_llm_failed(self, source: Any, event: Any) -> None: error = str(getattr(event, "error", "unknown error")) @@ -364,35 +338,31 @@ def _on_llm_failed(self, source: Any, event: Any) -> None: payload = self._payload(error=error) if model: payload["model"] = model - parent = self._leaf_parent_span_id() - self._emit("agent.error", payload, parent_span_id=parent) + self._fire("agent.error", payload, parent_span_id=self._leaf_parent()) # ------------------------------------------------------------------ - # Tool usage — split into tool.call (start) and tool.result (end) + # Tool usage # ------------------------------------------------------------------ def _on_tool_started(self, source: Any, event: Any) -> None: tool_name = getattr(event, "tool_name", None) or "unknown" span_id = self._new_span_id() - tool_key = self._tool_event_key(event) + key = self._tool_key(event) with self._lock: - self._tool_span_ids[tool_key] = span_id + self._tool_span_ids[key] = span_id payload = self._payload(tool_name=tool_name) self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "tool_args", None))) - parent = self._leaf_parent_span_id() - self._emit("tool.call", payload, span_id=span_id, parent_span_id=parent) + self._fire("tool.call", payload, span_id=span_id, parent_span_id=self._leaf_parent()) def _on_tool_finished(self, source: Any, event: Any) -> None: tool_name = getattr(event, "tool_name", None) or "unknown" - tool_key = self._tool_event_key(event) + key = self._tool_key(event) with self._lock: - span_id = self._tool_span_ids.pop(tool_key, None) + span_id = self._tool_span_ids.pop(key, None) if span_id is None: span_id = self._new_span_id() - payload = self._payload(tool_name=tool_name) self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "output", None))) - # Compute latency from started_at/finished_at started_at = getattr(event, "started_at", None) finished_at = getattr(event, "finished_at", None) if started_at is not None and finished_at is not None: @@ -400,24 +370,17 @@ def _on_tool_finished(self, source: Any, event: Any) -> None: payload["latency_ms"] = (finished_at - started_at).total_seconds() * 1000 except Exception: pass - from_cache = getattr(event, "from_cache", None) - if from_cache: + if getattr(event, "from_cache", None): payload["from_cache"] = True - parent = self._leaf_parent_span_id() - self._emit("tool.result", payload, span_id=span_id, parent_span_id=parent) + self._fire("tool.result", payload, span_id=span_id, parent_span_id=self._leaf_parent()) def _on_tool_error(self, source: Any, event: Any) -> None: tool_name = getattr(event, "tool_name", None) or "unknown" error = str(getattr(event, "error", "unknown error")) - tool_key = self._tool_event_key(event) + key = self._tool_key(event) with self._lock: - self._tool_span_ids.pop(tool_key, None) - parent = self._leaf_parent_span_id() - self._emit( - "agent.error", - self._payload(tool_name=tool_name, error=error), - parent_span_id=parent, - ) + self._tool_span_ids.pop(key, None) + self._fire("agent.error", self._payload(tool_name=tool_name, error=error), parent_span_id=self._leaf_parent()) # ------------------------------------------------------------------ # Flow events @@ -426,25 +389,23 @@ def _on_tool_error(self, source: Any, event: Any) -> None: def _on_flow_started(self, source: Any, event: Any) -> None: span_id = self._new_span_id() with self._lock: + self._collector = TraceCollector(self._client, self._config) self._crew_span_id = span_id - self._start_timer("crew") + self._tick("crew") flow_name = getattr(event, "flow_name", None) or self._get_name(source) payload = self._payload(flow_name=flow_name) self._set_if_capturing(payload, "input", safe_serialize(getattr(event, "inputs", None))) - self._emit("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=f"flow:{flow_name}") + self._fire("agent.input", payload, span_id=span_id, parent_span_id=None, span_name=f"flow:{flow_name}") def _on_flow_finished(self, source: Any, event: Any) -> None: - latency_ms = self._stop_timer("crew") + latency_ms = self._tock("crew") flow_name = getattr(event, "flow_name", None) or self._get_name(source) + span_id = self._crew_span_id or self._new_span_id() payload = self._payload(flow_name=flow_name) if latency_ms is not None: payload["duration_ns"] = int(latency_ms * 1_000_000) self._set_if_capturing(payload, "output", safe_serialize(getattr(event, "result", None))) - self._emit( - "agent.output", payload, - span_id=self._crew_span_id or self._new_span_id(), - parent_span_id=None, span_name=f"flow:{flow_name}", - ) + self._fire("agent.output", payload, span_id=span_id, parent_span_id=None, span_name=f"flow:{flow_name}") self._end_trace() # ------------------------------------------------------------------ @@ -461,8 +422,7 @@ def _on_mcp_tool_completed(self, source: Any, event: Any) -> None: payload["mcp_server"] = server_name if latency_ms is not None: payload["latency_ms"] = latency_ms - parent = self._leaf_parent_span_id() - self._emit("tool.call", payload, parent_span_id=parent) + self._fire("tool.call", payload, parent_span_id=self._leaf_parent()) def _on_mcp_tool_failed(self, source: Any, event: Any) -> None: tool_name = getattr(event, "tool_name", None) or "unknown" @@ -471,5 +431,4 @@ def _on_mcp_tool_failed(self, source: Any, event: Any) -> None: payload = self._payload(tool_name=tool_name, error=error) if server_name: payload["mcp_server"] = server_name - parent = self._leaf_parent_span_id() - self._emit("agent.error", payload, parent_span_id=parent) + self._fire("agent.error", payload, parent_span_id=self._leaf_parent()) diff --git a/tests/instrument/adapters/frameworks/test_crewai.py b/tests/instrument/adapters/frameworks/test_crewai.py index e995f8e..3b914a5 100644 --- a/tests/instrument/adapters/frameworks/test_crewai.py +++ b/tests/instrument/adapters/frameworks/test_crewai.py @@ -13,7 +13,7 @@ import pytest -from ..conftest import capture_framework_trace, find_event, find_events +from .conftest import capture_framework_trace, find_event, find_events # Skip entire module if crewai is not importable (Python < 3.10 or not installed). # crewai uses `type | None` syntax which causes TypeError on Python < 3.10, @@ -45,7 +45,7 @@ ) from crewai.tasks.task_output import TaskOutput # noqa: E402 -from layerlens.instrument.adapters.frameworks._staging.crewai import CrewAIAdapter # noqa: E402 +from layerlens.instrument.adapters.frameworks.crewai import CrewAIAdapter # noqa: E402 @pytest.fixture