diff --git a/src/layerlens/instrument/_capture_config.py b/src/layerlens/instrument/_capture_config.py index 837cc5a..5381123 100644 --- a/src/layerlens/instrument/_capture_config.py +++ b/src/layerlens/instrument/_capture_config.py @@ -89,6 +89,18 @@ class CaptureConfig: # Gates LLM message content (prompts/completions) independently of L-layers capture_content: bool = True + def redact_payload( + self, event_type: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Return a copy of payload with fields removed per config.""" + if not self.capture_content and event_type == "model.invoke": + payload = { + k: v + for k, v in payload.items() + if k not in ("messages", "output_message") + } + return payload + def is_layer_enabled(self, event_type: str) -> bool: """Check if an event type is enabled by this config. diff --git a/src/layerlens/instrument/_collector.py b/src/layerlens/instrument/_collector.py index ba97373..031576f 100644 --- a/src/layerlens/instrument/_collector.py +++ b/src/layerlens/instrument/_collector.py @@ -44,13 +44,7 @@ def emit( if not self._config.is_layer_enabled(event_type): return - # Strip LLM message content when capture_content is off - if not self._config.capture_content and event_type == "model.invoke": - payload = { - k: v - for k, v in payload.items() - if k not in ("messages", "output_message") - } + payload = self._config.redact_payload(event_type, payload) self._sequence += 1 event: Dict[str, Any] = { @@ -66,11 +60,8 @@ def emit( self._chain.add_event(event) self._events.append(event) - def flush(self) -> None: - """Build attestation and upload the trace.""" - if not self._events: - return - + def _build_trace_payload(self) -> Dict[str, Any]: + """Build the attestation envelope and trace payload.""" try: trial = self._chain.finalize() attestation: Dict[str, Any] = { @@ -82,34 +73,21 @@ def flush(self) -> None: log.warning("Failed to build attestation chain", exc_info=True) attestation = {"attestation_error": str(exc)} - payload = { + return { "trace_id": self._trace_id, "events": self._events, "capture_config": self._config.to_dict(), "attestation": attestation, } - upload_trace(self._client, payload) + + def flush(self) -> None: + """Build attestation and upload the trace.""" + if not self._events: + return + upload_trace(self._client, self._build_trace_payload()) async def async_flush(self) -> None: """Async version of flush.""" if not self._events: return - - try: - trial = self._chain.finalize() - attestation: Dict[str, Any] = { - "chain": self._chain.to_dict(), - "root_hash": trial.hash, - "schema_version": "1.0", - } - except Exception as exc: - log.warning("Failed to build attestation chain", exc_info=True) - attestation = {"attestation_error": str(exc)} - - payload = { - "trace_id": self._trace_id, - "events": self._events, - "capture_config": self._config.to_dict(), - "attestation": attestation, - } - await async_upload_trace(self._client, payload) + await async_upload_trace(self._client, self._build_trace_payload()) diff --git a/src/layerlens/instrument/_context.py b/src/layerlens/instrument/_context.py index 0587a95..dc1f873 100644 --- a/src/layerlens/instrument/_context.py +++ b/src/layerlens/instrument/_context.py @@ -11,24 +11,24 @@ _current_span_name: ContextVar[Optional[str]] = ContextVar("_current_span_name", default=None) -class _SpanTokens(NamedTuple): +class _SpanSnapshot(NamedTuple): span_id: Any parent_span_id: Any span_name: Any -def _push_span(span_id: str, name: Optional[str] = None) -> _SpanTokens: +def _push_span(span_id: str, name: Optional[str] = None) -> _SpanSnapshot: """Push a new span onto the context stack. The current span becomes the parent.""" old_span_id = _current_span_id.get() - return _SpanTokens( + return _SpanSnapshot( span_id=_current_span_id.set(span_id), parent_span_id=_parent_span_id.set(old_span_id), span_name=_current_span_name.set(name), ) -def _pop_span(tokens: _SpanTokens) -> None: +def _pop_span(snapshot: _SpanSnapshot) -> None: """Restore the previous span context.""" - _current_span_name.reset(tokens.span_name) - _parent_span_id.reset(tokens.parent_span_id) - _current_span_id.reset(tokens.span_id) + _current_span_name.reset(snapshot.span_name) + _parent_span_id.reset(snapshot.parent_span_id) + _current_span_id.reset(snapshot.span_id) diff --git a/src/layerlens/instrument/_decorator.py b/src/layerlens/instrument/_decorator.py index bfaf570..b4a118c 100644 --- a/src/layerlens/instrument/_decorator.py +++ b/src/layerlens/instrument/_decorator.py @@ -29,7 +29,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: root_span_id = uuid.uuid4().hex[:16] col_token = _current_collector.set(collector) - span_tokens = _push_span(root_span_id, span_name) + span_snapshot = _push_span(root_span_id, span_name) try: collector.emit( "agent.input", @@ -58,7 +58,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: await collector.async_flush() raise finally: - _pop_span(span_tokens) + _pop_span(span_snapshot) _current_collector.reset(col_token) return async_wrapper @@ -71,7 +71,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: root_span_id = uuid.uuid4().hex[:16] col_token = _current_collector.set(collector) - span_tokens = _push_span(root_span_id, span_name) + span_snapshot = _push_span(root_span_id, span_name) try: collector.emit( "agent.input", @@ -100,7 +100,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: collector.flush() raise finally: - _pop_span(span_tokens) + _pop_span(span_snapshot) _current_collector.reset(col_token) return sync_wrapper diff --git a/src/layerlens/instrument/_span.py b/src/layerlens/instrument/_span.py index 9eb239f..0ea2ecd 100644 --- a/src/layerlens/instrument/_span.py +++ b/src/layerlens/instrument/_span.py @@ -18,8 +18,8 @@ def span(name: str) -> Generator[str, None, None]: Yields the span_id string. """ new_span_id = uuid.uuid4().hex[:16] - tokens = _push_span(new_span_id, name) + snapshot = _push_span(new_span_id, name) try: yield new_span_id finally: - _pop_span(tokens) + _pop_span(snapshot) diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index af082cc..197c65e 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -1,79 +1,137 @@ +"""Unified base class for all framework adapters. + +Framework adapters hook into a framework's callback / event / tracing +system and emit LayerLens events. They share a common lifecycle: + + 1. Lazy-init a :class:`TraceCollector` on first event. + 2. Emit events through a thread-safe helper. + 3. Flush the collector when a logical trace ends (root span completes, + agent run finishes, disconnect, etc.). + +Subclasses MUST set ``name`` and implement ``connect()``. +Subclasses SHOULD call ``super().disconnect()`` after unhooking. +""" from __future__ import annotations import uuid -from uuid import UUID -from typing import Any, Dict, Optional, Tuple +import threading +from typing import Any, Dict, Optional from .._base import AdapterInfo, BaseAdapter -from ..._capture_config import CaptureConfig from ..._collector import TraceCollector +from ..._capture_config import CaptureConfig -class FrameworkTracer(BaseAdapter): - """Base class for framework adapters that manage their own collector. - - Framework adapters (LangChain, LangGraph, etc.) receive callbacks - from the framework rather than wrapping SDK methods. They maintain - their own TraceCollector and map framework run_ids to span_ids. - """ +class FrameworkAdapter(BaseAdapter): + """Base for framework adapters with collector lifecycle management.""" - _adapter_name: str = "framework" + name: str # Subclass must set: "crewai", "llamaindex", etc. def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - self._client: Any = None + self._client = client self._config = capture_config or CaptureConfig.standard() + self._lock = threading.Lock() + self._connected = False self._collector: Optional[TraceCollector] = None + self._root_span_id: Optional[str] = None + # Optional run_id → span_id mapping for callback-style frameworks self._span_ids: Dict[str, str] = {} - self._root_run_id: Optional[str] = None - self.connect(client) - def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 - self._client = target - return target - - def disconnect(self) -> None: - self._span_ids.clear() - self._root_run_id = None - self._collector = None - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name=self._adapter_name, - adapter_type="framework", - connected=self._client is not None, - ) + # ------------------------------------------------------------------ + # Collector lifecycle + # ------------------------------------------------------------------ def _ensure_collector(self) -> TraceCollector: + """Lazily create a collector and root span ID.""" if self._collector is None: self._collector = TraceCollector(self._client, self._config) + self._root_span_id = uuid.uuid4().hex[:16] return self._collector - def _get_or_create_span_id( - self, run_id: UUID, parent_run_id: Optional[UUID] = None - ) -> Tuple[str, Optional[str]]: - rid = str(run_id) - if rid not in self._span_ids: - self._span_ids[rid] = uuid.uuid4().hex[:16] - span_id = self._span_ids[rid] - parent_span_id = self._span_ids.get(str(parent_run_id)) if parent_run_id else None - if self._root_run_id is None: - self._root_run_id = rid - return span_id, parent_span_id + @staticmethod + def _new_span_id() -> str: + return uuid.uuid4().hex[:16] + + # ------------------------------------------------------------------ + # Event emission (thread-safe) + # ------------------------------------------------------------------ def _emit( self, event_type: str, payload: Dict[str, Any], - run_id: UUID, - parent_run_id: Optional[UUID] = None, + span_id: Optional[str] = None, + parent_span_id: Optional[str] = None, + span_name: Optional[str] = None, ) -> None: - collector = self._ensure_collector() - span_id, parent_span_id = self._get_or_create_span_id(run_id, parent_run_id) - collector.emit(event_type, payload, span_id=span_id, parent_span_id=parent_span_id) + """Thread-safe event emission through the collector.""" + with self._lock: + collector = self._ensure_collector() + sid = span_id or self._new_span_id() + parent = parent_span_id or self._root_span_id + collector.emit( + event_type, payload, + span_id=sid, parent_span_id=parent, span_name=span_name, + ) - def _maybe_flush(self, run_id: UUID) -> None: - if str(run_id) == self._root_run_id and self._collector is not None: - self._collector.flush() - self._span_ids.clear() - self._root_run_id = None + # ------------------------------------------------------------------ + # Run ID → span ID mapping (opt-in for callback-style frameworks) + # ------------------------------------------------------------------ + + def _span_id_for(self, run_id: Any, parent_run_id: Any = None) -> tuple[str, Optional[str]]: + """Map a framework run_id to a span_id, creating one if needed. + + Returns ``(span_id, parent_span_id)``. Useful for frameworks + (LangChain, CrewAI, OpenAI Agents) that assign their own run + identifiers to each step. + """ + rid = str(run_id) + if rid not in self._span_ids: + self._span_ids[rid] = self._new_span_id() + span_id = self._span_ids[rid] + parent_span_id = self._span_ids.get(str(parent_run_id)) if parent_run_id else None + return span_id, parent_span_id + + # ------------------------------------------------------------------ + # Flush + # ------------------------------------------------------------------ + + def _flush_collector(self) -> None: + """Flush the current collector and reset state.""" + with self._lock: + collector = self._collector self._collector = None + self._root_span_id = None + self._span_ids.clear() + if collector is not None: + collector.flush() + + # ------------------------------------------------------------------ + # BaseAdapter interface + # ------------------------------------------------------------------ + + def connect(self, target: Any = None, **kwargs: Any) -> Any: + """Mark the adapter as connected. + + Callback-style adapters (LangChain, LangGraph) are passed directly + to the framework, so ``connect()`` just flips the flag. Adapters + that need registration (CrewAI, LlamaIndex, etc.) should override. + """ + self._connected = True + return target + + def disconnect(self) -> None: + """Flush remaining events and mark as disconnected. + + Subclasses should unhook from the framework first, then call + ``super().disconnect()``. + """ + self._flush_collector() + self._connected = False + + def adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name=self.name, + adapter_type="framework", + connected=self._connected, + ) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py index bfa1784..5b14f0e 100644 --- a/src/layerlens/instrument/adapters/frameworks/langchain.py +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -3,7 +3,7 @@ from uuid import UUID from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence -from ._base_framework import FrameworkTracer +from ._base_framework import FrameworkAdapter if TYPE_CHECKING: from ..._capture_config import CaptureConfig @@ -20,12 +20,32 @@ def __init_subclass__(cls, **kwargs: Any) -> None: ) -class LangChainCallbackHandler(BaseCallbackHandler, FrameworkTracer): - _adapter_name: str = "langchain" +class LangChainCallbackHandler(BaseCallbackHandler, FrameworkAdapter): + name = "langchain" def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: BaseCallbackHandler.__init__(self) - FrameworkTracer.__init__(self, client, capture_config=capture_config) + FrameworkAdapter.__init__(self, client, capture_config=capture_config) + self._root_run_id: Optional[str] = None + + def _emit_for_run( + self, + event_type: str, + payload: Dict[str, Any], + run_id: UUID, + parent_run_id: Optional[UUID] = None, + ) -> None: + """Emit an event, mapping framework run_ids to span_ids.""" + span_id, parent_span_id = self._span_id_for(run_id, parent_run_id) + rid = str(run_id) + if self._root_run_id is None: + self._root_run_id = rid + self._emit(event_type, payload, span_id=span_id, parent_span_id=parent_span_id) + + def _maybe_flush(self, run_id: UUID) -> None: + if str(run_id) == self._root_run_id and self._collector is not None: + self._flush_collector() + self._root_run_id = None # -- Chain -- @@ -40,7 +60,7 @@ def on_chain_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) def on_chain_end( self, @@ -50,7 +70,7 @@ def on_chain_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.output", {"output": outputs, "status": "ok"}, run_id) + self._emit_for_run("agent.output", {"output": outputs, "status": "ok"}, run_id) self._maybe_flush(run_id) def on_chain_error( @@ -61,7 +81,7 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- LLM -- @@ -77,7 +97,7 @@ def on_llm_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit("model.invoke", {"name": name, "messages": prompts}, run_id, parent_run_id) + self._emit_for_run("model.invoke", {"name": name, "messages": prompts}, run_id, parent_run_id) def on_chat_model_start( self, @@ -90,7 +110,7 @@ def on_chat_model_start( ) -> None: serialized = serialized or {} name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] - self._emit( + self._emit_for_run( "model.invoke", {"name": name, "messages": [[_serialize_lc_message(m) for m in batch] for batch in messages]}, run_id, @@ -120,7 +140,7 @@ def on_llm_end( model_name = llm_output.get("model_name") if model_name or output: - self._emit( + self._emit_for_run( "model.invoke", {"model": model_name, "output_message": output}, run_id, @@ -129,7 +149,7 @@ def on_llm_end( usage = llm_output.get("token_usage", {}) if usage: - self._emit("cost.record", usage, run_id, parent_run_id) + self._emit_for_run("cost.record", usage, run_id, parent_run_id) self._maybe_flush(run_id) @@ -141,7 +161,7 @@ def on_llm_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- Tool -- @@ -156,7 +176,7 @@ def on_tool_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "tool") - self._emit("tool.call", {"name": name, "input": input_str}, run_id, parent_run_id) + self._emit_for_run("tool.call", {"name": name, "input": input_str}, run_id, parent_run_id) def on_tool_end( self, @@ -166,7 +186,7 @@ def on_tool_end( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("tool.result", {"output": output}, run_id) + self._emit_for_run("tool.result", {"output": output}, run_id) self._maybe_flush(run_id) def on_tool_error( @@ -177,7 +197,7 @@ def on_tool_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- Retriever -- @@ -192,7 +212,7 @@ def on_retriever_start( **kwargs: Any, ) -> None: name = (serialized or {}).get("name", "retriever") - self._emit("tool.call", {"name": name, "input": query}, run_id, parent_run_id) + self._emit_for_run("tool.call", {"name": name, "input": query}, run_id, parent_run_id) def on_retriever_end( self, @@ -203,7 +223,7 @@ def on_retriever_end( **kwargs: Any, ) -> None: output = [_serialize_lc_document(d) for d in documents] - self._emit("tool.result", {"output": output}, run_id) + self._emit_for_run("tool.result", {"output": output}, run_id) self._maybe_flush(run_id) def on_retriever_error( @@ -214,7 +234,7 @@ def on_retriever_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - self._emit("agent.error", {"error": str(error), "status": "error"}, run_id) + self._emit_for_run("agent.error", {"error": str(error), "status": "error"}, run_id) self._maybe_flush(run_id) # -- Text (required by base) -- diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph.py b/src/layerlens/instrument/adapters/frameworks/langgraph.py index 47d2439..f4b666a 100644 --- a/src/layerlens/instrument/adapters/frameworks/langgraph.py +++ b/src/layerlens/instrument/adapters/frameworks/langgraph.py @@ -7,7 +7,7 @@ class LangGraphCallbackHandler(LangChainCallbackHandler): - _adapter_name: str = "langgraph" + name = "langgraph" def on_chain_start( self, @@ -38,4 +38,4 @@ def on_chain_start( if node_name: name = node_name - self._emit("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) + self._emit_for_run("agent.input", {"name": name, "input": inputs}, run_id, parent_run_id) diff --git a/src/layerlens/instrument/adapters/providers/_base_provider.py b/src/layerlens/instrument/adapters/providers/_base_provider.py index cd02bb2..a109c16 100644 --- a/src/layerlens/instrument/adapters/providers/_base_provider.py +++ b/src/layerlens/instrument/adapters/providers/_base_provider.py @@ -1,96 +1,101 @@ from __future__ import annotations -import uuid -from typing import Any, Dict, Callable - -from ..._context import _current_collector, _current_span_id - - -def emit_llm_events( - name: str, - kwargs: Dict[str, Any], - response: Any, - extract_output: Callable[[Any], Any], - extract_meta: Callable[[Any], Dict[str, Any]], - capture_params: frozenset[str], - latency_ms: float, -) -> None: - """Emit model.invoke + cost.record events for an LLM call. - - Builds the full payload -- the collector handles CaptureConfig gating - (L3 suppresses model.invoke entirely, capture_content strips messages). - """ - collector = _current_collector.get() - if collector is None: - return - - parent_span_id = _current_span_id.get() - span_id = uuid.uuid4().hex[:16] - response_meta = extract_meta(response) - - collector.emit( - "model.invoke", - { - "name": name, - "latency_ms": latency_ms, - "parameters": {k: kwargs[k] for k in capture_params if k in kwargs}, - "messages": _extract_messages(kwargs), - "output_message": extract_output(response), - **response_meta, - }, - span_id=span_id, - parent_span_id=parent_span_id, - ) - - usage = response_meta.get("usage", {}) - if usage: - collector.emit( - "cost.record", - { - "provider": name.split(".")[0], - "model": response_meta.get("response_model", kwargs.get("model")), - **usage, - }, - span_id=span_id, - parent_span_id=parent_span_id, - ) +import abc +import time +import logging +from typing import Any, Dict + +from .._base import AdapterInfo, BaseAdapter +from ._emit_helpers import emit_llm_events, emit_llm_error +from ..._context import _current_collector + +log: logging.Logger = logging.getLogger(__name__) + + +class MonkeyPatchProvider(BaseAdapter): + """Base for providers that monkey-patch SDK client or module methods.""" + + name: str + capture_params: frozenset[str] + + def __init__(self) -> None: + self._client: Any = None + self._originals: Dict[str, Any] = {} + + @staticmethod + @abc.abstractmethod + def extract_output(response: Any) -> Any: ... + @staticmethod + @abc.abstractmethod + def extract_meta(response: Any) -> Dict[str, Any]: ... -def emit_llm_error( - name: str, - error: Exception, - latency_ms: float, -) -> None: - """Emit agent.error event for a failed LLM call.""" - collector = _current_collector.get() - parent_span_id = _current_span_id.get() - if collector is None: - return - - span_id = uuid.uuid4().hex[:16] - collector.emit( - "agent.error", - {"name": name, "error": str(error), "latency_ms": latency_ms}, - span_id=span_id, - parent_span_id=parent_span_id, - ) - - -def _extract_messages(kwargs: Dict[str, Any]) -> Any: - messages = kwargs.get("messages") - if messages is not None: - return [_serialize_message(m) for m in messages] - for key in ("prompt", "contents", "input"): - val = kwargs.get(key) - if val is not None: - return val - return None - - -def _serialize_message(msg: Any) -> Any: - if isinstance(msg, dict): - return msg - try: - return {"role": msg.role, "content": msg.content} - except AttributeError: - return str(msg) + def _wrap_sync(self, event_name: str, original: Any) -> Any: + extract_output = self.extract_output + extract_meta = self.extract_meta + capture_params = self.capture_params + + def wrapped(*args: Any, **kwargs: Any) -> Any: + if _current_collector.get() is None: + return original(*args, **kwargs) + start = time.time() + try: + response = original(*args, **kwargs) + except Exception as exc: + latency_ms = (time.time() - start) * 1000 + emit_llm_error(event_name, exc, latency_ms) + raise + latency_ms = (time.time() - start) * 1000 + emit_llm_events( + event_name, kwargs, response, + extract_output, extract_meta, capture_params, latency_ms, + ) + return response + + return wrapped + + def _wrap_async(self, event_name: str, original: Any) -> Any: + extract_output = self.extract_output + extract_meta = self.extract_meta + capture_params = self.capture_params + + async def wrapped(*args: Any, **kwargs: Any) -> Any: + if _current_collector.get() is None: + return await original(*args, **kwargs) + start = time.time() + try: + response = await original(*args, **kwargs) + except Exception as exc: + latency_ms = (time.time() - start) * 1000 + emit_llm_error(event_name, exc, latency_ms) + raise + latency_ms = (time.time() - start) * 1000 + emit_llm_events( + event_name, kwargs, response, + extract_output, extract_meta, capture_params, latency_ms, + ) + return response + + return wrapped + + def disconnect(self) -> None: + if self._client is None: + return + for key, orig in self._originals.items(): + try: + parts = key.split(".") + obj = self._client + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], orig) + except Exception: + log.warning("Could not restore %s", key) + self._client = None + self._originals.clear() + + def adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name=self.name, + adapter_type="provider", + connected=self._client is not None, + ) diff --git a/src/layerlens/instrument/adapters/providers/_emit_helpers.py b/src/layerlens/instrument/adapters/providers/_emit_helpers.py new file mode 100644 index 0000000..d46a9ed --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_emit_helpers.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import uuid +from typing import Any, Dict, Callable + +from ..._context import _current_collector, _current_span_id + + +def emit_llm_events( + name: str, + kwargs: Dict[str, Any], + response: Any, + extract_output: Callable[[Any], Any], + extract_meta: Callable[[Any], Dict[str, Any]], + capture_params: frozenset[str], + latency_ms: float, +) -> None: + """Emit model.invoke + cost.record events for an LLM call. + + Builds the full payload -- the collector handles CaptureConfig gating + (L3 suppresses model.invoke entirely, capture_content strips messages). + """ + collector = _current_collector.get() + if collector is None: + return + + parent_span_id = _current_span_id.get() + span_id = uuid.uuid4().hex[:16] + response_meta = extract_meta(response) + + # Resolve model name: prefer response_model (actual model used), fall back to kwargs + model_name = response_meta.get("response_model") or kwargs.get("model") + + collector.emit( + "model.invoke", + { + "name": name, + "model": model_name, + "latency_ms": latency_ms, + "parameters": {k: kwargs[k] for k in capture_params if k in kwargs}, + "messages": _extract_messages(kwargs), + "output_message": extract_output(response), + **response_meta, + }, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + usage = response_meta.get("usage", {}) + if usage: + collector.emit( + "cost.record", + { + "provider": name.split(".")[0], + "model": response_meta.get("response_model", kwargs.get("model")), + **usage, + }, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + +def emit_llm_error( + name: str, + error: Exception, + latency_ms: float, +) -> None: + """Emit agent.error event for a failed LLM call.""" + collector = _current_collector.get() + parent_span_id = _current_span_id.get() + if collector is None: + return + + span_id = uuid.uuid4().hex[:16] + collector.emit( + "agent.error", + {"name": name, "error": str(error), "latency_ms": latency_ms}, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + +def _extract_messages(kwargs: Dict[str, Any]) -> Any: + messages = kwargs.get("messages") + if messages is not None: + return [_serialize_message(m) for m in messages] + for key in ("prompt", "contents", "input"): + val = kwargs.get(key) + if val is not None: + return val + return None + + +def _serialize_message(msg: Any) -> Any: + if isinstance(msg, dict): + return msg + try: + return {"role": msg.role, "content": msg.content} + except AttributeError: + return str(msg) diff --git a/src/layerlens/instrument/adapters/providers/anthropic.py b/src/layerlens/instrument/adapters/providers/anthropic.py index 675b558..0a2b17d 100644 --- a/src/layerlens/instrument/adapters/providers/anthropic.py +++ b/src/layerlens/instrument/adapters/providers/anthropic.py @@ -1,14 +1,8 @@ from __future__ import annotations -import time -import logging from typing import Any, Dict -from .._base import AdapterInfo, BaseAdapter -from ._base_provider import emit_llm_events, emit_llm_error -from ..._context import _current_collector - -log: logging.Logger = logging.getLogger(__name__) +from ._base_provider import MonkeyPatchProvider _CAPTURE_PARAMS = frozenset( { @@ -23,10 +17,42 @@ ) -class AnthropicProvider(BaseAdapter): - def __init__(self) -> None: - self._client: Any = None - self._originals: Dict[str, Any] = {} +class AnthropicProvider(MonkeyPatchProvider): + name = "anthropic" + capture_params = _CAPTURE_PARAMS + + @staticmethod + def extract_output(response: Any) -> Any: + try: + content = response.content + if content: + block = content[0] + return {"type": block.type, "text": getattr(block, "text", None)} + except (AttributeError, IndexError): + pass + return None + + @staticmethod + def extract_meta(response: Any) -> Dict[str, Any]: + meta: Dict[str, Any] = {} + try: + usage = response.usage + if usage is not None: + meta["usage"] = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + except AttributeError: + pass + try: + meta["response_model"] = response.model + except AttributeError: + pass + try: + meta["stop_reason"] = response.stop_reason + except AttributeError: + pass + return meta def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 self._client = target @@ -34,110 +60,19 @@ def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 if hasattr(target, "messages"): orig = target.messages.create self._originals["messages.create"] = orig - target.messages.create = self._wrap_sync(orig) + target.messages.create = self._wrap_sync( + "anthropic.messages.create", orig + ) if hasattr(target.messages, "acreate"): async_orig = target.messages.acreate self._originals["messages.acreate"] = async_orig - target.messages.acreate = self._wrap_async(async_orig) + target.messages.acreate = self._wrap_async( + "anthropic.messages.create", async_orig + ) return target - def disconnect(self) -> None: - if self._client is None: - return - for key, orig in self._originals.items(): - try: - parts = key.split(".") - obj = self._client - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], orig) - except Exception: - log.warning("Could not restore %s", key) - self._client = None - self._originals.clear() - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name="anthropic", - adapter_type="provider", - connected=self._client is not None, - ) - - def _wrap_sync(self, original: Any) -> Any: - def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return original(*args, **kwargs) - start = time.time() - try: - response = original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("anthropic.messages.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "anthropic.messages.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - def _wrap_async(self, original: Any) -> Any: - async def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return await original(*args, **kwargs) - start = time.time() - try: - response = await original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("anthropic.messages.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "anthropic.messages.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - -def _extract_output(response: Any) -> Any: - try: - content = response.content - if content: - block = content[0] - return {"type": block.type, "text": getattr(block, "text", None)} - except (AttributeError, IndexError): - pass - return None - - -def _extract_response_meta(response: Any) -> Dict[str, Any]: - meta: Dict[str, Any] = {} - try: - usage = response.usage - if usage is not None: - meta["usage"] = { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - } - except AttributeError: - pass - try: - meta["response_model"] = response.model - except AttributeError: - pass - try: - meta["stop_reason"] = response.stop_reason - except AttributeError: - pass - return meta - # --- Convenience API --- diff --git a/src/layerlens/instrument/adapters/providers/litellm.py b/src/layerlens/instrument/adapters/providers/litellm.py index e7bbda8..784e7e8 100644 --- a/src/layerlens/instrument/adapters/providers/litellm.py +++ b/src/layerlens/instrument/adapters/providers/litellm.py @@ -1,12 +1,9 @@ from __future__ import annotations -import time -from typing import Any +from typing import Any, Dict -from .._base import AdapterInfo, BaseAdapter -from .openai import _extract_output, _extract_response_meta -from ._base_provider import emit_llm_events, emit_llm_error -from ..._context import _current_collector +from ._base_provider import MonkeyPatchProvider +from .openai import OpenAIProvider _CAPTURE_PARAMS = frozenset( { @@ -21,91 +18,40 @@ ) -class LiteLLMProvider(BaseAdapter): - def __init__(self) -> None: - self._original_completion: Any = None - self._original_acompletion: Any = None - self._connected = False +class LiteLLMProvider(MonkeyPatchProvider): + name = "litellm" + capture_params = _CAPTURE_PARAMS + + @staticmethod + def extract_output(response: Any) -> Any: + return OpenAIProvider.extract_output(response) + + @staticmethod + def extract_meta(response: Any) -> Dict[str, Any]: + return OpenAIProvider.extract_meta(response) def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 try: import litellm except ImportError as err: raise ImportError( - "The 'litellm' package is required for LiteLLM instrumentation. Install it with: pip install litellm" + "The 'litellm' package is required for LiteLLM instrumentation. " + "Install it with: pip install litellm" ) from err - if self._original_completion is None: - self._original_completion = litellm.completion - orig_sync = self._original_completion - - def patched_completion(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return orig_sync(*args, **kwargs) - start = time.time() - try: - response = orig_sync(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("litellm.completion", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "litellm.completion", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - litellm.completion = patched_completion - - if self._original_acompletion is None: - self._original_acompletion = litellm.acompletion - orig_async = self._original_acompletion - - async def patched_acompletion(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return await orig_async(*args, **kwargs) - start = time.time() - try: - response = await orig_async(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("litellm.acompletion", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "litellm.acompletion", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - litellm.acompletion = patched_acompletion - - self._connected = True - return target + self._client = litellm - def disconnect(self) -> None: - try: - import litellm - except ImportError: - self._connected = False - return - - if self._original_completion is not None: - litellm.completion = self._original_completion - self._original_completion = None - if self._original_acompletion is not None: - litellm.acompletion = self._original_acompletion - self._original_acompletion = None - - self._connected = False - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name="litellm", - adapter_type="provider", - connected=self._connected, - ) + if "completion" not in self._originals: + orig_sync = litellm.completion + self._originals["completion"] = orig_sync + litellm.completion = self._wrap_sync("litellm.completion", orig_sync) + + if "acompletion" not in self._originals: + orig_async = litellm.acompletion + self._originals["acompletion"] = orig_async + litellm.acompletion = self._wrap_async("litellm.acompletion", orig_async) + + return target # --- Convenience API --- diff --git a/src/layerlens/instrument/adapters/providers/openai.py b/src/layerlens/instrument/adapters/providers/openai.py index bdb8480..d09779f 100644 --- a/src/layerlens/instrument/adapters/providers/openai.py +++ b/src/layerlens/instrument/adapters/providers/openai.py @@ -1,14 +1,8 @@ from __future__ import annotations -import time -import logging from typing import Any, Dict -from .._base import AdapterInfo, BaseAdapter -from ._base_provider import emit_llm_events, emit_llm_error -from ..._context import _current_collector - -log: logging.Logger = logging.getLogger(__name__) +from ._base_provider import MonkeyPatchProvider _CAPTURE_PARAMS = frozenset( { @@ -24,10 +18,39 @@ ) -class OpenAIProvider(BaseAdapter): - def __init__(self) -> None: - self._client: Any = None - self._originals: Dict[str, Any] = {} +class OpenAIProvider(MonkeyPatchProvider): + name = "openai" + capture_params = _CAPTURE_PARAMS + + @staticmethod + def extract_output(response: Any) -> Any: + try: + choices = response.choices + if choices: + msg = choices[0].message + return {"role": msg.role, "content": msg.content} + except (AttributeError, IndexError): + pass + return None + + @staticmethod + def extract_meta(response: Any) -> Dict[str, Any]: + meta: Dict[str, Any] = {} + try: + usage = response.usage + if usage is not None: + meta["usage"] = { + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + } + except AttributeError: + pass + try: + meta["response_model"] = response.model + except AttributeError: + pass + return meta def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 self._client = target @@ -35,107 +58,19 @@ def connect(self, target: Any = None, **kwargs: Any) -> Any: # noqa: ARG002 if hasattr(target, "chat") and hasattr(target.chat, "completions"): orig = target.chat.completions.create self._originals["chat.completions.create"] = orig - target.chat.completions.create = self._wrap_sync(orig) + target.chat.completions.create = self._wrap_sync( + "openai.chat.completions.create", orig + ) if hasattr(target.chat.completions, "acreate"): async_orig = target.chat.completions.acreate self._originals["chat.completions.acreate"] = async_orig - target.chat.completions.acreate = self._wrap_async(async_orig) + target.chat.completions.acreate = self._wrap_async( + "openai.chat.completions.create", async_orig + ) return target - def disconnect(self) -> None: - if self._client is None: - return - for key, orig in self._originals.items(): - try: - parts = key.split(".") - obj = self._client - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], orig) - except Exception: - log.warning("Could not restore %s", key) - self._client = None - self._originals.clear() - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - name="openai", - adapter_type="provider", - connected=self._client is not None, - ) - - def _wrap_sync(self, original: Any) -> Any: - def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return original(*args, **kwargs) - start = time.time() - try: - response = original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("openai.chat.completions.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "openai.chat.completions.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - def _wrap_async(self, original: Any) -> Any: - async def wrapped(*args: Any, **kwargs: Any) -> Any: - if _current_collector.get() is None: - return await original(*args, **kwargs) - start = time.time() - try: - response = await original(*args, **kwargs) - except Exception as exc: - latency_ms = (time.time() - start) * 1000 - emit_llm_error("openai.chat.completions.create", exc, latency_ms) - raise - latency_ms = (time.time() - start) * 1000 - emit_llm_events( - "openai.chat.completions.create", kwargs, response, - _extract_output, _extract_response_meta, _CAPTURE_PARAMS, latency_ms, - ) - return response - - return wrapped - - -def _extract_output(response: Any) -> Any: - try: - choices = response.choices - if choices: - msg = choices[0].message - return {"role": msg.role, "content": msg.content} - except (AttributeError, IndexError): - pass - return None - - -def _extract_response_meta(response: Any) -> Dict[str, Any]: - meta: Dict[str, Any] = {} - try: - usage = response.usage - if usage is not None: - meta["usage"] = { - "prompt_tokens": usage.prompt_tokens, - "completion_tokens": usage.completion_tokens, - "total_tokens": usage.total_tokens, - } - except AttributeError: - pass - try: - meta["response_model"] = response.model - except AttributeError: - pass - return meta - # --- Convenience API --- diff --git a/tests/instrument/adapters/__init__.py b/tests/instrument/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/frameworks/__init__.py b/tests/instrument/adapters/frameworks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/frameworks/conftest.py b/tests/instrument/adapters/frameworks/conftest.py new file mode 100644 index 0000000..fb8d90e --- /dev/null +++ b/tests/instrument/adapters/frameworks/conftest.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import json +from typing import Any, Dict +from unittest.mock import Mock + + +# Re-export from root conftest so framework tests can do `from .conftest import ...` +from ...conftest import find_event, find_events # noqa: F401 + + +def capture_framework_trace(mock_client: Mock) -> Dict[str, Any]: + """Capture the uploaded trace payload from a framework adapter. + + Accumulates events across multiple flushes (some adapters use + multiple collectors). + """ + uploaded: Dict[str, Any] = {"events": []} + + def _capture(path: str) -> None: + with open(path) as f: + data = json.load(f) + payload = data[0] + uploaded["trace_id"] = payload.get("trace_id") + uploaded["events"].extend(payload.get("events", [])) + uploaded["capture_config"] = payload.get("capture_config", {}) + uploaded["attestation"] = payload.get("attestation", {}) + + mock_client.traces.upload.side_effect = _capture + return uploaded diff --git a/tests/instrument/adapters/frameworks/test_langchain.py b/tests/instrument/adapters/frameworks/test_langchain.py new file mode 100644 index 0000000..d2a3057 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_langchain.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from uuid import uuid4 +from unittest.mock import Mock + +from langchain_core.callbacks import BaseCallbackHandler + +from layerlens.instrument.adapters.frameworks.langchain import LangChainCallbackHandler + +from .conftest import capture_framework_trace, find_event, find_events + + +# --------------------------------------------------------------------------- +# Sanity: real base class +# --------------------------------------------------------------------------- + + +class TestBaseClass: + def test_inherits_langchain_base(self): + assert issubclass(LangChainCallbackHandler, BaseCallbackHandler) + + def test_name(self): + handler = LangChainCallbackHandler(Mock()) + assert handler.name == "langchain" + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def test_chain_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "RunnableSequence", "id": ["RunnableSequence"]}, + {"question": "What is AI?"}, + run_id=chain_id, + ) + handler.on_chain_end({"output": "AI is..."}, run_id=chain_id) + + events = uploaded["events"] + agent_input = find_event(events, "agent.input") + assert agent_input["payload"]["name"] == "RunnableSequence" + assert agent_input["payload"]["input"] == {"question": "What is AI?"} + + agent_output = find_event(events, "agent.output") + assert agent_output["payload"]["status"] == "ok" + + def test_llm_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start( + {"name": "Chain"}, {"input": "x"}, run_id=chain_id, + ) + handler.on_llm_start( + {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, + ["What is AI?"], + run_id=llm_id, + parent_run_id=chain_id, + ) + + llm_response = Mock() + llm_response.generations = [[Mock(text="AI is...")]] + llm_response.llm_output = { + "token_usage": {"total_tokens": 50}, + "model_name": "gpt-4", + } + handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_chain_end({"output": "AI is..."}, run_id=chain_id) + + events = uploaded["events"] + + model_invokes = find_events(events, "model.invoke") + assert len(model_invokes) >= 1 + # Start event has name and messages + start_invoke = [m for m in model_invokes if m["payload"].get("name") == "ChatOpenAI"] + assert len(start_invoke) == 1 + # End event has model and output + end_invoke = [m for m in model_invokes if m["payload"].get("model") == "gpt-4"] + assert len(end_invoke) == 1 + assert end_invoke[0]["payload"]["output_message"] == "AI is..." + + cost = find_event(events, "cost.record") + assert cost["payload"]["total_tokens"] == 50 + + def test_chat_model_start(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + chat_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + msg = Mock() + msg.type = "human" + msg.content = "Hello" + handler.on_chat_model_start( + {"name": "ChatAnthropic"}, + [[msg]], + run_id=chat_id, + parent_run_id=chain_id, + ) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + invoke = find_event(events, "model.invoke") + assert invoke["payload"]["name"] == "ChatAnthropic" + assert invoke["payload"]["messages"] == [[{"type": "human", "content": "Hello"}]] + + +# --------------------------------------------------------------------------- +# Tool and retriever events +# --------------------------------------------------------------------------- + + +class TestToolsAndRetrievers: + def test_tool_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start( + {"name": "search"}, "query text", + run_id=tool_id, parent_run_id=chain_id, + ) + handler.on_tool_end("search results", run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["name"] == "search" + assert tool_call["payload"]["input"] == "query text" + + tool_result = find_event(events, "tool.result") + assert tool_result["payload"]["output"] == "search results" + + def test_retriever_lifecycle(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start( + {"name": "vectorstore"}, "query", + run_id=ret_id, parent_run_id=chain_id, + ) + docs = [Mock(page_content="doc text", metadata={"source": "a.txt"})] + handler.on_retriever_end(docs, run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + tool_call = find_event(events, "tool.call") + assert tool_call["payload"]["name"] == "vectorstore" + + tool_result = find_event(events, "tool.result") + output = tool_result["payload"]["output"] + assert output[0]["page_content"] == "doc text" + assert output[0]["metadata"] == {"source": "a.txt"} + + def test_combined_tools_and_retrievers(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("results", run_id=tool_id) + handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) + handler.on_retriever_end([Mock(page_content="d", metadata={})], run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + assert len(find_events(events, "tool.call")) == 2 + assert len(find_events(events, "tool.result")) == 2 + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + def test_chain_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) + handler.on_chain_error(ValueError("broke"), run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "broke" + assert error["payload"]["status"] == "error" + + def test_llm_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["prompt"], run_id=llm_id, parent_run_id=chain_id) + handler.on_llm_error(RuntimeError("timeout"), run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "timeout" + + def test_tool_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_error(RuntimeError("404"), run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "404" + + def test_retriever_error(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + ret_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {}, run_id=chain_id) + handler.on_retriever_start({"name": "vs"}, "q", run_id=ret_id, parent_run_id=chain_id) + handler.on_retriever_error(ConnectionError("down"), run_id=ret_id) + handler.on_chain_end({}, run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "down" + + +# --------------------------------------------------------------------------- +# Parent-child span relationships +# --------------------------------------------------------------------------- + + +class TestSpanRelationships: + def test_llm_parent_is_chain(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start( + {"name": "LLM"}, ["prompt"], + run_id=llm_id, parent_run_id=chain_id, + ) + llm_response = Mock() + llm_response.generations = [[Mock(text="out")]] + llm_response.llm_output = {} + handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + chain_input = find_event(events, "agent.input") + llm_invoke = [e for e in find_events(events, "model.invoke") if e["payload"].get("name") == "LLM"][0] + assert llm_invoke["parent_span_id"] == chain_input["span_id"] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_null_serialized(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + run_id = uuid4() + handler.on_chain_start(None, {"input": "x"}, run_id=run_id) + handler.on_chain_end({}, run_id=run_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "unknown" + + def test_empty_serialized_id(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + run_id = uuid4() + handler.on_chain_start({"id": ["FallbackName"]}, {}, run_id=run_id) + handler.on_chain_end({}, run_id=run_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "FallbackName" + + def test_llm_end_no_output(self, mock_client): + """LLM response with no generations should not crash.""" + uploaded = capture_framework_trace(mock_client) + handler = LangChainCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Chain"}, {}, run_id=chain_id) + handler.on_llm_start({"name": "LLM"}, ["p"], run_id=llm_id, parent_run_id=chain_id) + + empty_response = Mock() + empty_response.generations = [] + empty_response.llm_output = None + handler.on_llm_end(empty_response, run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + # Should complete without error — no model.invoke end event since no output/model + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info(self): + handler = LangChainCallbackHandler(Mock()) + info = handler.adapter_info() + assert info.name == "langchain" + assert info.adapter_type == "framework" diff --git a/tests/instrument/adapters/frameworks/test_langgraph.py b/tests/instrument/adapters/frameworks/test_langgraph.py new file mode 100644 index 0000000..7ff6e9d --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_langgraph.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from uuid import uuid4 +from unittest.mock import Mock + +from langchain_core.callbacks import BaseCallbackHandler + +from layerlens.instrument.adapters.frameworks.langgraph import LangGraphCallbackHandler + +from .conftest import capture_framework_trace, find_event, find_events + + +# --------------------------------------------------------------------------- +# Sanity: real base class +# --------------------------------------------------------------------------- + + +class TestBaseClass: + def test_inherits_langchain_base(self): + assert issubclass(LangGraphCallbackHandler, BaseCallbackHandler) + + def test_name(self): + handler = LangGraphCallbackHandler(Mock()) + assert handler.name == "langgraph" + + +# --------------------------------------------------------------------------- +# Inherited LangChain behavior +# --------------------------------------------------------------------------- + + +class TestInheritedBehavior: + """LangGraph inherits all LangChain callbacks except on_chain_start.""" + + def test_llm_events_inherited(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + llm_id = uuid4() + + handler.on_chain_start({"name": "Graph"}, {}, run_id=chain_id) + handler.on_llm_start( + {"name": "ChatOpenAI"}, ["prompt"], + run_id=llm_id, parent_run_id=chain_id, + ) + llm_response = Mock() + llm_response.generations = [[Mock(text="output")]] + llm_response.llm_output = {"model_name": "gpt-4", "token_usage": {"total_tokens": 10}} + handler.on_llm_end(llm_response, run_id=llm_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + assert len(find_events(events, "model.invoke")) >= 1 + assert find_event(events, "cost.record")["payload"]["total_tokens"] == 10 + + def test_tool_events_inherited(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start({"name": "Graph"}, {}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "q", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("results", run_id=tool_id) + handler.on_chain_end({}, run_id=chain_id) + + events = uploaded["events"] + assert find_event(events, "tool.call")["payload"]["name"] == "search" + assert find_event(events, "tool.result")["payload"]["output"] == "results" + + def test_error_handling_inherited(self, mock_client): + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start({"name": "Graph"}, {}, run_id=chain_id) + handler.on_chain_error(RuntimeError("graph failed"), run_id=chain_id) + + error = find_event(uploaded["events"], "agent.error") + assert error["payload"]["error"] == "graph failed" + + +# --------------------------------------------------------------------------- +# LangGraph-specific: on_chain_start node extraction +# --------------------------------------------------------------------------- + + +class TestNodeExtraction: + def test_extracts_node_from_tags(self, mock_client): + """LangGraph passes node names as plain tags (no colon).""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "RunnableSequence"}, + {"input": "hello"}, + run_id=chain_id, + tags=["graph:step:1", "retriever_node"], + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "retriever_node" + + def test_extracts_node_from_metadata(self, mock_client): + """LangGraph puts node name in metadata.langgraph_node.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "RunnableSequence"}, + {"input": "hello"}, + run_id=chain_id, + metadata={"langgraph_node": "agent_node"}, + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "agent_node" + + def test_metadata_overrides_tags(self, mock_client): + """When both tags and metadata provide a node name, metadata wins.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "Seq"}, + {}, + run_id=chain_id, + tags=["tag_node"], + metadata={"langgraph_node": "meta_node"}, + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "meta_node" + + def test_falls_back_to_serialized_name(self, mock_client): + """Without tags or metadata, falls back to serialized name.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "MyCustomChain"}, + {}, + run_id=chain_id, + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + assert agent_input["payload"]["name"] == "MyCustomChain" + + def test_skips_graph_step_tags(self, mock_client): + """Tags starting with 'graph:step:' should be skipped.""" + uploaded = capture_framework_trace(mock_client) + handler = LangGraphCallbackHandler(mock_client) + + chain_id = uuid4() + handler.on_chain_start( + {"name": "Default"}, + {}, + run_id=chain_id, + tags=["graph:step:0", "graph:step:1"], + ) + handler.on_chain_end({}, run_id=chain_id) + + agent_input = find_event(uploaded["events"], "agent.input") + # No usable tags — falls back to serialized name + assert agent_input["payload"]["name"] == "Default" + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info(self): + handler = LangGraphCallbackHandler(Mock()) + info = handler.adapter_info() + assert info.name == "langgraph" + assert info.adapter_type == "framework" diff --git a/tests/instrument/adapters/providers/__init__.py b/tests/instrument/adapters/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/providers/conftest.py b/tests/instrument/adapters/providers/conftest.py new file mode 100644 index 0000000..48aa1e7 --- /dev/null +++ b/tests/instrument/adapters/providers/conftest.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types import CompletionUsage + +from anthropic.types import Message, TextBlock, Usage + + +def make_openai_response( + content: str = "Hello!", + role: str = "assistant", + model: str = "gpt-4", + prompt_tokens: int = 10, + completion_tokens: int = 5, + total_tokens: int = 15, +) -> ChatCompletion: + """Build a real OpenAI ChatCompletion response.""" + return ChatCompletion( + id="chatcmpl-test", + model=model, + object="chat.completion", + created=1700000000, + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage(role=role, content=content), + ) + ], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + ) + + +def make_openai_response_no_usage(model: str = "gpt-4") -> ChatCompletion: + """Build an OpenAI response with no usage data.""" + return ChatCompletion( + id="chatcmpl-test", + model=model, + object="chat.completion", + created=1700000000, + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage(role="assistant", content="Hello!"), + ) + ], + usage=None, + ) + + +def make_openai_response_empty_choices(model: str = "gpt-4") -> ChatCompletion: + """Build an OpenAI response with empty choices.""" + return ChatCompletion( + id="chatcmpl-test", + model=model, + object="chat.completion", + created=1700000000, + choices=[], + usage=None, + ) + + +def make_anthropic_response( + text: str = "I'm Claude!", + model: str = "claude-3-opus-20240229", + input_tokens: int = 20, + output_tokens: int = 10, + stop_reason: str = "end_turn", +) -> Message: + """Build a real Anthropic Message response.""" + return Message( + id="msg-test", + type="message", + role="assistant", + model=model, + content=[TextBlock(type="text", text=text)], + usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens), + stop_reason=stop_reason, + ) + + +def make_anthropic_response_empty_content( + model: str = "claude-3-opus-20240229", +) -> Message: + """Build an Anthropic response with empty content.""" + return Message( + id="msg-test", + type="message", + role="assistant", + model=model, + content=[], + usage=Usage(input_tokens=0, output_tokens=0), + stop_reason="end_turn", + ) diff --git a/tests/instrument/adapters/providers/test_anthropic.py b/tests/instrument/adapters/providers/test_anthropic.py new file mode 100644 index 0000000..7dcf22f --- /dev/null +++ b/tests/instrument/adapters/providers/test_anthropic.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from layerlens.instrument import trace +from layerlens.instrument.adapters.providers.anthropic import ( + AnthropicProvider, + instrument_anthropic, + uninstrument_anthropic, +) + +from ...conftest import find_event +from .conftest import make_anthropic_response, make_anthropic_response_empty_content + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def test_model_invoke_and_cost_record(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=make_anthropic_response()) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + r = anthropic_client.messages.create( + model="claude-3-opus-20240229", max_tokens=1024, + messages=[{"role": "user", "content": "Hi"}], + ) + return r.content[0].text + + my_agent() + events = capture_trace["events"] + + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["name"] == "anthropic.messages.create" + assert model_invoke["payload"]["response_model"] == "claude-3-opus-20240229" + assert model_invoke["payload"]["output_message"]["type"] == "text" + assert model_invoke["payload"]["output_message"]["text"] == "I'm Claude!" + assert model_invoke["payload"]["usage"]["input_tokens"] == 20 + assert model_invoke["payload"]["usage"]["output_tokens"] == 10 + assert model_invoke["payload"]["stop_reason"] == "end_turn" + assert "latency_ms" in model_invoke["payload"] + + cost = find_event(events, "cost.record") + assert cost["payload"]["provider"] == "anthropic" + assert cost["payload"]["input_tokens"] == 20 + assert cost["payload"]["output_tokens"] == 10 + + def test_error_emits_agent_error(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(side_effect=RuntimeError("overloaded")) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + try: + anthropic_client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + events = capture_trace["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "overloaded" + assert "latency_ms" in error["payload"] + + +# --------------------------------------------------------------------------- +# Passthrough / no-op behavior +# --------------------------------------------------------------------------- + + +class TestPassthrough: + def test_no_op_outside_trace(self): + response = make_anthropic_response() + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=response) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + result = anthropic_client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[]) + assert result.content[0].text == "I'm Claude!" + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def test_disconnect_restores_original(self): + anthropic_client = Mock() + original = anthropic_client.messages.create + + provider = AnthropicProvider() + provider.connect(anthropic_client) + assert anthropic_client.messages.create is not original + + provider.disconnect() + assert anthropic_client.messages.create is original + + def test_disconnect_when_not_connected(self): + provider = AnthropicProvider() + provider.disconnect() # should not raise + + def test_double_connect_replaces_wrapper(self): + anthropic_client = Mock() + provider = AnthropicProvider() + provider.connect(anthropic_client) + first_wrapper = anthropic_client.messages.create + + provider2 = AnthropicProvider() + provider2.connect(anthropic_client) + assert anthropic_client.messages.create is not first_wrapper + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info_before_connect(self): + provider = AnthropicProvider() + info = provider.adapter_info() + assert info.name == "anthropic" + assert info.adapter_type == "provider" + assert info.connected is False + + def test_info_after_connect(self): + provider = AnthropicProvider() + provider.connect(Mock()) + info = provider.adapter_info() + assert info.connected is True + + def test_info_after_disconnect(self): + provider = AnthropicProvider() + provider.connect(Mock()) + provider.disconnect() + assert provider.adapter_info().connected is False + + +# --------------------------------------------------------------------------- +# Convenience API +# --------------------------------------------------------------------------- + + +class TestConvenienceAPI: + def test_instrument_and_uninstrument(self): + anthropic_client = Mock() + original = anthropic_client.messages.create + instrument_anthropic(anthropic_client) + assert anthropic_client.messages.create is not original + uninstrument_anthropic() + + +# --------------------------------------------------------------------------- +# capture_params filtering +# --------------------------------------------------------------------------- + + +class TestCaptureParams: + def test_captured_params_included(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=make_anthropic_response()) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + anthropic_client.messages.create( + model="claude-3-opus-20240229", max_tokens=1024, temperature=0.5, top_k=40, + messages=[{"role": "user", "content": "Hi"}], + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert params["model"] == "claude-3-opus-20240229" + assert params["max_tokens"] == 1024 + assert params["temperature"] == 0.5 + assert params["top_k"] == 40 + + def test_non_captured_params_excluded(self, mock_client, capture_trace): + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=make_anthropic_response()) + + provider = AnthropicProvider() + provider.connect(anthropic_client) + + @trace(mock_client) + def my_agent(): + anthropic_client.messages.create( + model="claude-3-opus-20240229", max_tokens=1024, + messages=[], stream=True, metadata={"user_id": "abc"}, + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert "stream" not in params + assert "metadata" not in params + assert "messages" not in params + + +# --------------------------------------------------------------------------- +# Extractor edge cases (using real SDK types) +# --------------------------------------------------------------------------- + + +class TestExtractors: + def test_extract_output_normal(self): + r = make_anthropic_response(text="Hello world") + output = AnthropicProvider.extract_output(r) + assert output == {"type": "text", "text": "Hello world"} + + def test_extract_output_empty_content(self): + r = make_anthropic_response_empty_content() + assert AnthropicProvider.extract_output(r) is None + + def test_extract_meta_normal(self): + r = make_anthropic_response( + model="claude-3-5-sonnet-20241022", + input_tokens=100, output_tokens=50, + stop_reason="max_tokens", + ) + meta = AnthropicProvider.extract_meta(r) + assert meta["response_model"] == "claude-3-5-sonnet-20241022" + assert meta["usage"]["input_tokens"] == 100 + assert meta["usage"]["output_tokens"] == 50 + assert meta["stop_reason"] == "max_tokens" diff --git a/tests/instrument/adapters/providers/test_litellm.py b/tests/instrument/adapters/providers/test_litellm.py new file mode 100644 index 0000000..50588ce --- /dev/null +++ b/tests/instrument/adapters/providers/test_litellm.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import sys +import types +from unittest.mock import Mock + +from layerlens.instrument import trace +from layerlens.instrument.adapters.providers.litellm import ( + LiteLLMProvider, + instrument_litellm, + uninstrument_litellm, +) + +from ...conftest import find_event +from .conftest import make_openai_response, make_openai_response_empty_choices, make_openai_response_no_usage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _install_mock_litellm(response=None): + """Inject a fake litellm module into sys.modules with real OpenAI response types.""" + mock_mod = types.ModuleType("litellm") + mock_mod.completion = Mock(return_value=response or make_openai_response()) + mock_mod.acompletion = Mock() + sys.modules["litellm"] = mock_mod + return mock_mod + + +def _remove_mock_litellm(): + uninstrument_litellm() + for key in list(sys.modules.keys()): + if key.startswith("litellm"): + del sys.modules[key] + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_model_invoke_and_cost_record(self, mock_client, capture_trace): + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + r = litellm.completion( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + return r.choices[0].message.content + + my_agent() + events = capture_trace["events"] + + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["name"] == "litellm.completion" + assert model_invoke["payload"]["model"] == "gpt-4" + assert model_invoke["payload"]["output_message"]["content"] == "Hello!" + assert model_invoke["payload"]["usage"]["total_tokens"] == 15 + assert "latency_ms" in model_invoke["payload"] + + cost = find_event(events, "cost.record") + assert cost["payload"]["provider"] == "litellm" + assert cost["payload"]["total_tokens"] == 15 + + def test_error_emits_agent_error(self, mock_client, capture_trace): + self.mock_litellm.completion = Mock(side_effect=RuntimeError("rate limited")) + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + try: + litellm.completion(model="gpt-4", messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + events = capture_trace["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "rate limited" + assert "latency_ms" in error["payload"] + + +# --------------------------------------------------------------------------- +# Passthrough / no-op behavior +# --------------------------------------------------------------------------- + + +class TestPassthrough: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_no_op_outside_trace(self): + instrument_litellm() + import litellm + result = litellm.completion(model="gpt-4", messages=[]) + assert result.choices[0].message.content == "Hello!" + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_uninstrument_restores_original(self): + original = self.mock_litellm.completion + instrument_litellm() + assert self.mock_litellm.completion is not original + uninstrument_litellm() + assert self.mock_litellm.completion is original + + def test_disconnect_when_not_connected(self): + provider = LiteLLMProvider() + provider.disconnect() # should not raise + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_info_before_connect(self): + provider = LiteLLMProvider() + info = provider.adapter_info() + assert info.name == "litellm" + assert info.adapter_type == "provider" + assert info.connected is False + + def test_info_after_connect(self): + provider = LiteLLMProvider() + provider.connect() + info = provider.adapter_info() + assert info.connected is True + + def test_info_after_disconnect(self): + provider = LiteLLMProvider() + provider.connect() + provider.disconnect() + assert provider.adapter_info().connected is False + + +# --------------------------------------------------------------------------- +# Convenience API +# --------------------------------------------------------------------------- + + +class TestConvenienceAPI: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_instrument_and_uninstrument(self): + original = self.mock_litellm.completion + instrument_litellm() + assert self.mock_litellm.completion is not original + uninstrument_litellm() + assert self.mock_litellm.completion is original + + +# --------------------------------------------------------------------------- +# capture_params filtering +# --------------------------------------------------------------------------- + + +class TestCaptureParams: + def setup_method(self): + self.mock_litellm = _install_mock_litellm() + + def teardown_method(self): + _remove_mock_litellm() + + def test_captured_params_included(self, mock_client, capture_trace): + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + litellm.completion( + model="gpt-4", temperature=0.7, top_p=0.9, + messages=[{"role": "user", "content": "Hi"}], + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert params["model"] == "gpt-4" + assert params["temperature"] == 0.7 + assert params["top_p"] == 0.9 + + def test_non_captured_params_excluded(self, mock_client, capture_trace): + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + litellm.completion( + model="gpt-4", messages=[], stream=True, api_key="sk-123", + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert "stream" not in params + assert "api_key" not in params + assert "messages" not in params + + +# --------------------------------------------------------------------------- +# Extractor edge cases (LiteLLM reuses OpenAI extractors, real types) +# --------------------------------------------------------------------------- + + +class TestExtractors: + def test_extract_output_normal(self): + r = make_openai_response(content="LiteLLM response") + output = LiteLLMProvider.extract_output(r) + assert output == {"role": "assistant", "content": "LiteLLM response"} + + def test_extract_output_empty_choices(self): + r = make_openai_response_empty_choices() + assert LiteLLMProvider.extract_output(r) is None + + def test_extract_meta_normal(self): + r = make_openai_response(model="gpt-4o", prompt_tokens=5, completion_tokens=3, total_tokens=8) + meta = LiteLLMProvider.extract_meta(r) + assert meta["response_model"] == "gpt-4o" + assert meta["usage"]["total_tokens"] == 8 + + def test_extract_meta_no_usage(self): + r = make_openai_response_no_usage() + meta = LiteLLMProvider.extract_meta(r) + assert "usage" not in meta diff --git a/tests/instrument/adapters/providers/test_openai.py b/tests/instrument/adapters/providers/test_openai.py new file mode 100644 index 0000000..42641b4 --- /dev/null +++ b/tests/instrument/adapters/providers/test_openai.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from layerlens.instrument import trace +from layerlens.instrument.adapters.providers.openai import ( + OpenAIProvider, + instrument_openai, + uninstrument_openai, +) + +from ...conftest import find_event +from .conftest import ( + make_openai_response, + make_openai_response_no_usage, + make_openai_response_empty_choices, +) + + +# --------------------------------------------------------------------------- +# Emit events +# --------------------------------------------------------------------------- + + +class TestEmitsEvents: + def test_model_invoke_and_cost_record(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=make_openai_response()) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + r = openai_client.chat.completions.create( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + return r.choices[0].message.content + + my_agent() + events = capture_trace["events"] + + model_invoke = find_event(events, "model.invoke") + assert model_invoke["payload"]["name"] == "openai.chat.completions.create" + assert model_invoke["payload"]["model"] == "gpt-4" + assert model_invoke["payload"]["output_message"]["role"] == "assistant" + assert model_invoke["payload"]["output_message"]["content"] == "Hello!" + assert model_invoke["payload"]["usage"]["prompt_tokens"] == 10 + assert model_invoke["payload"]["usage"]["completion_tokens"] == 5 + assert model_invoke["payload"]["usage"]["total_tokens"] == 15 + assert "latency_ms" in model_invoke["payload"] + + cost = find_event(events, "cost.record") + assert cost["payload"]["provider"] == "openai" + assert cost["payload"]["total_tokens"] == 15 + + def test_error_emits_agent_error(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(side_effect=RuntimeError("API error")) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + try: + openai_client.chat.completions.create(model="gpt-4", messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + events = capture_trace["events"] + error = find_event(events, "agent.error") + assert error["payload"]["error"] == "API error" + assert "latency_ms" in error["payload"] + + +# --------------------------------------------------------------------------- +# Passthrough / no-op behavior +# --------------------------------------------------------------------------- + + +class TestPassthrough: + def test_no_op_outside_trace(self): + response = make_openai_response() + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=response) + + provider = OpenAIProvider() + provider.connect(openai_client) + + result = openai_client.chat.completions.create(model="gpt-4", messages=[]) + assert result.choices[0].message.content == "Hello!" + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + def test_disconnect_restores_original(self): + openai_client = Mock() + original = openai_client.chat.completions.create + + provider = OpenAIProvider() + provider.connect(openai_client) + assert openai_client.chat.completions.create is not original + + provider.disconnect() + assert openai_client.chat.completions.create is original + + def test_disconnect_when_not_connected(self): + provider = OpenAIProvider() + provider.disconnect() # should not raise + + def test_double_connect_replaces_wrapper(self): + openai_client = Mock() + provider = OpenAIProvider() + provider.connect(openai_client) + first_wrapper = openai_client.chat.completions.create + + provider2 = OpenAIProvider() + provider2.connect(openai_client) + assert openai_client.chat.completions.create is not first_wrapper + + +# --------------------------------------------------------------------------- +# adapter_info +# --------------------------------------------------------------------------- + + +class TestAdapterInfo: + def test_info_before_connect(self): + provider = OpenAIProvider() + info = provider.adapter_info() + assert info.name == "openai" + assert info.adapter_type == "provider" + assert info.connected is False + + def test_info_after_connect(self): + provider = OpenAIProvider() + provider.connect(Mock()) + info = provider.adapter_info() + assert info.connected is True + + def test_info_after_disconnect(self): + provider = OpenAIProvider() + provider.connect(Mock()) + provider.disconnect() + assert provider.adapter_info().connected is False + + +# --------------------------------------------------------------------------- +# Convenience API +# --------------------------------------------------------------------------- + + +class TestConvenienceAPI: + def test_instrument_and_uninstrument(self): + openai_client = Mock() + original = openai_client.chat.completions.create + instrument_openai(openai_client) + assert openai_client.chat.completions.create is not original + uninstrument_openai() + + +# --------------------------------------------------------------------------- +# capture_params filtering +# --------------------------------------------------------------------------- + + +class TestCaptureParams: + def test_captured_params_included(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=make_openai_response()) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + openai_client.chat.completions.create( + model="gpt-4", temperature=0.7, top_p=0.9, + messages=[{"role": "user", "content": "Hi"}], + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert params["model"] == "gpt-4" + assert params["temperature"] == 0.7 + assert params["top_p"] == 0.9 + + def test_non_captured_params_excluded(self, mock_client, capture_trace): + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=make_openai_response()) + + provider = OpenAIProvider() + provider.connect(openai_client) + + @trace(mock_client) + def my_agent(): + openai_client.chat.completions.create( + model="gpt-4", messages=[], stream=True, user="test-user", + ) + return "done" + + my_agent() + params = find_event(capture_trace["events"], "model.invoke")["payload"]["parameters"] + assert "stream" not in params + assert "user" not in params + assert "messages" not in params + + +# --------------------------------------------------------------------------- +# Extractor edge cases (using real SDK types) +# --------------------------------------------------------------------------- + + +class TestExtractors: + def test_extract_output_normal(self): + r = make_openai_response(content="Hi there", role="assistant") + output = OpenAIProvider.extract_output(r) + assert output == {"role": "assistant", "content": "Hi there"} + + def test_extract_output_empty_choices(self): + r = make_openai_response_empty_choices() + assert OpenAIProvider.extract_output(r) is None + + def test_extract_meta_normal(self): + r = make_openai_response(model="gpt-4o", prompt_tokens=5, completion_tokens=3, total_tokens=8) + meta = OpenAIProvider.extract_meta(r) + assert meta["response_model"] == "gpt-4o" + assert meta["usage"]["prompt_tokens"] == 5 + assert meta["usage"]["completion_tokens"] == 3 + assert meta["usage"]["total_tokens"] == 8 + + def test_extract_meta_no_usage(self): + r = make_openai_response_no_usage(model="gpt-4") + meta = OpenAIProvider.extract_meta(r) + assert "usage" not in meta + assert meta["response_model"] == "gpt-4" diff --git a/tests/instrument/test_registry.py b/tests/instrument/adapters/test_registry.py similarity index 100% rename from tests/instrument/test_registry.py rename to tests/instrument/adapters/test_registry.py diff --git a/tests/instrument/test_adapters.py b/tests/instrument/test_adapters.py deleted file mode 100644 index 11752c2..0000000 --- a/tests/instrument/test_adapters.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -import json -import sys -import types -import importlib -from uuid import uuid4 -from unittest.mock import Mock - -from .conftest import find_events, find_event - - -def _capture_framework_trace(mock_client): - """Helper to capture uploaded trace from framework adapters (which manage their own collector).""" - uploaded = {} - - def _capture(path): - with open(path) as f: - data = json.load(f) - payload = data[0] - uploaded["trace_id"] = payload.get("trace_id") - uploaded["events"] = payload.get("events", []) - uploaded["capture_config"] = payload.get("capture_config", {}) - uploaded["attestation"] = payload.get("attestation", {}) - - mock_client.traces.upload.side_effect = _capture - return uploaded - - -class TestLangChainAdapter: - def _setup_langchain_mock(self): - mock_lc_core = types.ModuleType("langchain_core") - mock_lc_callbacks = types.ModuleType("langchain_core.callbacks") - - class FakeBaseCallbackHandler: - def __init__(self): - pass - - mock_lc_callbacks.BaseCallbackHandler = FakeBaseCallbackHandler - mock_lc_core.callbacks = mock_lc_callbacks - - sys.modules["langchain_core"] = mock_lc_core - sys.modules["langchain_core.callbacks"] = mock_lc_callbacks - - def _teardown_langchain_mock(self): - for key in list(sys.modules.keys()): - if key.startswith("langchain_core"): - del sys.modules[key] - - def _get_handler(self, mock_client): - from layerlens.instrument.adapters.frameworks import langchain as lc_mod - - importlib.reload(lc_mod) - return lc_mod.LangChainCallbackHandler(mock_client) - - def test_emits_flat_events(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - chain_run_id = uuid4() - llm_run_id = uuid4() - - handler.on_chain_start( - {"name": "RunnableSequence", "id": ["RunnableSequence"]}, - {"question": "What is AI?"}, - run_id=chain_run_id, - ) - handler.on_llm_start( - {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, - ["What is AI?"], - run_id=llm_run_id, - parent_run_id=chain_run_id, - ) - - llm_response = Mock() - llm_response.generations = [[Mock(text="AI is...")]] - llm_response.llm_output = {"token_usage": {"total_tokens": 50}, "model_name": "gpt-4"} - handler.on_llm_end(llm_response, run_id=llm_run_id) - handler.on_chain_end({"output": "AI is..."}, run_id=chain_run_id) - - events = uploaded["events"] - # Should have: agent.input, model.invoke (start), model.invoke (end), cost.record, agent.output - agent_input = find_event(events, "agent.input") - assert agent_input["payload"]["name"] == "RunnableSequence" - assert agent_input["payload"]["input"] == {"question": "What is AI?"} - - model_invokes = find_events(events, "model.invoke") - assert len(model_invokes) >= 1 - # The end event has model name and output - end_invoke = [m for m in model_invokes if m["payload"].get("model") == "gpt-4"] - assert len(end_invoke) == 1 - assert end_invoke[0]["payload"]["output_message"] == "AI is..." - - cost = find_event(events, "cost.record") - assert cost["payload"]["total_tokens"] == 50 - - agent_output = find_event(events, "agent.output") - assert agent_output["payload"]["status"] == "ok" - - # Parent-child: LLM events should reference chain's span_id as parent - chain_span_id = agent_input["span_id"] - llm_start = [m for m in model_invokes if m["payload"].get("name") == "ChatOpenAI"][0] - assert llm_start["parent_span_id"] == chain_span_id - finally: - self._teardown_langchain_mock() - - def test_tracks_tools_and_retrievers(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - chain_id = uuid4() - tool_id = uuid4() - retriever_id = uuid4() - - handler.on_chain_start({"name": "Agent"}, {"input": "test"}, run_id=chain_id) - handler.on_tool_start({"name": "search"}, "query", run_id=tool_id, parent_run_id=chain_id) - handler.on_tool_end("results", run_id=tool_id) - handler.on_retriever_start({"name": "vectorstore"}, "query", run_id=retriever_id, parent_run_id=chain_id) - - docs = [Mock(page_content="doc1", metadata={"source": "a"})] - handler.on_retriever_end(docs, run_id=retriever_id) - handler.on_chain_end({"output": "done"}, run_id=chain_id) - - events = uploaded["events"] - tool_calls = find_events(events, "tool.call") - assert len(tool_calls) == 2 # tool + retriever both emit tool.call - tool_results = find_events(events, "tool.result") - assert len(tool_results) == 2 - finally: - self._teardown_langchain_mock() - - def test_error_on_chain(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - chain_id = uuid4() - handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) - handler.on_chain_error(ValueError("broke"), run_id=chain_id) - - events = uploaded["events"] - error = find_event(events, "agent.error") - assert error["payload"]["error"] == "broke" - assert error["payload"]["status"] == "error" - finally: - self._teardown_langchain_mock() - - def test_null_serialized_handled(self, mock_client): - self._setup_langchain_mock() - try: - uploaded = _capture_framework_trace(mock_client) - handler = self._get_handler(mock_client) - - run_id = uuid4() - handler.on_chain_start(None, {"input": "x"}, run_id=run_id) - handler.on_chain_end({"output": "done"}, run_id=run_id) - - events = uploaded["events"] - agent_input = find_event(events, "agent.input") - assert agent_input["payload"]["name"] == "unknown" - finally: - self._teardown_langchain_mock() diff --git a/tests/instrument/test_capture_config.py b/tests/instrument/test_capture_config.py index a70dfb4..5b00390 100644 --- a/tests/instrument/test_capture_config.py +++ b/tests/instrument/test_capture_config.py @@ -5,9 +5,9 @@ import pytest -from layerlens.instrument import trace, CaptureConfig -from .conftest import find_events, find_event +from layerlens.instrument import CaptureConfig, trace +from .conftest import find_event, find_events # --------------------------------------------------------------------------- # CaptureConfig unit tests diff --git a/tests/instrument/test_core.py b/tests/instrument/test_core.py index 89ed459..d16277d 100644 --- a/tests/instrument/test_core.py +++ b/tests/instrument/test_core.py @@ -1,12 +1,11 @@ from __future__ import annotations -import os - import pytest -from layerlens.instrument import span, emit, trace -from layerlens.instrument._context import _current_collector, _current_span_id -from .conftest import find_events, find_event +from layerlens.instrument import emit, span, trace +from layerlens.instrument._context import _current_span_id, _current_collector + +from .conftest import find_event class TestTraceDecorator: diff --git a/tests/instrument/test_providers.py b/tests/instrument/test_providers.py deleted file mode 100644 index ede576f..0000000 --- a/tests/instrument/test_providers.py +++ /dev/null @@ -1,220 +0,0 @@ -from __future__ import annotations - -import sys -import types -from unittest.mock import Mock - -from layerlens.instrument import trace -from .conftest import find_events, find_event - - -def _openai_response(): - r = Mock() - r.choices = [Mock()] - r.choices[0].message = Mock() - r.choices[0].message.role = "assistant" - r.choices[0].message.content = "Hello!" - r.usage = Mock() - r.usage.prompt_tokens = 10 - r.usage.completion_tokens = 5 - r.usage.total_tokens = 15 - r.model = "gpt-4" - return r - - -def _anthropic_response(): - r = Mock() - block = Mock() - block.type = "text" - block.text = "I'm Claude!" - r.content = [block] - r.usage = Mock() - r.usage.input_tokens = 20 - r.usage.output_tokens = 10 - r.model = "claude-3-opus" - r.stop_reason = "end_turn" - return r - - -class TestOpenAIProvider: - def test_instrument_emits_events(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - openai_client.chat.completions.create = Mock(return_value=_openai_response()) - - provider = OpenAIProvider() - provider.connect(openai_client) - - @trace(mock_client) - def my_agent(): - return ( - openai_client.chat.completions.create(model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) - .choices[0] - .message.content - ) - - my_agent() - events = capture_trace["events"] - model_invoke = find_event(events, "model.invoke") - assert model_invoke["payload"]["name"] == "openai.chat.completions.create" - assert model_invoke["payload"]["parameters"]["model"] == "gpt-4" - assert model_invoke["payload"]["usage"]["total_tokens"] == 15 - assert model_invoke["payload"]["output_message"]["content"] == "Hello!" - - cost = find_event(events, "cost.record") - assert cost["payload"]["provider"] == "openai" - assert cost["payload"]["total_tokens"] == 15 - - def test_passthrough_without_trace(self): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - openai_client.chat.completions.create = Mock(return_value=_openai_response()) - - provider = OpenAIProvider() - provider.connect(openai_client) - - result = openai_client.chat.completions.create(model="gpt-4", messages=[]) - assert result.choices[0].message.content == "Hello!" - - def test_disconnect_restores(self): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - original = openai_client.chat.completions.create - - provider = OpenAIProvider() - provider.connect(openai_client) - assert openai_client.chat.completions.create is not original - - provider.disconnect() - assert openai_client.chat.completions.create is original - - def test_instrument_convenience_function(self): - from layerlens.instrument.adapters.providers.openai import instrument_openai, uninstrument_openai - - openai_client = Mock() - original = openai_client.chat.completions.create - instrument_openai(openai_client) - assert openai_client.chat.completions.create is not original - uninstrument_openai() - - -class TestAnthropicProvider: - def test_instrument_emits_events(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.anthropic import AnthropicProvider - - anthropic_client = Mock() - anthropic_client.messages.create = Mock(return_value=_anthropic_response()) - - provider = AnthropicProvider() - provider.connect(anthropic_client) - - @trace(mock_client) - def my_agent(): - return ( - anthropic_client.messages.create( - model="claude-3-opus", max_tokens=1024, messages=[{"role": "user", "content": "Hi"}] - ) - .content[0] - .text - ) - - my_agent() - events = capture_trace["events"] - model_invoke = find_event(events, "model.invoke") - assert model_invoke["payload"]["output_message"]["text"] == "I'm Claude!" - assert model_invoke["payload"]["usage"]["input_tokens"] == 20 - assert model_invoke["payload"]["response_model"] == "claude-3-opus" - assert model_invoke["payload"]["stop_reason"] == "end_turn" - - def test_disconnect_restores(self): - from layerlens.instrument.adapters.providers.anthropic import AnthropicProvider - - anthropic_client = Mock() - original = anthropic_client.messages.create - - provider = AnthropicProvider() - provider.connect(anthropic_client) - provider.disconnect() - assert anthropic_client.messages.create is original - - -class TestLiteLLMProvider: - def setup_method(self): - self.mock_litellm = types.ModuleType("litellm") - self.mock_litellm.completion = Mock(return_value=_openai_response()) - self.mock_litellm.acompletion = Mock() - sys.modules["litellm"] = self.mock_litellm - - def teardown_method(self): - from layerlens.instrument.adapters.providers.litellm import uninstrument_litellm - - uninstrument_litellm() - for key in list(sys.modules.keys()): - if key.startswith("litellm"): - del sys.modules[key] - - def test_instrument_emits_events(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.litellm import instrument_litellm - - instrument_litellm() - - @trace(mock_client) - def my_agent(): - import litellm - - return ( - litellm.completion(model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) - .choices[0] - .message.content - ) - - my_agent() - events = capture_trace["events"] - model_invoke = find_event(events, "model.invoke") - assert model_invoke["payload"]["name"] == "litellm.completion" - assert model_invoke["payload"]["parameters"]["model"] == "gpt-4" - - def test_passthrough_without_trace(self): - from layerlens.instrument.adapters.providers.litellm import instrument_litellm - - instrument_litellm() - import litellm - - result = litellm.completion(model="gpt-4", messages=[]) - assert result.choices[0].message.content == "Hello!" - - def test_uninstrument(self): - from layerlens.instrument.adapters.providers.litellm import instrument_litellm, uninstrument_litellm - - original = self.mock_litellm.completion - instrument_litellm() - assert self.mock_litellm.completion is not original - uninstrument_litellm() - assert self.mock_litellm.completion is original - - -class TestProviderErrorHandling: - def test_error_emits_event(self, mock_client, capture_trace): - from layerlens.instrument.adapters.providers.openai import OpenAIProvider - - openai_client = Mock() - openai_client.chat.completions.create = Mock(side_effect=RuntimeError("API error")) - - provider = OpenAIProvider() - provider.connect(openai_client) - - @trace(mock_client) - def my_agent(): - try: - openai_client.chat.completions.create(model="gpt-4", messages=[]) - except RuntimeError: - pass - return "recovered" - - my_agent() - events = capture_trace["events"] - error = find_event(events, "agent.error") - assert error["payload"]["error"] == "API error" diff --git a/tests/instrument/test_types.py b/tests/instrument/test_types.py index 63927e0..618ebd0 100644 --- a/tests/instrument/test_types.py +++ b/tests/instrument/test_types.py @@ -1,7 +1,7 @@ from __future__ import annotations from layerlens.instrument._span import span -from layerlens.instrument._context import _current_span_id, _parent_span_id, _current_span_name +from layerlens.instrument._context import _parent_span_id, _current_span_id, _current_span_name class TestSpan: