Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish_extract_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
id-token: write
contents: write
env:
PYTHON_VERSION: 3.12
PYTHON_VERSION: 3.14
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_extract_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ jobs:
test:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
PYTHON_VERSION: 3.14
ASTRAL_VERSION: 0.11.24
steps:
- uses: actions/checkout@v6
- name: Setup Python project
Expand Down
44 changes: 26 additions & 18 deletions datashare-python/datashare_python/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import asyncio
import shutil
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator, Generator, Iterator, Sequence
from collections.abc import AsyncGenerator, Generator, Sequence
from pathlib import Path

import aiohttp
Expand All @@ -10,7 +8,9 @@
from elasticsearch._async.helpers import async_streaming_bulk
from icij_common.es import DOC_ROOT_ID, ES_DOCUMENT_TYPE, ID, ESClient
from icij_common.test_utils import reset_env # noqa: F401
from pytest_asyncio import is_async_test
from temporalio import workflow
from temporalio.service import RPCError, RPCStatusCode

from datashare_python.config import (
DatashareClientConfig,
Expand Down Expand Up @@ -59,6 +59,13 @@
}


def pytest_collection_modifyitems(items: list) -> None:
pytest_asyncio_tests = (item for item in items if is_async_test(item))
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
for async_test in pytest_asyncio_tests:
async_test.add_marker(session_scope_marker, append=False)


@activity_defn(name="mocked-act")
def mocked_act() -> None:
pass
Expand All @@ -81,16 +88,6 @@ def test_deps() -> list[ContextManagerFactory]:
return [set_es_client, set_task_client]


