Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 13, 2024
1 parent b98cc06 commit 89ffffc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import Any, Dict, Sequence, Text, Optional

import pytest
from unittest.mock import Mock
from pytest import MonkeyPatch
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry import trace


from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk import Tracker
from rasa_sdk.tracing.trace_provider import TraceProvider


@pytest.mark.parametrize(
Expand Down Expand Up @@ -56,3 +61,30 @@ async def test_tracing_action_executor_run(
assert captured_span.name == "MockActionExecutor.run"

assert captured_span.attributes == expected


@pytest.mark.asyncio
async def test_instrument_action_executor_run_registers_tracer(
tracer_provider: TracerProvider, monkeypatch: MonkeyPatch
) -> None:
component_class = MockActionExecutor

mock_tracer = trace.get_tracer(__name__)

register_tracer_mock = Mock()
get_tracer_mock = Mock(return_value=mock_tracer)

monkeypatch.setattr(TraceProvider, "register_tracer", register_tracer_mock())
monkeypatch.setattr(TraceProvider, "get_tracer", get_tracer_mock)

instrumentation.instrument(
tracer_provider,
action_executor_class=component_class,
)
register_tracer_mock.assert_called_once()

provider = TraceProvider()
tracer = provider.get_tracer()

assert tracer is not None
assert tracer == mock_tracer
19 changes: 19 additions & 0 deletions tests/tracing/test_trace_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from rasa_sdk.tracing.trace_provider import TraceProvider
from opentelemetry import trace


def test_anonymization_pipeline_provider_is_singleton() -> None:
trace_provider_1 = TraceProvider()
trace_provider_2 = TraceProvider()

assert trace_provider_1 is trace_provider_2
assert trace_provider_1.tracer is trace_provider_2.tracer


def test_trace_provider() -> None:
trace_provider = TraceProvider()
tracer = trace.get_tracer(__name__)
trace_provider.register_tracer(tracer)

assert trace_provider.tracer == tracer
assert trace_provider.get_tracer() == tracer

0 comments on commit 89ffffc

Please sign in to comment.