Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions datashare-python/datashare_python/interceptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import secrets
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from typing import Annotated, Any, NoReturn, Self, TypeVar

from nexusrpc import InputT, OutputT
from pydantic import Field
from temporalio.api.common.v1 import Payload
from temporalio.converter import DataConverter
from temporalio.worker import (
ActivityInboundInterceptor,
ContinueAsNewInput,
ExecuteActivityInput,
ExecuteWorkflowInput,
HandleQueryInput,
HandleSignalInput,
Interceptor,
SignalChildWorkflowInput,
SignalExternalWorkflowInput,
StartActivityInput,
StartChildWorkflowInput,
StartLocalActivityInput,
StartNexusOperationInput,
WorkflowInboundInterceptor,
WorkflowInterceptorClassInput,
WorkflowOutboundInterceptor,
)
from temporalio.workflow import (
ActivityHandle,
ChildWorkflowHandle,
NexusOperationHandle,
)

from .objects import BaseModel

_TRACEPARENT = "traceparent"
_DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter


class TraceContext(BaseModel):
# https://www.w3.org/TR/trace-context/
version: Annotated[str, Field(frozen=True)] = "00"
trace_id: str
parent_id: str
sampled: bool = True

def __hash__(self) -> int:
return hash((self.trace_id, self.parent_id, self.sampled))

@classmethod
def next_span(cls, parent: Self | None) -> Self:
new_span_id = secrets.token_hex(8)
if parent is None:
trace_id = secrets.token_hex(16)
return TraceContext(trace_id=trace_id, parent_id=new_span_id)
return TraceContext(
trace_id=parent.trace_id, parent_id=new_span_id, sampled=parent.sampled
)

@property
def traceparent(self) -> str:
flags = "01" if self.sampled else "00"
return f"{self.version}-{self.trace_id}-{self.parent_id}-{flags}"

@classmethod
def from_traceparent(cls, traceparent: str) -> Self:
split = traceparent.split("-")
if len(split) != 4:
raise ValueError(f"invalid trace parent: {traceparent}")
version, trace_id, parent_id, flags = split
if version != "00":
msg = (
f"unsupported trace parent version {version} "
f"for traceparent {traceparent}"
)
raise ValueError(msg)
sampled = flags == "01"
return cls(trace_id=trace_id, parent_id=parent_id, sampled=sampled)


_TRACE_CONTEXT: ContextVar[TraceContext | None] = ContextVar(
"trace_context", default=None
)


class TraceContextInterceptor(Interceptor):
def workflow_interceptor_class(
self,
input: WorkflowInterceptorClassInput, # noqa: A002, ARG002
) -> type[WorkflowInboundInterceptor] | None:
return _TraceContextWorkflowInboundInterceptor

def intercept_activity(
self,
next: ActivityInboundInterceptor, # noqa: A002
) -> ActivityInboundInterceptor:
return _TraceContextActivityInboundInterceptor(next)


class _TraceContextWorkflowInboundInterceptor(WorkflowInboundInterceptor):
def init(self, outbound: WorkflowOutboundInterceptor) -> None:
with_outbound_trace_ctx = _TraceContextWorkflowOutboundInterceptor(outbound)
super().init(with_outbound_trace_ctx)

async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: # noqa: A002
with _trace_context(input.headers):
return await super().execute_workflow(input)

async def handle_signal(self, input: HandleSignalInput) -> None: # noqa: A002
with _trace_context(input.headers):
return await super().handle_signal(input)

async def handle_query(self, input: HandleQueryInput) -> Any: # noqa: A002
with _trace_context(input.headers):
return await super().handle_query(input)


class _TraceContextWorkflowOutboundInterceptor(WorkflowOutboundInterceptor):
def continue_as_new(self, input: ContinueAsNewInput) -> NoReturn: # noqa: A002
super().continue_as_new(_with_trace_context_header(input))

async def signal_child_workflow(self, input: SignalChildWorkflowInput) -> None: # noqa: A002
return await super().signal_child_workflow(_with_trace_context_header(input))

async def signal_external_workflow(
self,
input: SignalExternalWorkflowInput, # noqa: A002
) -> None:
return await super().signal_external_workflow(_with_trace_context_header(input))

def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: # noqa: A002
return super().start_activity(_with_trace_context_header(input))

