Skip to content

Commit

Permalink
implement PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 8, 2024
1 parent a397172 commit b12a62a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
9 changes: 5 additions & 4 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def extract_attrs_for_action_executor(
:param action_call: The `ActionCall` argument.
:return: A dictionary containing the attributes.
"""
return {
"next_action": action_call["next_action"],
"sender_id": action_call["sender_id"],
}
attr = {"sender_id": action_call["sender_id"]}
action_name = action_call.get("next_action")
if action_name:
attr["action_name"] = action_name
return attr
35 changes: 34 additions & 1 deletion rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
return async_wrapper


def traceable(
fn: Callable[[T, Any, Any], S],
tracer: Tracer,
attr_extractor: Optional[Callable[[T, Any, Any], Dict[str, Any]]],
) -> Callable[[T, Any, Any], S]:
"""Wrap a non-`async` function by tracing functionality.
:param fn: The function to be wrapped.
:param tracer: The `Tracer` that shall be used for tracing this function.
:param attr_extractor: A function that is applied to the function's instance and
the function's arguments.
:return: The wrapped function.
"""
should_extract_args = _check_extractor_argument_list(fn, attr_extractor)

@functools.wraps(fn)
def wrapper(self: T, *args: Any, **kwargs: Any) -> S:
attrs = (
attr_extractor(self, *args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
return fn(self, *args, **kwargs)

return wrapper


ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)


Expand Down Expand Up @@ -111,7 +141,10 @@ def _instrument_method(
attr_extractor: Optional[Callable],
) -> None:
method_to_trace = getattr(instrumented_class, method_name)
traced_method = traceable_async(method_to_trace, tracer, attr_extractor)
if inspect.iscoroutinefunction(method_to_trace):
traced_method = traceable_async(method_to_trace, tracer, attr_extractor)
else:
traced_method = traceable(method_to_trace, tracer, attr_extractor)
setattr(instrumented_class, method_name, traced_method)

logger.debug(f"Instrumented '{instrumented_class.__name__}.{method_name}'.")
Expand Down
19 changes: 12 additions & 7 deletions tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Any, Dict, Sequence, Text, Optional

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
Expand All @@ -10,11 +10,20 @@
from rasa_sdk import Tracker


@pytest.mark.parametrize(
"action_name, expected",
[
("check_balance", {"action_name": "check_balance", "sender_id": "test"}),
(None, {"sender_id": "test"}),
],
)
@pytest.mark.asyncio
async def test_tracing_action_executor_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
action_name: Optional[str],
expected: Dict[Text, Any],
) -> None:
component_class = MockActionExecutor

Expand All @@ -26,7 +35,7 @@ async def test_tracing_action_executor_run(
mock_action_executor = component_class()
action_call = ActionCall(
{
"next_action": "check_balance",
"next_action": action_name,
"sender_id": "test",
"tracker": Tracker("test", {}, {}, [], False, None, {}, ""),
"version": "1.0.0",
Expand All @@ -46,8 +55,4 @@ async def test_tracing_action_executor_run(

assert captured_span.name == "MockActionExecutor.run"

expected_attributes = {
"next_action": "check_balance",
"sender_id": "test",
}
assert captured_span.attributes == expected_attributes
assert captured_span.attributes == expected

0 comments on commit b12a62a

Please sign in to comment.