From e92af4b15780d707a0bd8d05c09759b3169ce2c1 Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Wed, 31 Jan 2024 14:51:17 -0500 Subject: [PATCH] Introducing a subscription API for autonomous task scheduling (#11779) In earlier work, we've introduced autonomous task scheduling, where tasks outside a flow run are created as scheduled and picked up by one or more processes running `Task.serve`. In our initial implementation, we used a polling approach where each `TaskSever` would make requests from the API to look for any tasks that were currently `Scheduled`, and then move them to `Running` as they entered the task engine. This work introduces a new mechanism for `TaskServer`s to get work from their Prefect Server: a long-lived websocket connection subscribed to a queue of `TaskRun`s to be worked. Because the Prefect Server is a singleton, it can govern a queue in-memory that will be distributed out among each of the `TaskServer`s to make a simple task brokering system. The websocket implementation is modeled on the `events/in` and `events/out` websockets in Prefect Cloud, and it's expected that we'd negotiate authentication in a common way across all websockets. Note: this does not address issues of resiliency, like what happens if the Prefect Server is restarted (in-flight tasks would be lost), or if there are no `TaskServer`s draining the Queue (the Prefect Server would eventually run out of memory), or if a `TaskServer` died before transitioning a task to `Running` (the task would remain `Scheduled` and never get picked up). These are some of the items I'd like to address in future work if we like this direction. Co-authored-by: Nathan Nowack Co-authored-by: Andrew Brookins --- src/prefect/client/orchestration.py | 2 +- src/prefect/client/schemas/objects.py | 1 + src/prefect/client/subscriptions.py | 82 ++++++++ src/prefect/context.py | 6 +- src/prefect/engine.py | 97 ++++++--- src/prefect/results.py | 25 ++- src/prefect/server/api/task_runs.py | 80 +++++++- src/prefect/server/schemas/states.py | 1 + src/prefect/server/utilities/server.py | 12 +- src/prefect/server/utilities/subscriptions.py | 49 +++++ src/prefect/settings.py | 12 ++ src/prefect/task_engine.py | 70 +++++++ src/prefect/task_server.py | 186 ++++++++++++++++++ src/prefect/tasks.py | 25 ++- tests/test_autonomous_tasks.py | 107 ++++++++++ 15 files changed, 718 insertions(+), 37 deletions(-) create mode 100644 src/prefect/client/subscriptions.py create mode 100644 src/prefect/server/utilities/subscriptions.py create mode 100644 src/prefect/task_engine.py create mode 100644 src/prefect/task_server.py create mode 100644 tests/test_autonomous_tasks.py diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index 41c05c0b9b0f..dc4aedf67440 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -1961,7 +1961,7 @@ async def set_task_run_name(self, task_run_id: UUID, name: str): async def create_task_run( self, task: "TaskObject", - flow_run_id: UUID, + flow_run_id: Optional[UUID], dynamic_key: str, name: str = None, extra_tags: Iterable[str] = None, diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 556961a65320..035f5b8760d1 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -113,6 +113,7 @@ class StateDetails(PrefectBaseModel): refresh_cache: bool = None retriable: bool = None transition_id: Optional[UUID] = None + task_parameters_id: Optional[UUID] = None class State(ObjectBaseModel, Generic[R]): diff --git a/src/prefect/client/subscriptions.py b/src/prefect/client/subscriptions.py new file mode 100644 index 000000000000..533347c1433b --- /dev/null +++ b/src/prefect/client/subscriptions.py @@ -0,0 +1,82 @@ +import asyncio +from typing import Generic, Type, TypeVar + +import orjson +import websockets +import websockets.exceptions +from starlette.status import WS_1008_POLICY_VIOLATION +from typing_extensions import Self + +from prefect._internal.schemas.bases import IDBaseModel +from prefect.settings import PREFECT_API_KEY, PREFECT_API_URL + +S = TypeVar("S", bound=IDBaseModel) + + +class Subscription(Generic[S]): + def __init__(self, model: Type[S], path: str): + self.model = model + + base_url = PREFECT_API_URL.value().replace("http", "ws", 1) + self.subscription_url = f"{base_url}{path}" + + self._connect = websockets.connect( + self.subscription_url, + subprotocols=["prefect"], + ) + self._websocket = None + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> S: + while True: + try: + await self._ensure_connected() + message = await self._websocket.recv() + + message_data = orjson.loads(message) + + if message_data.get("type") == "ping": + await self._websocket.send(orjson.dumps({"type": "pong"}).decode()) + continue + + return self.model.parse_raw(message) + except ( + ConnectionRefusedError, + websockets.exceptions.ConnectionClosedError, + ): + self._websocket = None + if hasattr(self._connect, "protocol"): + await self._connect.__aexit__(None, None, None) + await asyncio.sleep(0.5) + + async def _ensure_connected(self): + if self._websocket: + return + + websocket = await self._connect.__aenter__() + + await websocket.send( + orjson.dumps({"type": "auth", "token": PREFECT_API_KEY.value()}).decode() + ) + + try: + auth = orjson.loads(await websocket.recv()) + assert auth["type"] == "auth_success" + except ( + AssertionError, + websockets.exceptions.ConnectionClosedError, + ) as e: + if isinstance(e, AssertionError) or e.code == WS_1008_POLICY_VIOLATION: + raise Exception( + "Unable to authenticate to the subscription. Please " + "ensure the provided `PREFECT_API_KEY` you are using is " + "valid for this environment." + ) from e + raise + else: + self._websocket = websocket + + def __repr__(self) -> str: + return f"{type(self).__name__}[{self.model.__name__}]" diff --git a/src/prefect/context.py b/src/prefect/context.py index 7bc6b757afa9..3bbfce938fdd 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -214,7 +214,7 @@ class RunContext(ContextModel): client: PrefectClient -class FlowRunContext(RunContext): +class EngineContext(RunContext): """ The context for a flow run. Data in this context is only available from within a flow run function. @@ -233,6 +233,7 @@ class FlowRunContext(RunContext): flow: Optional["Flow"] = None flow_run: Optional[FlowRun] = None + autonomous_task_run: Optional[TaskRun] = None task_runner: BaseTaskRunner log_prints: bool = False parameters: Dict[str, Any] @@ -266,6 +267,9 @@ class FlowRunContext(RunContext): __var__ = ContextVar("flow_run") +FlowRunContext = EngineContext # for backwards compatibility + + class TaskRunContext(RunContext): """ The context for a task run. Data in this context is only available from within a diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 4982e99aecf5..3830d0c53d01 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -169,6 +169,7 @@ from prefect.results import BaseResult, ResultFactory, UnknownResult from prefect.settings import ( PREFECT_DEBUG_MODE, + PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING, PREFECT_LOGGING_LOG_PRINTS, PREFECT_TASK_INTROSPECTION_WARN_THRESHOLD, PREFECT_TASKS_REFRESH_CACHE, @@ -179,6 +180,7 @@ Paused, Pending, Running, + Scheduled, State, Suspended, exception_to_crashed_state, @@ -213,6 +215,7 @@ T = TypeVar("T", bound=RunInput) EngineReturnType = Literal["future", "state", "result"] +NUM_CHARS_DYNAMIC_KEY = 8 API_HEALTHCHECKS = {} UNTRACKABLE_TYPES = {bool, type(None), type(...), type(NotImplemented)} @@ -1382,16 +1385,19 @@ def enter_task_run_engine( return_type: EngineReturnType, task_runner: Optional[BaseTaskRunner], mapped: bool, -) -> Union[PrefectFuture, Awaitable[PrefectFuture]]: - """ - Sync entrypoint for task calls - """ +) -> Union[PrefectFuture, Awaitable[PrefectFuture], TaskRun]: + """Sync entrypoint for task calls""" flow_run_context = FlowRunContext.get() + if not flow_run_context: + if PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value(): + return _create_autonomous_task_run(task=task, parameters=parameters) + raise RuntimeError( - "Tasks cannot be run outside of a flow. To call the underlying task" - " function outside of a flow use `task.fn()`." + "Tasks cannot be run outside of a flow" + " - if you meant to submit an autonomous task, you need to set" + " `prefect config set PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING=true`" ) if TaskRunContext.get(): @@ -1603,14 +1609,22 @@ async def create_task_run_future( # Generate a name for the future dynamic_key = _dynamic_key_for_task_run(flow_run_context, task) - task_run_name = f"{task.name}-{dynamic_key}" + task_run_name = ( + f"{task.name}-{dynamic_key}" + if flow_run_context and flow_run_context.flow_run + else f"{task.name}-{dynamic_key[:NUM_CHARS_DYNAMIC_KEY]}" # autonomous task run + ) # Generate a future future = PrefectFuture( name=task_run_name, key=uuid4(), task_runner=task_runner, - asynchronous=task.isasync and flow_run_context.flow.isasync, + asynchronous=( + task.isasync and flow_run_context.flow.isasync + if flow_run_context and flow_run_context.flow + else task.isasync + ), ) # Create and submit the task run in the background @@ -1650,14 +1664,18 @@ async def create_task_run_then_submit( task_runner: BaseTaskRunner, extra_task_inputs: Dict[str, Set[TaskRunInput]], ) -> None: - task_run = await create_task_run( - task=task, - name=task_run_name, - flow_run_context=flow_run_context, - parameters=parameters, - dynamic_key=task_run_dynamic_key, - wait_for=wait_for, - extra_task_inputs=extra_task_inputs, + task_run = ( + await create_task_run( + task=task, + name=task_run_name, + flow_run_context=flow_run_context, + parameters=parameters, + dynamic_key=task_run_dynamic_key, + wait_for=wait_for, + extra_task_inputs=extra_task_inputs, + ) + if not flow_run_context.autonomous_task_run + else flow_run_context.autonomous_task_run ) # Attach the task run to the future to support `get_state` operations @@ -1698,7 +1716,7 @@ async def create_task_run( task_run = await flow_run_context.client.create_task_run( task=task, name=name, - flow_run_id=flow_run_context.flow_run.id, + flow_run_id=flow_run_context.flow_run.id if flow_run_context.flow_run else None, dynamic_key=dynamic_key, state=Pending(), extra_tags=TagsContext.get().current_tags, @@ -1721,7 +1739,10 @@ async def submit_task_run( ) -> PrefectFuture: logger = get_run_logger(flow_run_context) - if task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL: + if ( + task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL + and not flow_run_context.autonomous_task_run + ): logger.info(f"Executing {task_run.name!r} immediately...") future = await task_runner.submit( @@ -1799,7 +1820,7 @@ async def begin_task_run( # worker, the flow run timeout will not be raised in the worker process. interruptible = maybe_flow_run_context.timeout_scope is not None else: - # Otherwise, retrieve a new client + # Otherwise, retrieve a new clien`t client = await stack.enter_async_context(get_client()) interruptible = False await stack.enter_async_context(anyio.create_task_group()) @@ -2153,7 +2174,6 @@ async def tick(): await _check_task_failure_retriable(task, task_run, terminal_state) ) state = await propose_state(client, terminal_state, task_run_id=task_run.id) - last_event = _emit_task_run_state_change_event( task_run=task_run, initial_state=last_state, @@ -2203,7 +2223,7 @@ async def tick(): level=logging.INFO if state.is_completed() else logging.ERROR, msg=f"Finished in state {display_state}", ) - + logger.warning(f"Task run {task_run.name!r} finished in state {display_state}") return state @@ -2572,7 +2592,12 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: def _dynamic_key_for_task_run(context: FlowRunContext, task: Task) -> int: - if task.task_key not in context.task_run_dynamic_keys: + if context.flow_run is None: # this is an autonomous task run + context.task_run_dynamic_keys[task.task_key] = getattr( + task, "dynamic_key", str(uuid4()) + ) + + elif task.task_key not in context.task_run_dynamic_keys: context.task_run_dynamic_keys[task.task_key] = 0 else: context.task_run_dynamic_keys[task.task_key] += 1 @@ -2912,6 +2937,34 @@ def _emit_task_run_state_change_event( ) +@sync_compatible +async def _create_autonomous_task_run( + task: Task, parameters: Dict[str, Any] +) -> TaskRun: + async with get_client() as client: + scheduled = Scheduled() + if parameters: + parameters_id = uuid4() + scheduled.state_details.task_parameters_id = parameters_id + + # TODO: We want to use result storage for parameters, but we'll need + # a better way to use it than this. + task.persist_result = True + factory = await ResultFactory.from_task(task, client=client) + await factory.store_parameters(parameters_id, parameters) + + task_run = await client.create_task_run( + task=task, + flow_run_id=None, + dynamic_key=f"{task.task_key}-{str(uuid4())[:NUM_CHARS_DYNAMIC_KEY]}", + state=scheduled, + ) + + engine_logger.debug(f"Submitted run of task {task.name!r} for execution") + + return task_run + + if __name__ == "__main__": try: flow_run_id = UUID( diff --git a/src/prefect/results.py b/src/prefect/results.py index ff772e6395d9..9b0dcfcc90ac 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -5,6 +5,7 @@ TYPE_CHECKING, Any, Callable, + Dict, Generic, Optional, Tuple, @@ -12,6 +13,7 @@ TypeVar, Union, ) +from uuid import UUID from typing_extensions import Self @@ -51,7 +53,7 @@ ResultStorage = Union[WritableFileSystem, str] ResultSerializer = Union[Serializer, str] -LITERAL_TYPES = {type(None), bool} +LITERAL_TYPES = {type(None), bool, UUID} def DEFAULT_STORAGE_KEY_FN(): @@ -383,6 +385,27 @@ async def create_result(self, obj: R) -> Union[R, "BaseResult[R]"]: cache_object=should_cache_object, ) + @sync_compatible + async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]): + assert ( + self.storage_block_id is not None + ), "Unexpected storage block ID. Was it persisted?" + data = self.serializer.dumps(parameters) + blob = PersistedResultBlob(serializer=self.serializer, data=data) + await self.storage_block.write_path( + f"parameters/{identifier}", content=blob.to_bytes() + ) + + @sync_compatible + async def read_parameters(self, identifier: UUID) -> Dict[str, Any]: + assert ( + self.storage_block_id is not None + ), "Unexpected storage block ID. Was it persisted?" + blob = PersistedResultBlob.parse_raw( + await self.storage_block.read_path(f"parameters/{identifier}") + ) + return self.serializer.loads(blob.data) + @add_type_dispatch class BaseResult(pydantic.BaseModel, abc.ABC, Generic[R]): diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py index c78a0d221df6..fca381ad6a8b 100644 --- a/src/prefect/server/api/task_runs.py +++ b/src/prefect/server/api/task_runs.py @@ -2,28 +2,60 @@ Routes for interacting with task run objects. """ +import asyncio import datetime -from typing import List +from typing import Dict, List from uuid import UUID import pendulum -from prefect._vendor.fastapi import Body, Depends, HTTPException, Path, Response, status +from prefect._vendor.fastapi import ( + Body, + Depends, + HTTPException, + Path, + Response, + WebSocket, + status, +) import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas +from prefect.logging import get_logger from prefect.server.api.run_history import run_history from prefect.server.database.dependencies import provide_database_interface from prefect.server.database.interface import PrefectDBInterface from prefect.server.orchestration import dependencies as orchestration_dependencies from prefect.server.orchestration.policies import BaseOrchestrationPolicy from prefect.server.schemas.responses import OrchestrationResult +from prefect.server.utilities import subscriptions from prefect.server.utilities.schemas import DateTimeTZ from prefect.server.utilities.server import PrefectRouter +from prefect.settings import PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING + +logger = get_logger("server.api") router = PrefectRouter(prefix="/task_runs", tags=["Task Runs"]) +_scheduled_task_runs_queues: Dict[asyncio.AbstractEventLoop, asyncio.Queue] = {} +_retry_task_runs_queues: Dict[asyncio.AbstractEventLoop, asyncio.Queue] = {} + + +def scheduled_task_runs_queue() -> asyncio.Queue: + loop = asyncio.get_event_loop() + if loop not in _scheduled_task_runs_queues: + _scheduled_task_runs_queues[loop] = asyncio.Queue() + return _scheduled_task_runs_queues[loop] + + +def retry_task_runs_queue() -> asyncio.Queue: + loop = asyncio.get_event_loop() + if loop not in _retry_task_runs_queues: + _retry_task_runs_queues[loop] = asyncio.Queue() + return _retry_task_runs_queues[loop] + + @router.post("/") async def create_task_run( task_run: schemas.actions.TaskRunCreate, @@ -57,7 +89,19 @@ async def create_task_run( if model.created >= now: response.status_code = status.HTTP_201_CREATED - return model + + new_task_run: schemas.core.TaskRun = schemas.core.TaskRun.from_orm(model) + + # Place autonomously scheduled task runs onto a notification queue for the websocket + if ( + PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value() + and new_task_run.flow_run_id is None + and new_task_run.state + and new_task_run.state.is_scheduled() + ): + await scheduled_task_runs_queue().put(new_task_run) + + return new_task_run @router.patch("/{id}", status_code=status.HTTP_204_NO_CONTENT) @@ -244,3 +288,33 @@ async def set_task_run_state( response.status_code = status.HTTP_200_OK return orchestration_result + + +@router.websocket("/subscriptions/scheduled") +async def scheduled_task_subscription(websocket: WebSocket): + websocket = await subscriptions.accept_prefect_socket(websocket) + if not websocket: + return + + scheduled_queue = scheduled_task_runs_queue() + retry_queue = retry_task_runs_queue() + + while True: + task_run: schemas.core.TaskRun = None + # First, check if there's anything in the retry queue + if not retry_queue.empty(): + task_run = await retry_queue.get() + else: + task_run = await scheduled_queue.get() + + try: + await websocket.send_json(task_run.dict(json_compatible=True)) + + await subscriptions.ping_pong(websocket) + + logger.debug(f"Sent task run {task_run.id!r} to websocket") + + except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS: + # If sending fails or pong fails, put the task back into the retry queue + await retry_queue.put(task_run) + break diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py index 0a62fef6e59b..a3526537ad0b 100644 --- a/src/prefect/server/schemas/states.py +++ b/src/prefect/server/schemas/states.py @@ -67,6 +67,7 @@ class StateDetails(PrefectBaseModel): refresh_cache: bool = None retriable: bool = None transition_id: Optional[UUID] = None + task_parameters_id: Optional[UUID] = None class StateBaseModel(IDBaseModel): diff --git a/src/prefect/server/utilities/server.py b/src/prefect/server/utilities/server.py index 90521222c0cf..a1457a25bf73 100644 --- a/src/prefect/server/utilities/server.py +++ b/src/prefect/server/utilities/server.py @@ -4,13 +4,14 @@ import functools import inspect from contextlib import AsyncExitStack, asynccontextmanager -from typing import Any, Callable, Coroutine, Iterable, Set, get_type_hints +from typing import Any, Callable, Coroutine, Sequence, Set, get_type_hints from prefect._vendor.fastapi import APIRouter, Request, Response, status -from prefect._vendor.fastapi.routing import APIRoute +from prefect._vendor.fastapi.routing import APIRoute, BaseRoute +from starlette.routing import Route as StarletteRoute -def method_paths_from_routes(routes: Iterable[APIRoute]) -> Set[str]: +def method_paths_from_routes(routes: Sequence[BaseRoute]) -> Set[str]: """ Generate a set of strings describing the given routes in the format: @@ -18,8 +19,9 @@ def method_paths_from_routes(routes: Iterable[APIRoute]) -> Set[str]: """ method_paths = set() for route in routes: - for method in route.methods: - method_paths.add(f"{method} {route.path}") + if isinstance(route, (APIRoute, StarletteRoute)): + for method in route.methods: + method_paths.add(f"{method} {route.path}") return method_paths diff --git a/src/prefect/server/utilities/subscriptions.py b/src/prefect/server/utilities/subscriptions.py new file mode 100644 index 000000000000..751fa43219f1 --- /dev/null +++ b/src/prefect/server/utilities/subscriptions.py @@ -0,0 +1,49 @@ +from typing import Optional + +from prefect._vendor.fastapi import ( + WebSocket, +) +from starlette.status import WS_1002_PROTOCOL_ERROR, WS_1008_POLICY_VIOLATION +from websockets.exceptions import ConnectionClosed + +NORMAL_DISCONNECT_EXCEPTIONS = (IOError, ConnectionClosed) + + +async def ping_pong(websocket: WebSocket): + try: + await websocket.send_json({"type": "ping"}) + + response = await websocket.receive_json() + if response.get("type") == "pong": + return True + else: + return False + except Exception: + return False + + +async def accept_prefect_socket(websocket: WebSocket) -> Optional[WebSocket]: + subprotocols = websocket.headers.get("Sec-WebSocket-Protocol", "").split(",") + if "prefect" not in subprotocols: + return await websocket.close(WS_1002_PROTOCOL_ERROR) + + await websocket.accept(subprotocol="prefect") + + try: + # Websocket connections are authenticated via messages. The first + # message is expected to be an auth message, and if any other type of + # message is received then the connection will be closed. + # + # There is no authentication in Prefect Server, but the protocol requires + # that we receive and return the auth message for compatibility with Prefect + # Cloud. + message = await websocket.receive_json() + if message["type"] != "auth": + return await websocket.close(WS_1008_POLICY_VIOLATION) + + await websocket.send_json({"type": "auth_success"}) + return websocket + + except NORMAL_DISCONNECT_EXCEPTIONS: + # it's fine if a client disconnects either normally or abnormally + return None diff --git a/src/prefect/settings.py b/src/prefect/settings.py index bf7396825f7e..14d236b70108 100644 --- a/src/prefect/settings.py +++ b/src/prefect/settings.py @@ -1408,6 +1408,13 @@ def default_cloud_ui_url(settings, value): """ The port the worker's webserver should bind to. """ +PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS = Setting( + bool, + default=True, +) +""" +Whether or not to delete failed task submissions from the database. +""" PREFECT_EXPERIMENTAL_ENABLE_EXTRA_RUNNER_ENDPOINTS = Setting(bool, default=False) """ @@ -1435,6 +1442,11 @@ def default_cloud_ui_url(settings, value): Whether or not to warn when the experimental workspace dashboard is enabled. """ +PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING = Setting(bool, default=False) +""" +Whether or not to enable experimental task scheduling. +""" + # Defaults ----------------------------------------------------------------------------- PREFECT_DEFAULT_RESULT_STORAGE_BLOCK = Setting( diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py new file mode 100644 index 000000000000..7b6fdb014230 --- /dev/null +++ b/src/prefect/task_engine.py @@ -0,0 +1,70 @@ +from contextlib import AsyncExitStack +from typing import ( + Any, + Dict, + Iterable, + Optional, + Type, +) + +import anyio +from anyio import start_blocking_portal +from typing_extensions import Literal + +from prefect._internal.concurrency.api import create_call, from_async, from_sync +from prefect.client.orchestration import get_client +from prefect.client.schemas.objects import TaskRun +from prefect.context import EngineContext +from prefect.engine import ( + begin_task_map, + get_task_call_return_value, +) +from prefect.futures import PrefectFuture +from prefect.results import ResultFactory +from prefect.task_runners import BaseTaskRunner, SequentialTaskRunner +from prefect.tasks import Task +from prefect.utilities.asyncutils import sync_compatible + +EngineReturnType = Literal["future", "state", "result"] + + +@sync_compatible +async def submit_autonomous_task_to_engine( + task: Task, + task_run: TaskRun, + parameters: Optional[Dict] = None, + wait_for: Optional[Iterable[PrefectFuture]] = None, + mapped: bool = False, + return_type: EngineReturnType = "future", + task_runner: Optional[Type[BaseTaskRunner]] = None, +) -> Any: + parameters = parameters or {} + async with AsyncExitStack() as stack: + with EngineContext( + flow=None, + flow_run=None, + autonomous_task_run=task_run, + task_runner=await stack.enter_async_context( + (task_runner if task_runner else SequentialTaskRunner()).start() + ), + client=await stack.enter_async_context(get_client()), + parameters=parameters, + result_factory=await ResultFactory.from_task(task), + background_tasks=await stack.enter_async_context(anyio.create_task_group()), + sync_portal=( + stack.enter_context(start_blocking_portal()) if task.isasync else None + ), + ) as flow_run_context: + begin_run = create_call( + begin_task_map if mapped else get_task_call_return_value, + task=task, + flow_run_context=flow_run_context, + parameters=parameters, + wait_for=wait_for, + return_type=return_type, + task_runner=task_runner, + ) + if task.isasync: + return await from_async.wait_for_call_in_loop_thread(begin_run) + else: + return from_sync.wait_for_call_in_loop_thread(begin_run) diff --git a/src/prefect/task_server.py b/src/prefect/task_server.py new file mode 100644 index 000000000000..35c77fad4a91 --- /dev/null +++ b/src/prefect/task_server.py @@ -0,0 +1,186 @@ +import asyncio +import signal +import sys +from functools import partial +from typing import Iterable, Optional + +import anyio +import anyio.abc +import pendulum + +from prefect import Task, get_client +from prefect._internal.concurrency.api import create_call, from_sync +from prefect.client.schemas.objects import TaskRun +from prefect.client.subscriptions import Subscription +from prefect.logging.loggers import get_logger +from prefect.results import ResultFactory +from prefect.settings import PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS +from prefect.task_engine import submit_autonomous_task_to_engine +from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible +from prefect.utilities.processutils import _register_signal + +logger = get_logger("task_server") + + +class TaskServer: + """This class is responsible for serving tasks that may be executed autonomously + (i.e., without a parent flow run). + + When `start()` is called, the task server will subscribe to the task run scheduling + topic and poll for scheduled task runs. When a scheduled task run is found, it + will submit the task run to the engine for execution, using `submit_autonomous_task_to_engine` + to construct a minimal `EngineContext` for the task run. + + Args: + - tasks: A list of tasks to serve. These tasks will be submitted to the engine + when a scheduled task run is found. + - tags: A list of tags to apply to the task server. Defaults to `["autonomous"]`. + """ + + def __init__( + self, + *tasks: Task, + tags: Optional[Iterable[str]] = None, + ): + self.tasks: list[Task] = tasks + self.tags: Iterable[str] = tags or ["autonomous"] + self.last_polled: Optional[pendulum.DateTime] = None + self.started = False + self.stopping = False + + self._client = get_client() + + self._runs_task_group: anyio.abc.TaskGroup = anyio.create_task_group() + self._loops_task_group: anyio.abc.TaskGroup = anyio.create_task_group() + + def handle_sigterm(self, signum, frame): + """ + Shuts down the task server when a SIGTERM is received. + """ + logger.info("SIGTERM received, initiating graceful shutdown...") + from_sync.call_in_loop_thread(create_call(self.stop)) + + sys.exit(0) + + @sync_compatible + async def start(self) -> None: + """ + Starts a task server, which runs the tasks provided in the constructor. + """ + _register_signal(signal.SIGTERM, self.handle_sigterm) + + async with asyncnullcontext() if self.started else self: + async with self._loops_task_group as tg: + tg.start_soon(self._subscribe_to_task_scheduling) + + @sync_compatible + async def stop(self): + """Stops the task server's polling cycle.""" + if not self.started: + raise RuntimeError( + "Task server has not yet started. Please start the task server by" + " calling .start()" + ) + + logger.info("Stopping task server...") + self.started = False + self.stopping = True + try: + self._loops_task_group.cancel_scope.cancel() + except Exception: + logger.exception("Exception encountered while shutting down", exc_info=True) + + async def run_once(self): + """Runs one iteration of the task server's polling cycle (used for testing)""" + async with self._runs_task_group: + await self._get_and_submit_task_runs() + + async def _subscribe_to_task_scheduling(self): + subscription = Subscription(TaskRun, "/task_runs/subscriptions/scheduled") + logger.debug(f"Created: {subscription}") + async for task_run in subscription: + logger.info(f"Received task run: {task_run.id} - {task_run.name}") + await self._submit_pending_task_run(task_run) + + async def _submit_pending_task_run(self, task_run: TaskRun): + logger.debug( + f"Found task run: {task_run.name!r} in state: {task_run.state.name!r}" + ) + + task = next((t for t in self.tasks if t.name in task_run.task_key), None) + + if not task: + if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS.value(): + logger.warning( + f"Task {task_run.name!r} not found in task server registry." + ) + await self._client._client.delete(f"/task_runs/{task_run.id}") + + return + + # The ID of the parameters for this run are stored in the Scheduled state's + # state_details. If there is no parameters_id, then the task was created + # without parameters. + parameters = {} + if hasattr(task_run.state.state_details, "task_parameters_id"): + parameters_id = task_run.state.state_details.task_parameters_id + task.persist_result = True + factory = await ResultFactory.from_task(task) + try: + parameters = await factory.read_parameters(parameters_id) + except Exception as exc: + logger.exception( + f"Failed to read parameters for task run {task_run.id!r}", + exc_info=exc, + ) + if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS.value(): + logger.info( + f"Deleting task run {task_run.id!r} because it failed to submit" + ) + await self._client._client.delete(f"/task_runs/{task_run.id}") + return + + logger.debug( + f"Submitting run {task_run.name!r} of task {task.name!r} to engine" + ) + + self._runs_task_group.start_soon( + partial( + submit_autonomous_task_to_engine, + task=task, + task_run=task_run, + parameters=parameters, + ) + ) + + async def __aenter__(self): + logger.debug("Starting task server...") + self._client = get_client() + await self._client.__aenter__() + await self._runs_task_group.__aenter__() + + self.started = True + return self + + async def __aexit__(self, *exc_info): + logger.debug("Stopping task server...") + self.started = False + if self._runs_task_group: + await self._runs_task_group.__aexit__(*exc_info) + if self._client: + await self._client.__aexit__(*exc_info) + + +def serve( + *tasks: Task, + tags: Optional[Iterable[str]] = None, + run_once: bool = False, +): + async def run_server(): + task_server = TaskServer(*tasks, tags=tags) + if run_once: + await task_server.run_once() + else: + await task_server.start() + + asyncio.run(run_server()) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 9164f92ce2f7..13f53e3630a1 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -660,18 +660,35 @@ def submit( ) -> State[T]: ... + @overload + def submit( + self: "Task[P, T]", + *args: P.args, + **kwargs: P.kwargs, + ) -> TaskRun: + ... + + @overload + def submit( + self: "Task[P, Coroutine[Any, Any, T]]", + *args: P.args, + **kwargs: P.kwargs, + ) -> Awaitable[TaskRun]: + ... + def submit( self, *args: Any, return_state: bool = False, wait_for: Optional[Iterable[PrefectFuture]] = None, **kwargs: Any, - ) -> Union[PrefectFuture, Awaitable[PrefectFuture]]: + ) -> Union[PrefectFuture, Awaitable[PrefectFuture], TaskRun, Awaitable[TaskRun]]: """ - Submit a run of the task to a worker. + Submit a run of the task to the engine. + + If writing an async task, this call must be awaited. - Must be called within a flow function. If writing an async task, this call must - be awaited. + If called from within a flow function, Will create a new task run in the backing API and submit the task to the flow's task runner. This call only blocks execution while the task is being submitted, diff --git a/tests/test_autonomous_tasks.py b/tests/test_autonomous_tasks.py new file mode 100644 index 000000000000..ea01fa6103fe --- /dev/null +++ b/tests/test_autonomous_tasks.py @@ -0,0 +1,107 @@ +import pytest + +from prefect import task +from prefect.filesystems import LocalFileSystem +from prefect.results import ResultFactory +from prefect.settings import ( + PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING, + temporary_settings, +) +from prefect.utilities.asyncutils import sync_compatible + + +@sync_compatible +async def result_factory_from_task(task): + return await ResultFactory.from_task(task) + + +@pytest.fixture +def local_filesystem(): + block = LocalFileSystem(basepath="~/.prefect/storage/test") + block.save("test-fs", overwrite=True) + return block + + +@pytest.fixture(autouse=True) +def allow_experimental_task_scheduling(): + with temporary_settings( + { + PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING: True, + # PREFECT_DEFAULT_RESULT_STORAGE_BLOCK: "local-filesystem/test-fs", + } + ): + yield + + +@pytest.fixture +def foo_task(): + @task + def foo(x: int) -> int: + print(x) + return x + + return foo + + +@pytest.fixture +def async_foo_task(): + @task + async def async_foo(x: int) -> int: + print(x) + return x + + return async_foo + + +@pytest.fixture +def foo_task_with_result_storage(foo_task, local_filesystem): + return foo_task.with_options(result_storage=local_filesystem) + + +@pytest.fixture +def async_foo_task_with_result_storage(async_foo_task, local_filesystem): + return async_foo_task.with_options(result_storage=local_filesystem) + + +def test_task_submission_fails_when_experimental_flag_off(foo_task): + with temporary_settings({PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING: False}): + with pytest.raises(RuntimeError, match="Tasks cannot be run outside of a flow"): + foo_task.submit(42) + + +def test_task_submission_with_parameters_fails_without_result_storage(foo_task): + foo_task_without_result_storage = foo_task.with_options(result_storage=None) + task_run = foo_task_without_result_storage.submit(42) + + result_factory = result_factory_from_task(foo_task) + + with pytest.raises(AssertionError, match="Was it persisted?"): + result_factory.read_parameters(task_run.state.state_details.task_parameters_id) + + +def test_task_submission_creates_a_scheduled_task_run(foo_task_with_result_storage): + task_run = foo_task_with_result_storage.submit(42) + assert task_run.state.is_scheduled() + + result_factory = result_factory_from_task(foo_task_with_result_storage) + + parameters = result_factory.read_parameters( + task_run.state.state_details.task_parameters_id + ) + + assert parameters == dict(x=42) + + +async def test_async_task_submission_creates_a_scheduled_task_run( + async_foo_task_with_result_storage, +): + task_run = await async_foo_task_with_result_storage.submit(42) + assert task_run.state.is_scheduled() + + result_factory = await result_factory_from_task(async_foo_task_with_result_storage) + + parameters = await result_factory.read_parameters( + task_run.state.state_details.task_parameters_id + ) + + assert parameters == dict(x=42)