Skip to content

Commit

Permalink
implement PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 13, 2024
1 parent 40275d1 commit c28a929
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 31 deletions.
4 changes: 2 additions & 2 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk.tracing.instrumentation import attribute_extractors
from rasa_sdk.tracing.trace_provider import TraceProvider
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister

# The `TypeVar` representing the return type for a function to be wrapped.
S = TypeVar("S")
Expand Down Expand Up @@ -141,7 +141,7 @@ def instrument(
attribute_extractors.extract_attrs_for_action_executor,
)
mark_class_as_instrumented(action_executor_class)
TraceProvider().register_tracer(tracer)
ActionExecutorTracerRegister().register_tracer(tracer)

if validation_action_class is not None and not class_is_instrumented(
validation_action_class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
from opentelemetry.trace import Tracer


class TraceProvider(metaclass=Singleton):
"""Represents a provider for tracer."""
class ActionExecutorTracerRegister(metaclass=Singleton):
"""Represents a provider for ActionExecutor tracer."""

tracer: Optional[Tracer] = None

def register_tracer(self, tracer: Tracer) -> None:
"""Register a tracer.
"""Register an ActionExecutor tracer.
Args:
trace: The tracer to register.
"""
self.tracer = tracer

def get_tracer(self) -> Optional[Tracer]:
"""Get the tracer.
"""Get the ActionExecutor tracer.
Returns:
The tracer.
"""
Expand Down
13 changes: 7 additions & 6 deletions tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tests.tracing.instrumentation.conftest import MockActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk import Tracker
from rasa_sdk.tracing.trace_provider import TraceProvider
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister


@pytest.mark.parametrize(
Expand Down Expand Up @@ -62,8 +62,7 @@ async def test_tracing_action_executor_run(
assert captured_span.attributes == expected


@pytest.mark.asyncio
async def test_instrument_action_executor_run_registers_tracer(
def test_instrument_action_executor_run_registers_tracer(
tracer_provider: TracerProvider, monkeypatch: MonkeyPatch
) -> None:
component_class = MockActionExecutor
Expand All @@ -73,16 +72,18 @@ async def test_instrument_action_executor_run_registers_tracer(
register_tracer_mock = Mock()
get_tracer_mock = Mock(return_value=mock_tracer)

monkeypatch.setattr(TraceProvider, "register_tracer", register_tracer_mock())
monkeypatch.setattr(TraceProvider, "get_tracer", get_tracer_mock)
monkeypatch.setattr(
ActionExecutorTracerRegister, "register_tracer", register_tracer_mock()
)
monkeypatch.setattr(ActionExecutorTracerRegister, "get_tracer", get_tracer_mock)

instrumentation.instrument(
tracer_provider,
action_executor_class=component_class,
)
register_tracer_mock.assert_called_once()

provider = TraceProvider()
provider = ActionExecutorTracerRegister()
tracer = provider.get_tracer()

assert tracer is not None
Expand Down
19 changes: 0 additions & 19 deletions tests/tracing/test_trace_provider.py

This file was deleted.

21 changes: 21 additions & 0 deletions tests/tracing/test_tracer_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister
from opentelemetry import trace


def test_tracer_register_is_singleton() -> None:
tracer_register_1 = ActionExecutorTracerRegister()
tracer_register_2 = ActionExecutorTracerRegister()

assert tracer_register_1 is tracer_register_2
assert tracer_register_1.tracer is tracer_register_2.tracer


def test_trace_register() -> None:
tracer_register = ActionExecutorTracerRegister()
assert tracer_register.get_tracer() is None

tracer = trace.get_tracer(__name__)
tracer_register.register_tracer(tracer)

assert tracer_register.tracer == tracer
assert tracer_register.get_tracer() == tracer

0 comments on commit c28a929

Please sign in to comment.