From 39a61c8088ad7ef599334d2a0bec8583ccefbeed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 22 Apr 2026 17:53:07 +0200 Subject: [PATCH 1/3] feature(datashare-python): implement a `TraceContextInterceptor` worker interceptor to handle trace context --- .../datashare_python/interceptors.py | 198 ++++++++++++++++++ datashare-python/datashare_python/utils.py | 2 + datashare-python/tests/test_interceptors.py | 123 +++++++++++ 3 files changed, 323 insertions(+) create mode 100644 datashare-python/datashare_python/interceptors.py create mode 100644 datashare-python/tests/test_interceptors.py diff --git a/datashare-python/datashare_python/interceptors.py b/datashare-python/datashare_python/interceptors.py new file mode 100644 index 0000000..8b42f42 --- /dev/null +++ b/datashare-python/datashare_python/interceptors.py @@ -0,0 +1,198 @@ +import secrets +from collections.abc import Generator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar +from copy import deepcopy +from typing import Annotated, Any, NoReturn, Self, TypeVar + +from nexusrpc import InputT, OutputT +from pydantic import Field +from temporalio.api.common.v1 import Payload +from temporalio.converter import DataConverter +from temporalio.worker import ( + ActivityInboundInterceptor, + ContinueAsNewInput, + ExecuteActivityInput, + ExecuteWorkflowInput, + HandleQueryInput, + HandleSignalInput, + Interceptor, + SignalChildWorkflowInput, + SignalExternalWorkflowInput, + StartActivityInput, + StartChildWorkflowInput, + StartLocalActivityInput, + StartNexusOperationInput, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) +from temporalio.workflow import ( + ActivityHandle, + ChildWorkflowHandle, + NexusOperationHandle, +) + +from .objects import BaseModel + +_TRACEPARENT = "traceparent" +_DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter + + +class TraceContext(BaseModel): + # https://www.w3.org/TR/trace-context/ + version: Annotated[str, Field(frozen=True)] = "00" + trace_id: str + parent_id: str + sampled: bool = True + + def __hash__(self) -> int: + return hash((self.trace_id, self.parent_id, self.sampled)) + + @classmethod + def next_span(cls, parent: Self | None) -> Self: + new_span_id = secrets.token_hex(8) + if parent is None: + trace_id = secrets.token_hex(16) + return TraceContext(trace_id=trace_id, parent_id=new_span_id) + return TraceContext( + trace_id=parent.trace_id, parent_id=new_span_id, sampled=parent.sampled + ) + + @property + def traceparent(self) -> str: + flags = "01" if self.sampled else "00" + return f"{self.version}-{self.trace_id}-{self.parent_id}-{flags}" + + @classmethod + def from_traceparent(cls, traceparent: str) -> Self: + split = traceparent.split("-") + if len(split) != 4: + raise ValueError(f"invalid trace parent: {traceparent}") + version, trace_id, parent_id, flags = split + if version != "00": + msg = ( + f"unsupported trace parent version {version} " + f"for traceparent {traceparent}" + ) + raise ValueError(msg) + sampled = flags == "01" + return cls(trace_id=trace_id, parent_id=parent_id, sampled=sampled) + + +_TRACE_CONTEXT: ContextVar[TraceContext | None] = ContextVar( + "trace_context", default=None +) + + +class TraceContextInterceptor(Interceptor): + def workflow_interceptor_class( + self, + input: WorkflowInterceptorClassInput, # noqa: A002, ARG002 + ) -> type[WorkflowInboundInterceptor] | None: + return _TraceContextWorkflowInboundInterceptor + + def intercept_activity( + self, + next: ActivityInboundInterceptor, # noqa: A002 + ) -> ActivityInboundInterceptor: + return _TraceContextActivityInboundInterceptor(next) + + +class _TraceContextWorkflowInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + with_outbound_trace_ctx = _TraceContextWorkflowOutboundInterceptor(outbound) + super().init(with_outbound_trace_ctx) + + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: # noqa: A002 + with _trace_context(input.headers): + return await super().execute_workflow(input) + + async def handle_signal(self, input: HandleSignalInput) -> None: # noqa: A002 + with _trace_context(input.headers): + return await super().handle_signal(input) + + async def handle_query(self, input: HandleQueryInput) -> Any: # noqa: A002 + with _trace_context(input.headers): + return await super().handle_query(input) + + +class _TraceContextWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + def continue_as_new(self, input: ContinueAsNewInput) -> NoReturn: # noqa: A002 + super().continue_as_new(_with_trace_context_header(input)) + + async def signal_child_workflow(self, input: SignalChildWorkflowInput) -> None: # noqa: A002 + return await super().signal_child_workflow(_with_trace_context_header(input)) + + async def signal_external_workflow( + self, + input: SignalExternalWorkflowInput, # noqa: A002 + ) -> None: + return await super().signal_external_workflow(_with_trace_context_header(input)) + + def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: # noqa: A002 + return super().start_activity(_with_trace_context_header(input)) + + async def start_child_workflow( + self, + input: StartChildWorkflowInput, # noqa: A002 + ) -> ChildWorkflowHandle[Any, Any]: + return await super().start_child_workflow(_with_trace_context_header(input)) + + def start_local_activity( + self, + input: StartLocalActivityInput, # noqa: A002 + ) -> ActivityHandle[Any]: + return super().start_local_activity(_with_trace_context_header(input)) + + async def start_nexus_operation( + self, + input: StartNexusOperationInput[InputT, OutputT], # noqa: A002 + ) -> NexusOperationHandle[OutputT]: + return await super().start_nexus_operation(_with_trace_context_header(input)) + + +class _TraceContextActivityInboundInterceptor(ActivityInboundInterceptor): + async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A002 + with _trace_context(input.headers): + return await super().execute_activity(input) + + +def get_trace_context() -> TraceContext | None: + return _TRACE_CONTEXT.get() + + +@contextmanager +def _trace_context(headers: Mapping[str, Payload]) -> Generator[None, None, None]: + ctx = headers.get(_TRACEPARENT) + if ctx is not None: + ctx = _DEFAULT_PAYLOAD_CONVERTER.from_payloads( + [headers.get(_TRACEPARENT)], None + )[0] + ctx = TraceContext.from_traceparent(ctx) + else: + ctx = TraceContext.next_span(None) + tok = None + try: + tok = _TRACE_CONTEXT.set(ctx) + yield + finally: + if tok is not None: + _TRACE_CONTEXT.reset(tok) + + +InputWithHeaders = TypeVar("InputWithHeaders") + + +def _with_trace_context_header( + input_with_headers: InputWithHeaders, +) -> InputWithHeaders: + ctx = get_trace_context() + if ctx is None: + return input_with_headers + new_obj = deepcopy(input_with_headers) + next_ctx = TraceContext.next_span(ctx) + new_obj.headers[_TRACEPARENT] = _DEFAULT_PAYLOAD_CONVERTER.to_payload( + next_ctx.traceparent + ) + return new_obj diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index b346676..751fb59 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -338,6 +338,8 @@ def activity_defn( retriables: set[type[Exception]] = None, ) -> Callable[[Callable[P, T]], Callable[P, T]]: def decorator(activity_fn: Callable[P, T]) -> Callable[P, T]: + # TODO: some of these could probably be reimplemented more elegantly using + # temporal interceptors: https://docs.temporal.io/develop/python/workers/interceptors activity_fn = positional_args_only(activity_fn) activity_fn = with_retriables(retriables)(activity_fn) if supports_progress(activity_fn): diff --git a/datashare-python/tests/test_interceptors.py b/datashare-python/tests/test_interceptors.py new file mode 100644 index 0000000..f948af1 --- /dev/null +++ b/datashare-python/tests/test_interceptors.py @@ -0,0 +1,123 @@ +import uuid +from collections.abc import AsyncGenerator +from datetime import timedelta +from typing import Any + +import pytest +import temporalio + +with temporalio.workflow.unsafe.imports_passed_through(): + from datashare_python.config import PYDANTIC_DATA_CONVERTER, WorkerConfig + from datashare_python.interceptors import ( + TraceContext, + TraceContextInterceptor, + get_trace_context, + ) + from datashare_python.types_ import TemporalClient + from temporalio import activity, workflow + from temporalio.client import ( + Interceptor, + OutboundInterceptor, + StartWorkflowInput, + WorkflowHandle, + ) + from temporalio.converter import DataConverter + from temporalio.worker import Worker + +_TEST_CTX_QUEUE = "test.ctx.queue" + +_DUMMY_TRACE_CTX = TraceContext(version="00", trace_id="trace_id", parent_id="trace_id") +_DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter + + +class _MockOutboundInterceptor(OutboundInterceptor): + async def start_workflow( + self, + input: StartWorkflowInput, # noqa: A002 + ) -> WorkflowHandle[Any, Any]: + input.headers["traceparent"] = _DEFAULT_PAYLOAD_CONVERTER.to_payload( + _DUMMY_TRACE_CTX.traceparent + ) + return await self.next.start_workflow(input) + + +class _MockTraceContextHeaderInterceptor(Interceptor): + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: # noqa: A002 + return super().intercept_client(_MockOutboundInterceptor(next)) + + +_TIMEOUT = timedelta(seconds=10) + + +@workflow.defn +class _TestTraceContentWorkflow: + @workflow.run + async def run(self) -> list[TraceContext]: + current_ctx = get_trace_context() + ctx_log = [current_ctx] + ctx_log = await workflow.execute_activity( + ctx_test_act, + ctx_log, + task_queue=_TEST_CTX_QUEUE, + start_to_close_timeout=_TIMEOUT, + ) + ctx_log = await workflow.execute_activity( + ctx_test_act, + ctx_log, + task_queue=_TEST_CTX_QUEUE, + start_to_close_timeout=_TIMEOUT, + ) + return ctx_log + + +@activity.defn +async def ctx_test_act(previous: list[TraceContext]) -> list[TraceContext]: + previous.append(get_trace_context()) + return previous + + +@pytest.fixture(scope="session") +async def test_interceptor_worker( + test_temporal_client_session: TemporalClient, +) -> AsyncGenerator[None, None]: + client = test_temporal_client_session + worker_id = f"test-interceptor-worker-{uuid.uuid4()}" + interceptors = [TraceContextInterceptor()] + worker = Worker( + client, + identity=worker_id, + activities=[ctx_test_act], + workflows=[_TestTraceContentWorkflow], + interceptors=interceptors, + task_queue=_TEST_CTX_QUEUE, + ) + async with worker: + yield + + +async def test_trace_context_interceptor( + test_interceptor_worker, # noqa: ANN001, ARG001 + test_worker_config: WorkerConfig, +) -> None: + # Given + temporal_config = test_worker_config.temporal + client = await TemporalClient.connect( + target_host=temporal_config.host, + namespace=temporal_config.namespace, + data_converter=PYDANTIC_DATA_CONVERTER, + interceptors=[_MockTraceContextHeaderInterceptor()], + ) + wf_id = f"wf-test-interceptor-{uuid.uuid4()}" + # When + res = await client.execute_workflow( + _TestTraceContentWorkflow, id=wf_id, task_queue=_TEST_CTX_QUEUE + ) + # Then + assert len(res) == 3 + first = res[0] + assert first == _DUMMY_TRACE_CTX + remaining = res[1:] + for trace_ctx in remaining: + assert trace_ctx.trace_id == _DUMMY_TRACE_CTX.trace_id + assert len(trace_ctx.parent_id) == 16 + assert trace_ctx.sampled From 30bbc9dc28685ccb726f01ab0f2761796492cb58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 22 Apr 2026 17:58:21 +0200 Subject: [PATCH 2/3] feature(datashare-python): add trace context information to the `WorkerFilter` --- datashare-python/datashare_python/logging_.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/datashare-python/datashare_python/logging_.py b/datashare-python/datashare_python/logging_.py index 944a7ad..97a837e 100644 --- a/datashare-python/datashare_python/logging_.py +++ b/datashare-python/datashare_python/logging_.py @@ -12,20 +12,17 @@ from temporalio import activity, workflow from .config import LogLevel +from .interceptors import get_trace_context -_ACT_LOGGER_ATTRS = [ - "activity_type", - "activity_id", - "activity_run_id", -] - -_WF_LOGGED_ATTRS = [ - "workflow_type", - "workflow_id", - "workflow_run_id", -] +_ACT_LOGGER_ATTRS = ["activity_type", "activity_id", "activity_run_id"] +_WF_LOGGED_ATTRS = ["workflow_type", "workflow_id", "workflow_run_id"] +_TRACE_CONTEXT_ATTRS = ["trace_id", "parent_id", "traceparent"] _LOGGED_ATTRIBUTES = ( - copy(RESERVED_ATTRS) + _WF_LOGGED_ATTRS + _ACT_LOGGER_ATTRS + ["worker_id"] + copy(RESERVED_ATTRS) + + _WF_LOGGED_ATTRS + + _ACT_LOGGER_ATTRS + + _TRACE_CONTEXT_ATTRS + + ["worker_id"] ) @@ -75,6 +72,10 @@ def filter(self, record: logging.LogRecord) -> bool: act_info = activity.info() for attr in _ACT_LOGGER_ATTRS: setattr(record, attr, getattr(act_info, attr)) + trace_context = get_trace_context() + if trace_context is not None: + for attr in _TRACE_CONTEXT_ATTRS: + setattr(record, attr, getattr(trace_context, attr)) return True From eabb8ca20bf4c70c2c5f6d8ea6ef858e9eed954a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 22 Apr 2026 18:05:51 +0200 Subject: [PATCH 3/3] feature(datashare-python): add `TraceContextInterceptor` to `datashare_worker` utility function --- datashare-python/datashare_python/worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datashare-python/datashare_python/worker.py b/datashare-python/datashare_python/worker.py index ee4452b..92fb17e 100644 --- a/datashare-python/datashare_python/worker.py +++ b/datashare-python/datashare_python/worker.py @@ -16,6 +16,7 @@ from .config import WorkerConfig from .dependencies import with_dependencies from .discovery import Activity +from .interceptors import TraceContextInterceptor from .types_ import ContextManagerFactory, TemporalClient logger = logging.getLogger(__name__) @@ -84,9 +85,10 @@ def datashare_worker( max_concurrent_activities = 1 if workflows: logger.warning(_SEPARATE_IO_AND_CPU_WORKERS) - + interceptors = [TraceContextInterceptor()] return DatashareWorker( client, + interceptors=interceptors, identity=worker_id, workflows=workflows, activities=activities,