Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions datashare-python/datashare_python/logging_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
import sys
from copy import copy

from icij_common.logging_utils import (
DATE_FMT,
STREAM_HANDLER_FMT,
STREAM_HANDLER_FMT_WITH_WORKER_ID,
)
from icij_common.logging_utils import DATE_FMT, STREAM_HANDLER_FMT
from pythonjsonlogger.core import RESERVED_ATTRS, BaseJsonFormatter
from pythonjsonlogger.orjson import OrjsonFormatter
from temporalio import activity, workflow
Expand All @@ -26,6 +22,11 @@
)


_STREAM_HANDLER_FMT_WITH_WORKER_ID = (
"[%(levelname)s][%(asctime)s.%(msecs)03d][%(worker_id)s][%(name)s]: %(message)s"
)


def setup_worker_loggers(
loggers: dict[str, LogLevel], *, worker_id: str | None, in_json: bool
) -> None:
Expand All @@ -35,35 +36,18 @@ def setup_worker_loggers(
logger = logging.getLogger(logger_name)
logger.setLevel(level)
logger.handlers = []
for handler in _get_worker_handlers(level, worker_id, in_json=in_json):
for handler in _get_worker_handlers(level, worker_filter, in_json=in_json):
logger.addHandler(handler)
logger.addFilter(worker_filter)


def _get_worker_handlers(
level: int, worker_id: str | None, *, in_json: bool
) -> list[logging.Handler]:
stream_handler = logging.StreamHandler(sys.stderr)
if in_json:
fmt = _json_formatter(datefmt=DATE_FMT)
else:
if worker_id is not None:
fmt = STREAM_HANDLER_FMT_WITH_WORKER_ID
else:
fmt = STREAM_HANDLER_FMT
fmt = logging.Formatter(fmt, DATE_FMT)
stream_handler.setFormatter(fmt)
stream_handler.setLevel(level)
return [stream_handler]


class WorkerFilter(logging.Filter):
def __init__(self, worker_id: str) -> None:
def __init__(self, worker_id: str | None) -> None:
super().__init__()
self._worker_id = worker_id
self.worker_id = worker_id

def filter(self, record: logging.LogRecord) -> bool:
record.worker_id = self._worker_id
if self.worker_id is not None:
record.worker_id = self.worker_id
if workflow.in_workflow():
wf_info = workflow.info()
for attr in _WF_LOGGED_ATTRS:
Expand All @@ -79,6 +63,24 @@ def filter(self, record: logging.LogRecord) -> bool:
return True


def _get_worker_handlers(
level: int, worker_filter: WorkerFilter, *, in_json: bool
) -> list[logging.Handler]:
stream_handler = logging.StreamHandler(sys.stderr)
if in_json:
fmt = _json_formatter(datefmt=DATE_FMT)
else:
if worker_filter.worker_id is not None:
fmt = _STREAM_HANDLER_FMT_WITH_WORKER_ID
else:
fmt = STREAM_HANDLER_FMT
fmt = logging.Formatter(fmt, DATE_FMT)
stream_handler.setFormatter(fmt)
stream_handler.setLevel(level)
stream_handler.addFilter(worker_filter)
return [stream_handler]


def _json_formatter(datefmt: str) -> BaseJsonFormatter:
fmt = OrjsonFormatter( # let's keep logging as fast as possible
_LOGGED_ATTRIBUTES, datefmt=datefmt
Expand Down
154 changes: 86 additions & 68 deletions datashare-python/datashare_python/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import contextlib
import contextvars
import inspect
import json
import logging
import sys
from collections.abc import (
Callable,
Coroutine,
)
import threading
from collections.abc import Awaitable, Callable, Coroutine
from copy import deepcopy
from dataclasses import dataclass
from datetime import timedelta
Expand All @@ -20,15 +18,6 @@

import nest_asyncio
import temporalio
from icij_common.logging_utils import (
DATE_FMT,
STREAM_HANDLER_FMT,
STREAM_HANDLER_FMT_WITH_WORKER_ID,
WorkerIdFilter,
)
from icij_common.pydantic_utils import get_field_default_value
from pydantic.fields import FieldInfo
from pythonjsonlogger.json import JsonFormatter
from temporalio import activity, workflow
from temporalio.client import Client, WorkflowHandle
from temporalio.common import RetryPolicy, SearchAttributeKey
Expand Down Expand Up @@ -123,6 +112,7 @@ async def execute_activity(
*,
args: list | None = None,
start_to_close_timeout: timedelta | None = None,
heartbeat_timeout: timedelta = timedelta(minutes=1),
retry_policy: temporalio.common.RetryPolicy | None = None,
) -> Any:
if args is None:
Expand All @@ -135,6 +125,7 @@ async def execute_activity(
start_to_close_timeout=start_to_close_timeout,
task_queue=task_queue,
retry_policy=retry_policy,
heartbeat_timeout=heartbeat_timeout,
)


Expand All @@ -150,6 +141,8 @@ async def progress_handler(
activity_id=activity_id, run_id=run_id, progress=progress, weight=weight
)
await handle.signal("update_progress", signal)
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat()


def get_activity_progress_handler_async(
Expand Down Expand Up @@ -229,6 +222,74 @@ def wrapper(self: ActivityWithProgress, *args: P.args) -> T:
return decorator


def with_async_heartbeat(
activity_fn: Callable[P, Awaitable[T]], n_missed_before_timeout: int
) -> Callable[P, Awaitable[T]]:
# Copied from
# https://github.com/temporalio/samples-python/blob/main/custom_decorator/activity_utils.py
@wraps(activity_fn)
async def wrapper(*args, **kwargs) -> T:
heartbeat_timeout = activity.info().heartbeat_timeout
heartbeat_task = None
if heartbeat_timeout:
period = heartbeat_timeout.total_seconds() / n_missed_before_timeout
heartbeat_task = asyncio.create_task(_async_heartbeat_every(period))
try:
activity.heartbeat()
return await activity_fn(*args, **kwargs)
finally:
if heartbeat_task:
heartbeat_task.cancel()
await asyncio.wait([heartbeat_task])

return wrapper


async def _async_heartbeat_every(period: float, *details: Any) -> None:
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat(*details)
while True:
await asyncio.sleep(period)
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat(*details)


def with_sync_heartbeat(
activity_fn: Callable[P, T], n_missed_before_timeout: int
) -> Callable[P, T]:
@wraps(activity_fn)
def wrapper(*args, **kwargs) -> T:
heartbeat_timeout = activity.info().heartbeat_timeout
heartbeat_thread, stop_event = None, None
if heartbeat_timeout:
period = heartbeat_timeout.total_seconds() / n_missed_before_timeout
ctx = contextvars.copy_context()
run_args = (_sync_heartbeat_every, period, threading.Event())
heartbeat_thread, stop_event = (
threading.Thread(target=ctx.run, args=run_args),
run_args[-1],
)
heartbeat_thread.start()
try:
return activity_fn(*args, **kwargs)
finally:
if heartbeat_thread:
stop_event.set()
heartbeat_thread.join()

return wrapper


def _sync_heartbeat_every(
period: float, stop_event: threading.Event, *details: Any
) -> None:
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat(*details)
while not stop_event.wait(period):
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat(*details)


def positional_args_only(activity_fn: Callable[P, T]) -> Callable[P, T]:
sig = inspect.signature(activity_fn)

Expand Down Expand Up @@ -336,6 +397,7 @@ def activity_defn(
name: str,
progress_weight: float = 1.0,
retriables: set[type[Exception]] = None,
n_missed_heartbeats_before_timeout: int = 5,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def decorator(activity_fn: Callable[P, T]) -> Callable[P, T]:
# TODO: some of these could probably be reimplemented more elegantly using
Expand All @@ -344,6 +406,15 @@ def decorator(activity_fn: Callable[P, T]) -> Callable[P, T]:
activity_fn = with_retriables(retriables)(activity_fn)
if supports_progress(activity_fn):
activity_fn = with_progress(progress_weight)(activity_fn)
is_async = asyncio.iscoroutinefunction(activity_fn)
if is_async:
activity_fn = with_async_heartbeat(
activity_fn, n_missed_heartbeats_before_timeout
)
else:
activity_fn = with_sync_heartbeat(
activity_fn, n_missed_heartbeats_before_timeout
)
activity_fn = activity.defn(activity_fn, name=name)
return activity_fn

Expand Down Expand Up @@ -382,59 +453,6 @@ async def _scaled(p: float) -> None:
return _scaled


class LogWithWorkerIDMixin:
def setup_loggers(self, worker_id: str | None = None) -> None:
# Ugly work around the Pydantic V1 limitations...
all_loggers = self.loggers
if isinstance(all_loggers, FieldInfo):
all_loggers = get_field_default_value(all_loggers)
all_loggers.append(__name__)
loggers = sorted(set(all_loggers))
log_level = self.log_level
if isinstance(log_level, FieldInfo):
log_level = get_field_default_value(log_level)
force_warning = getattr(self, "force_warning_loggers", [])
if isinstance(force_warning, FieldInfo):
force_warning = get_field_default_value(force_warning)
force_warning = set(force_warning)
worker_id_filter = None
if worker_id is not None:
worker_id_filter = WorkerIdFilter(worker_id)
handlers = self._handlers(worker_id_filter, log_level)
for logger_ in loggers:
logger_ = logging.getLogger(logger_) # noqa: PLW2901
level = getattr(logging, log_level)
if logger_.name in force_warning:
level = max(logging.WARNING, level)
logger_.setLevel(level)
logger_.handlers = []
for handler in handlers:
logger_.addHandler(handler)

def _handlers(
self, worker_id_filter: logging.Filter | None, log_level: int
) -> list[logging.Handler]:
stream_handler = logging.StreamHandler(sys.stderr)
if worker_id_filter is not None:
fmt = STREAM_HANDLER_FMT_WITH_WORKER_ID
else:
fmt = STREAM_HANDLER_FMT
log_in_json = getattr(self, "log_in_json", False)
if isinstance(log_in_json, FieldInfo):
log_in_json = get_field_default_value(log_in_json)
if log_in_json:
fmt = JsonFormatter(fmt, DATE_FMT)
else:
fmt = logging.Formatter(fmt, DATE_FMT)
stream_handler.setFormatter(fmt)
handlers = [stream_handler]
for handler in handlers:
if worker_id_filter is not None:
handler.addFilter(worker_id_filter)
handler.setLevel(log_level)
return handlers


def safe_dir(doc_id: str) -> Path:
if len(doc_id) < 4:
raise ValueError(f"expected doc_id to be at least 4, found {doc_id}")
Expand Down
12 changes: 11 additions & 1 deletion datashare-python/datashare_python/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from copy import copy
from typing import Any

from temporalio.worker import PollerBehaviorSimpleMaximum, Worker
from temporalio.worker import (
PollerBehaviorSimpleMaximum,
UnsandboxedWorkflowRunner,
Worker,
)
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner

from .config import WorkerConfig
from .dependencies import with_dependencies
Expand Down Expand Up @@ -62,6 +67,7 @@ def datashare_worker(
# Scale horizontally be default for activities, each worker processes one activity
# at a time
max_concurrent_io_activities: int = 10,
sandboxed: bool = True,
) -> DatashareWorker:
if workflows is None:
workflows = []
Expand All @@ -86,6 +92,7 @@ def datashare_worker(
if workflows:
logger.warning(_SEPARATE_IO_AND_CPU_WORKERS)
interceptors = [TraceContextInterceptor()]
wf_runner = SandboxedWorkflowRunner() if sandboxed else UnsandboxedWorkflowRunner()
return DatashareWorker(
client,
interceptors=interceptors,
Expand All @@ -101,6 +108,7 @@ def datashare_worker(
# Workflow tasks are assumed to be very lightweight and fast we can reserve
# several of them
workflow_task_poller_behavior=PollerBehaviorSimpleMaximum(5),
workflow_runner=wf_runner,
)


Expand Down Expand Up @@ -144,6 +152,7 @@ async def worker_context(
event_loop: AbstractEventLoop,
task_queue: str,
dependencies: list[ContextManagerFactory] | None = None,
sandboxed: bool = True,
) -> AsyncGenerator[DatashareWorker, None]:
discovered = []
if activities is not None:
Expand Down Expand Up @@ -185,6 +194,7 @@ async def worker_context(
activities=acts,
task_queue=task_queue,
max_concurrent_io_activities=worker_config.max_concurrent_io_activities,
sandboxed=sandboxed,
)
async with worker:
yield worker
Expand Down
2 changes: 1 addition & 1 deletion datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "datashare-python"
version = "0.7.2"
description = "Manage Pythoœn tasks and local resources in Datashare"
description = "Manage Python tasks and local resources in Datashare"
authors = [
{ name = "Clément Doumouro", email = "cdoumouro@icij.org" },
{ name = "Clément Doumouro", email = "clement.doumouro@gmail.com" },
Expand Down
Loading
Loading