async def start_child_workflow(
self,
input: StartChildWorkflowInput, # noqa: A002
) -> ChildWorkflowHandle[Any, Any]:
return await super().start_child_workflow(_with_trace_context_header(input))

def start_local_activity(
self,
input: StartLocalActivityInput, # noqa: A002
) -> ActivityHandle[Any]:
return super().start_local_activity(_with_trace_context_header(input))

async def start_nexus_operation(
self,
input: StartNexusOperationInput[InputT, OutputT], # noqa: A002
) -> NexusOperationHandle[OutputT]:
return await super().start_nexus_operation(_with_trace_context_header(input))


class _TraceContextActivityInboundInterceptor(ActivityInboundInterceptor):
async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A002
with _trace_context(input.headers):
return await super().execute_activity(input)


def get_trace_context() -> TraceContext | None:
return _TRACE_CONTEXT.get()


@contextmanager
def _trace_context(headers: Mapping[str, Payload]) -> Generator[None, None, None]:
ctx = headers.get(_TRACEPARENT)
if ctx is not None:
ctx = _DEFAULT_PAYLOAD_CONVERTER.from_payloads(
[headers.get(_TRACEPARENT)], None
)[0]
ctx = TraceContext.from_traceparent(ctx)
else:
ctx = TraceContext.next_span(None)
tok = None
try:
tok = _TRACE_CONTEXT.set(ctx)
yield
finally:
if tok is not None:
_TRACE_CONTEXT.reset(tok)


InputWithHeaders = TypeVar("InputWithHeaders")


def _with_trace_context_header(
input_with_headers: InputWithHeaders,
) -> InputWithHeaders:
ctx = get_trace_context()
if ctx is None:
return input_with_headers
new_obj = deepcopy(input_with_headers)
next_ctx = TraceContext.next_span(ctx)
new_obj.headers[_TRACEPARENT] = _DEFAULT_PAYLOAD_CONVERTER.to_payload(
next_ctx.traceparent
)
return new_obj
25 changes: 13 additions & 12 deletions datashare-python/datashare_python/logging_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
from temporalio import activity, workflow

from .config import LogLevel
from .interceptors import get_trace_context

_ACT_LOGGER_ATTRS = [
"activity_type",
"activity_id",
"activity_run_id",
]

_WF_LOGGED_ATTRS = [
"workflow_type",
"workflow_id",
"workflow_run_id",
]
_ACT_LOGGER_ATTRS = ["activity_type", "activity_id", "activity_run_id"]
_WF_LOGGED_ATTRS = ["workflow_type", "workflow_id", "workflow_run_id"]
_TRACE_CONTEXT_ATTRS = ["trace_id", "parent_id", "traceparent"]
_LOGGED_ATTRIBUTES = (
copy(RESERVED_ATTRS) + _WF_LOGGED_ATTRS + _ACT_LOGGER_ATTRS + ["worker_id"]
copy(RESERVED_ATTRS)
+ _WF_LOGGED_ATTRS
+ _ACT_LOGGER_ATTRS
+ _TRACE_CONTEXT_ATTRS
+ ["worker_id"]
)


Expand Down Expand Up @@ -75,6 +72,10 @@ def filter(self, record: logging.LogRecord) -> bool:
act_info = activity.info()
for attr in _ACT_LOGGER_ATTRS:
setattr(record, attr, getattr(act_info, attr))
trace_context = get_trace_context()
if trace_context is not None:
for attr in _TRACE_CONTEXT_ATTRS:
setattr(record, attr, getattr(trace_context, attr))
return True