@pytest.fixture(scope="session")
def event_loop(
request: pytest.FixtureRequest, # noqa: ARG001
) -> Iterator[asyncio.AbstractEventLoop]:
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="session")
def test_worker_config() -> WorkerConfig:
logging_config = LoggingConfig(
Expand All @@ -117,9 +114,7 @@ def test_worker_config_path(test_worker_config: WorkerConfig, tmpdir: Path) -> P

@pytest.fixture(scope="session")
async def worker_lifetime_deps(
event_loop: AbstractEventLoop,
test_deps: list[ContextManagerFactory],
test_worker_config: WorkerConfig,
test_deps: list[ContextManagerFactory], test_worker_config: WorkerConfig
) -> AsyncGenerator[None, None]:
worker_id = "test-worker-id"
ctx = "test application"
Expand All @@ -128,7 +123,6 @@ async def worker_lifetime_deps(
ctx=ctx,
worker_id=worker_id,
worker_config=test_worker_config,
event_loop=event_loop,
):
yield

Expand Down Expand Up @@ -174,11 +168,25 @@ async def test_task_client(
@pytest.fixture(scope="session")
async def test_temporal_client_session(
test_worker_config: WorkerConfig,
event_loop: AbstractEventLoop, # noqa: ARG001
) -> TemporalClient: # noqa: ANN001
return await test_worker_config.to_temporal_client()


@pytest.fixture
async def test_temporal_client(
test_temporal_client_session: TemporalClient,
) -> TemporalClient: # noqa: ANN001
client = test_temporal_client_session
async for wf in client.list_workflows():
try:
await client.get_workflow_handle(wf.id).terminate()
except RPCError as e:
if e.status != RPCStatusCode.NOT_FOUND:
raise

return client


@pytest.fixture
async def populate_es(
test_es_client: ESClient,
Expand Down
45 changes: 37 additions & 8 deletions datashare-python/datashare_python/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from functools import partial
from functools import partial, wraps
from inspect import signature
from types import UnionType
from typing import (
Expand Down Expand Up @@ -52,7 +52,12 @@
)

from .objects import BaseModel
from .types_ import ProgressRateHandler, Weight
from .types_ import (
AsyncProgressRateHandler,
ProgressRateHandler,
SyncProgressRateHandler,
Weight,
)
from .utils import (
PROGRESS_HANDLER_ARG,
PYDANTIC_DATA_CONVERTER,
Expand All @@ -62,6 +67,11 @@

_TRACEPARENT = "traceparent"
_DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter
_PROGRESS_TYPES = {
ProgressRateHandler,
AsyncProgressRateHandler,
SyncProgressRateHandler,
}


class TraceContext(BaseModel):
Expand Down Expand Up @@ -292,11 +302,13 @@ def _get_progress_handler(act_fn: Callable) -> ProgressRateHandler:


def _is_progress(t: type) -> bool:
if t is ProgressRateHandler:
if any(t is p_cls for p_cls in _PROGRESS_TYPES):
return True
return bool(
isinstance(t, UnionType)
and any(sub_t is ProgressRateHandler for sub_t in get_args(t))
and any(
any(sub_t is p_cls for p_cls in _PROGRESS_TYPES) for sub_t in get_args(t)
)
)


Expand All @@ -315,16 +327,21 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A
# https://github.com/temporalio/sdk-python/blob/631ebaf0e20fb214b16589b45627b358048a5d77/temporalio/worker/_activity.py#L600
# we have to force it here again
progress_handler = _get_progress_handler(input.fn)
new_args = []
act_definition = _Definition.must_from_callable(input.fn)
if input.args:
data_converter = PYDANTIC_DATA_CONVERTER
arg_types = _Definition.must_from_callable(input.fn).arg_types
arg_types = act_definition.arg_types
arg_types = _without_progress(arg_types)
arg_types = arg_types[: len(input.args)]
encoded = await data_converter.encode(input.args)
new_args = await data_converter.decode(encoded, type_hints=arg_types)
new_args.append(progress_handler)
else:
new_args = [progress_handler]
injected_progress = (
progress_handler
if act_definition.is_async
else _sync_progress(progress_handler)
)
new_args.append(injected_progress)
new_input = dataclasses.replace(input, args=new_args)
await progress_handler(0.0)
res = await super().execute_activity(new_input)
Expand Down Expand Up @@ -374,3 +391,15 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A
if heartbeat_task:
heartbeat_task.cancel()
await asyncio.wait([heartbeat_task])


def _sync_progress(
progress_handler: AsyncProgressRateHandler,
) -> SyncProgressRateHandler:
@wraps(progress_handler)
def p(progress: float, event_loop: asyncio.AbstractEventLoop) -> None:
asyncio.run_coroutine_threadsafe(
progress_handler(progress), event_loop
).result()

return p
24 changes: 20 additions & 4 deletions datashare-python/datashare_python/types_.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import Coroutine
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from dataclasses import dataclass
Expand All @@ -8,19 +9,34 @@
TemporalClient = Client


class ProgressRateHandler(Protocol):
class AsyncProgressRateHandler(Protocol):
async def __call__(self, progress_rate: float) -> None:
pass


class SyncProgressRateHandler(Protocol):
def __call__(
self, progress_rate: float, event_loop: asyncio.AbstractEventLoop
) -> None:
pass


ProgressRateHandler = SyncProgressRateHandler | AsyncProgressRateHandler


@dataclass
class Weight:
value: float


class RawProgressHandler(Protocol):
async def __call__(self, iteration: int) -> None:
pass
class RawAsyncProgressHandler(Protocol):
async def __call__(self, iteration: int) -> None: ...


class RawSyncProgressHandler(Protocol):
async def __call__(
self, iteration: int, event_loop: asyncio.AbstractEventLoop
) -> None: ...


FactoryReturnType = (
Expand Down
42 changes: 31 additions & 11 deletions datashare-python/datashare_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any, ParamSpec, TypeVar
from uuid import uuid4

import nest_asyncio
import temporalio
from pydantic import ValidationError
from temporalio import activity, workflow
Expand All @@ -31,9 +30,15 @@
)
from temporalio.exceptions import ApplicationError

from datashare_python.types_ import (
AsyncProgressRateHandler,
RawSyncProgressHandler,
SyncProgressRateHandler,
)

from .constants import METADATA_JSON
from .objects import DocArtifact, DocumentLocation, FilesystemDocument
from .types_ import ProgressRateHandler, RawProgressHandler
from .types_ import RawAsyncProgressHandler

DependencyLabel = str | None
DependencySetup = Callable[..., None]
Expand Down Expand Up @@ -76,10 +81,13 @@ def to_progress(self) -> Progress:


class ActivityWithProgress:
def __init__(self, temporal_client: Client, event_loop: asyncio.AbstractEventLoop):
def __init__(
self,
temporal_client: Client,
event_loop: asyncio.AbstractEventLoop | None = None,
):
self._temporal_client = temporal_client
nest_asyncio.apply()
self._event_loop = event_loop
self._event_loop = event_loop or asyncio.get_event_loop()


class WorkflowWithProgress:
Expand Down Expand Up @@ -276,9 +284,9 @@ def fatal_error_from_exception(exc: Exception) -> ApplicationError:
return ApplicationError(str(exc), details, type=exc_type, non_retryable=True)


def to_raw_progress(
progress: ProgressRateHandler, max_progress: int
) -> RawProgressHandler:
def to_raw_async_progress(
progress: AsyncProgressRateHandler, max_progress: int
) -> RawAsyncProgressHandler:
if not max_progress > 0:
raise ValueError("max_progress must be > 0")

Expand All @@ -288,9 +296,21 @@ async def raw(p: int) -> None:
return raw


def to_scaled_progress(
progress: ProgressRateHandler, *, start: float = 0.0, end: float = 1.0
) -> ProgressRateHandler:
def to_raw_sync_progress(
progress: SyncProgressRateHandler, max_progress: int
) -> RawSyncProgressHandler:
if not max_progress > 0:
raise ValueError("max_progress must be > 0")

def raw(iteration: int, event_loop: asyncio.AbstractEventLoop) -> None:
progress(iteration / max_progress, event_loop)

return raw


def to_scaled_async_progress(
progress: AsyncProgressRateHandler, *, start: float = 0.0, end: float = 1.0
) -> AsyncProgressRateHandler:
if not 0 <= start < end:
raise ValueError("start must be [0, end[")
if not start < end <= 1.0:
Expand Down
5 changes: 4 additions & 1 deletion datashare-python/datashare_python/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
import os
Expand Down Expand Up @@ -157,7 +158,7 @@ async def worker_context(
workflows: list[type] | None = None,
worker_config: WorkerConfig,
client: TemporalClient,
event_loop: AbstractEventLoop,
event_loop: AbstractEventLoop | None = None,
task_queue: str,
dependencies: list[ContextManagerFactory] | None = None,
sandboxed: bool = True,
Expand All @@ -169,6 +170,8 @@ async def worker_context(
discovered.extend(workflows)
if dependencies is not None:
discovered.extend(dependencies)
if event_loop is None:
event_loop = asyncio.get_event_loop()
discovered.append(worker_config)
loggers = copy(worker_config.logging.loggers)
discovered_loggers = {_get_object_package(o).__name__ for o in discovered}
Expand Down
2 changes: 1 addition & 1 deletion datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ dependencies = [
"hatchling~=1.27",
"icij-common[elasticsearch]~=0.8.2",
"langcodes~=3.5",
"nest-asyncio~=1.6",
"orjson~=3.11",
"python-json-logger~=4.0",
"pyyaml~=6.0",
Expand Down Expand Up @@ -76,6 +75,7 @@ dev = [

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_debug = true
asyncio_default_fixture_loop_scope = "session"
markers = [
"integration",
Expand Down
Loading
Loading