Skip to content

Commit

Permalink
Merge pull request #1073 from RasaHQ/instrument-ActionExecutor.run-me…
Browse files Browse the repository at this point in the history
…thod

[ATO-2099] Instrument ActionExecutor.run method
  • Loading branch information
Tawakalt committed Feb 8, 2024
2 parents 33b8fa4 + 2366b0c commit 552f409
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog/1073.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Instrument `ActionExecutor.run` method.
30 changes: 23 additions & 7 deletions rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,35 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor


TRACING_SERVICE_NAME = os.environ.get("TRACING_SERVICE_NAME", "rasa_sdk")
TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")

ENDPOINTS_TRACING_KEY = "tracing"

logger = logging.getLogger(__name__)


def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
"""Configure tracing functionality.
When a tracing backend is defined, this function will
instrument all methods that shall be traced.
If no tracing backend is defined, no tracing is configured.
:param tracer_provider: The `TracingProvider` to be used for tracing
"""
if tracer_provider is None:
return None

instrumentation.instrument(
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
)


def get_tracer_provider(endpoints_file: Text) -> Optional[TracerProvider]:
"""Configure tracing backend.
Expand Down Expand Up @@ -90,9 +110,7 @@ def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
:return: The configured `TracerProvider`.
"""
provider = TracerProvider(
resource=Resource.create(
{SERVICE_NAME: cfg.kwargs.get("service_name", TRACING_SERVICE_NAME)}
)
resource=Resource.create({SERVICE_NAME: TRACING_SERVICE_NAME})
)