Expand Down
2 changes: 2 additions & 0 deletions datashare-python/datashare_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def activity_defn(
retriables: set[type[Exception]] = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def decorator(activity_fn: Callable[P, T]) -> Callable[P, T]:
# TODO: some of these could probably be reimplemented more elegantly using
# temporal interceptors: https://docs.temporal.io/develop/python/workers/interceptors
activity_fn = positional_args_only(activity_fn)
activity_fn = with_retriables(retriables)(activity_fn)
if supports_progress(activity_fn):
Expand Down
4 changes: 3 additions & 1 deletion datashare-python/datashare_python/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .config import WorkerConfig
from .dependencies import with_dependencies
from .discovery import Activity
from .interceptors import TraceContextInterceptor
from .types_ import ContextManagerFactory, TemporalClient

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,9 +85,10 @@ def datashare_worker(
max_concurrent_activities = 1
if workflows:
logger.warning(_SEPARATE_IO_AND_CPU_WORKERS)

interceptors = [TraceContextInterceptor()]
return DatashareWorker(
client,
interceptors=interceptors,
identity=worker_id,
workflows=workflows,
activities=activities,
Expand Down
123 changes: 123 additions & 0 deletions datashare-python/tests/test_interceptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import uuid
from collections.abc import AsyncGenerator
from datetime import timedelta
from typing import Any

import pytest
import temporalio

with temporalio.workflow.unsafe.imports_passed_through():
from datashare_python.config import PYDANTIC_DATA_CONVERTER, WorkerConfig
from datashare_python.interceptors import (
TraceContext,
TraceContextInterceptor,
get_trace_context,
)
from datashare_python.types_ import TemporalClient
from temporalio import activity, workflow
from temporalio.client import (
Interceptor,
OutboundInterceptor,
StartWorkflowInput,
WorkflowHandle,
)
from temporalio.converter import DataConverter
from temporalio.worker import Worker

_TEST_CTX_QUEUE = "test.ctx.queue"

_DUMMY_TRACE_CTX = TraceContext(version="00", trace_id="trace_id", parent_id="trace_id")
_DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter


class _MockOutboundInterceptor(OutboundInterceptor):
async def start_workflow(
self,
input: StartWorkflowInput, # noqa: A002
) -> WorkflowHandle[Any, Any]:
input.headers["traceparent"] = _DEFAULT_PAYLOAD_CONVERTER.to_payload(
_DUMMY_TRACE_CTX.traceparent
)
return await self.next.start_workflow(input)


class _MockTraceContextHeaderInterceptor(Interceptor):
def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: # noqa: A002
return super().intercept_client(_MockOutboundInterceptor(next))


_TIMEOUT = timedelta(seconds=10)


@workflow.defn
class _TestTraceContentWorkflow:
@workflow.run
async def run(self) -> list[TraceContext]:
current_ctx = get_trace_context()
ctx_log = [current_ctx]
ctx_log = await workflow.execute_activity(
ctx_test_act,
ctx_log,
task_queue=_TEST_CTX_QUEUE,
start_to_close_timeout=_TIMEOUT,
)
ctx_log = await workflow.execute_activity(
ctx_test_act,
ctx_log,
task_queue=_TEST_CTX_QUEUE,
start_to_close_timeout=_TIMEOUT,
)
return ctx_log


@activity.defn
async def ctx_test_act(previous: list[TraceContext]) -> list[TraceContext]:
previous.append(get_trace_context())
return previous


@pytest.fixture(scope="session")
async def test_interceptor_worker(
test_temporal_client_session: TemporalClient,
) -> AsyncGenerator[None, None]:
client = test_temporal_client_session
worker_id = f"test-interceptor-worker-{uuid.uuid4()}"
interceptors = [TraceContextInterceptor()]
worker = Worker(
client,
identity=worker_id,
activities=[ctx_test_act],
workflows=[_TestTraceContentWorkflow],
interceptors=interceptors,
task_queue=_TEST_CTX_QUEUE,
)
async with worker:
yield


async def test_trace_context_interceptor(
test_interceptor_worker, # noqa: ANN001, ARG001
test_worker_config: WorkerConfig,
) -> None:
# Given
temporal_config = test_worker_config.temporal
client = await TemporalClient.connect(
target_host=temporal_config.host,
namespace=temporal_config.namespace,
data_converter=PYDANTIC_DATA_CONVERTER,
interceptors=[_MockTraceContextHeaderInterceptor()],
)
wf_id = f"wf-test-interceptor-{uuid.uuid4()}"
# When
res = await client.execute_workflow(
_TestTraceContentWorkflow, id=wf_id, task_queue=_TEST_CTX_QUEUE
)
# Then
assert len(res) == 3
first = res[0]
assert first == _DUMMY_TRACE_CTX
remaining = res[1:]
for trace_ctx in remaining:
assert trace_ctx.trace_id == _DUMMY_TRACE_CTX.trace_id
assert len(trace_ctx.parent_id) == 16
assert trace_ctx.sampled
Loading