jaeger_exporter = JaegerExporter(
Expand Down Expand Up @@ -132,9 +150,7 @@ def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
:return: The configured `TracerProvider`.
"""
provider = TracerProvider(
resource=Resource.create(
{SERVICE_NAME: cfg.kwargs.get("service_name", TRACING_SERVICE_NAME)}
)
resource=Resource.create({SERVICE_NAME: TRACING_SERVICE_NAME})
)

insecure = cfg.kwargs.get("insecure")
Expand Down
30 changes: 30 additions & 0 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Dict, Text
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall


# This file contains all attribute extractors for tracing instrumentation.
# These are functions that are applied to the arguments of the wrapped function to be
# traced to extract the attributes that we want to forward to our tracing backend.
# Note that we always mirror the argument lists of the wrapped functions, as our
# wrapping mechanism always passes in the original arguments unchanged for further
# processing.


def extract_attrs_for_action_executor(
self: ActionExecutor,
action_call: ActionCall,
) -> Dict[Text, Any]:
"""Extract the attributes for `ActionExecutor.run`.
:param self: The `ActionExecutor` on which `run` is called.
:param action_call: The `ActionCall` argument.
:return: A dictionary containing the attributes.
"""
attributes = {"sender_id": action_call["sender_id"]}
action_name = action_call.get("next_action")

if action_name:
attributes["action_name"] = action_name

return attributes
175 changes: 175 additions & 0 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import functools
import inspect
import logging
from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Text,
Type,
TypeVar,
)

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.tracing.instrumentation import attribute_extractors

# The `TypeVar` representing the return type for a function to be wrapped.
S = TypeVar("S")
# The `TypeVar` representing the type of the argument passed to the function to be
# wrapped.
T = TypeVar("T")

logger = logging.getLogger(__name__)
INSTRUMENTED_BOOLEAN_ATTRIBUTE_NAME = "class_has_been_instrumented"


def _check_extractor_argument_list(
fn: Callable[[T, Any, Any], S],
attr_extractor: Optional[Callable[[T, Any, Any], Dict[str, Any]]],
) -> bool:
if attr_extractor is None:
return False

fn_args = inspect.signature(fn)
attr_args = inspect.signature(attr_extractor)

are_arglists_congruent = fn_args.parameters.keys() == attr_args.parameters.keys()

if not are_arglists_congruent:
logger.warning(
f"Argument lists for {fn.__name__} and {attr_extractor.__name__}"
f" do not match up. {fn.__name__} will be traced without attributes."
)

return are_arglists_congruent


def traceable_async(
fn: Callable[[T, Any, Any], Awaitable[S]],
tracer: Tracer,
attr_extractor: Optional[Callable[[T, Any, Any], Dict[str, Any]]],
) -> Callable[[T, Any, Any], Awaitable[S]]:
"""Wrap an `async` function by tracing functionality.
:param fn: The function to be wrapped.
:param tracer: The `Tracer` that shall be used for tracing this function.
:param attr_extractor: A function that is applied to the function's instance and
the function's arguments.
:return: The wrapped function.
"""
should_extract_args = _check_extractor_argument_list(fn, attr_extractor)

@functools.wraps(fn)
async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
attrs = (
attr_extractor(self, *args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
return await fn(self, *args, **kwargs)

return async_wrapper


def traceable(
fn: Callable[[T, Any, Any], S],
tracer: Tracer,
attr_extractor: Optional[Callable[[T, Any, Any], Dict[str, Any]]],
) -> Callable[[T, Any, Any], S]:
"""Wrap a non-`async` function by tracing functionality.
:param fn: The function to be wrapped.
:param tracer: The `Tracer` that shall be used for tracing this function.
:param attr_extractor: A function that is applied to the function's instance and
the function's arguments.
:return: The wrapped function.
"""
should_extract_args = _check_extractor_argument_list(fn, attr_extractor)

@functools.wraps(fn)
def wrapper(self: T, *args: Any, **kwargs: Any) -> S:
attrs = (
attr_extractor(self, *args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
return fn(self, *args, **kwargs)

return wrapper


ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.
:param tracer_provider: The `TracerProvider` to be used for configuring tracing
on the substituted methods.
:param action_executor_class: The `ActionExecutor` to be instrumented. If `None`
is given, no `ActionExecutor` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
):
_instrument_method(
tracer_provider.get_tracer(action_executor_class.__module__),
action_executor_class,
"run",
attribute_extractors.extract_attrs_for_action_executor,
)
mark_class_as_instrumented(action_executor_class)


def _instrument_method(
tracer: Tracer,
instrumented_class: Type,
method_name: Text,
attr_extractor: Optional[Callable],
) -> None:
method_to_trace = getattr(instrumented_class, method_name)
if inspect.iscoroutinefunction(method_to_trace):
traced_method = traceable_async(method_to_trace, tracer, attr_extractor)
else:
traced_method = traceable(method_to_trace, tracer, attr_extractor)
setattr(instrumented_class, method_name, traced_method)

logger.debug(f"Instrumented '{instrumented_class.__name__}.{method_name}'.")


def _mangled_instrumented_boolean_attribute_name(instrumented_class: Type) -> Text:
# see https://peps.python.org/pep-0008/#method-names-and-instance-variables
# and https://stackoverflow.com/a/50401073
return f"_{instrumented_class.__name__}__{INSTRUMENTED_BOOLEAN_ATTRIBUTE_NAME}"


def class_is_instrumented(instrumented_class: Type) -> bool:
"""Check if a class has already been instrumented."""
return getattr(
instrumented_class,
_mangled_instrumented_boolean_attribute_name(instrumented_class),
False,
)


def mark_class_as_instrumented(instrumented_class: Type) -> None:
"""Mark a class as instrumented if it isn't already marked."""
if not class_is_instrumented(instrumented_class):
setattr(
instrumented_class,
_mangled_instrumented_boolean_attribute_name(instrumented_class),
True,
)
3 changes: 2 additions & 1 deletion rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ def get_tracer_provider(

if endpoints_file is not None:
tracer_provider = config.get_tracer_provider(endpoints_file)
config.configure_tracing(tracer_provider)
return tracer_provider


def get_tracer_and_context(
tracer_provider: Optional[TracerProvider], request: Request
) -> Tuple[Any, Any, Text]:
"""Gets tracer and context"""
span_name = "rasa_sdk.create_app.webhook"
span_name = "create_app.webhook"
if tracer_provider is None:
tracer = trace.get_tracer(span_name)
context = None
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions tests/tracing/instrumentation/action_fixtures/dummy_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Any, Dict
from rasa_sdk import Action, Tracker
from rasa_sdk.executor import CollectingDispatcher


class DummyAction(Action):
def name(self) -> str:
return "dummy_action"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[str, Any],
):
return ""
23 changes: 23 additions & 0 deletions tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest
from typing import Text

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall


@pytest.fixture(scope="session")
def tracer_provider() -> TracerProvider:
Expand All @@ -21,3 +25,22 @@ def span_exporter(tracer_provider: TracerProvider) -> InMemorySpanExporter:
def previous_num_captured_spans(span_exporter: InMemorySpanExporter) -> int:
captured_spans = span_exporter.get_finished_spans() # type: ignore
return len(captured_spans)


class MockActionExecutor(ActionExecutor):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def run(self, action_call: ActionCall) -> None:
pass
58 changes: 58 additions & 0 deletions tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Dict, Sequence, Text, Optional

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk import Tracker


@pytest.mark.parametrize(
"action_name, expected",
[
("check_balance", {"action_name": "check_balance", "sender_id": "test"}),
(None, {"sender_id": "test"}),
],
)
@pytest.mark.asyncio
async def test_tracing_action_executor_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
action_name: Optional[str],
expected: Dict[Text, Any],
) -> None:
component_class = MockActionExecutor

instrumentation.instrument(
tracer_provider,
action_executor_class=component_class,
)

mock_action_executor = component_class()
action_call = ActionCall(
{
"next_action": action_name,
"sender_id": "test",
"tracker": Tracker("test", {}, {}, [], False, None, {}, ""),
"version": "1.0.0",
"domain": {},
}
)
await mock_action_executor.run(action_call)

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]

assert captured_span.name == "MockActionExecutor.run"

assert captured_span.attributes == expected

0 comments on commit 552f409

Please sign in to comment.