Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing a subscription API for autonomous task scheduling #11779

Merged
merged 31 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
083fe2d
init task engine
zzstoatzz Jan 22, 2024
bdc2a1d
Merge branch 'main' of https://github.com/PrefectHQ/prefect into init…
zzstoatzz Jan 23, 2024
768ed79
exploring task engine
zzstoatzz Jan 23, 2024
52e5bf8
change name
zzstoatzz Jan 24, 2024
69f8bf6
add _not_ task runner
zzstoatzz Jan 24, 2024
57eaefb
background
zzstoatzz Jan 24, 2024
0fc1d12
rm comment
zzstoatzz Jan 24, 2024
f2cd571
Merge branch 'main' of https://github.com/PrefectHQ/prefect into init…
zzstoatzz Jan 24, 2024
26214c0
init init task server
zzstoatzz Jan 25, 2024
e32dbae
wip
zzstoatzz Jan 25, 2024
dee5d4e
rm breakpoint
zzstoatzz Jan 25, 2024
ea6d5b3
Merge branch 'main' of https://github.com/PrefectHQ/prefect into init…
zzstoatzz Jan 25, 2024
77a55df
task parameters in storage (#11736)
chrisguidry Jan 25, 2024
5fd8ffb
run pre-commits
zzstoatzz Jan 25, 2024
9c4ef1b
track task run params in state details to avoid quantum entanglement …
abrookins Jan 26, 2024
e74e3c1
Giving TaskServer a `run_once` testing method
chrisguidry Jan 26, 2024
d854a69
update logging
zzstoatzz Jan 26, 2024
2bc96f4
Merge branch 'init-task-engine' of https://github.com/PrefectHQ/prefe…
zzstoatzz Jan 26, 2024
633f5fb
add some basic tests
zzstoatzz Jan 29, 2024
62786f3
add docstring
zzstoatzz Jan 29, 2024
1684731
add `TaskRunFilterFlowRunId` client and server filters
zzstoatzz Jan 29, 2024
596c150
Merge branch 'allow-filter-null-flow-run-ids' of https://github.com/P…
zzstoatzz Jan 29, 2024
7aae491
pull in new filter and update setting name and mask flow run log
zzstoatzz Jan 29, 2024
3303016
merge conflicts + new filter
zzstoatzz Jan 30, 2024
d678199
Introducing a subscription API for autonomous task scheduling
chrisguidry Jan 30, 2024
6a13bbe
[task scheduling] task server tweaks (#11785)
zzstoatzz Jan 31, 2024
366ef46
Make sure that the scheduled task runs queue is created on an event loop
chrisguidry Jan 31, 2024
254a89f
Task subscription api retry queue (#11789)
zzstoatzz Jan 31, 2024
4f7f581
Merge branch 'main' into task-subscription-api
chrisguidry Jan 31, 2024
6ee16c0
Using the experimental flag for the enqueuing of automated tasks
chrisguidry Jan 31, 2024
d7035c9
Removing earlier sketch
chrisguidry Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,7 +1961,7 @@ async def set_task_run_name(self, task_run_id: UUID, name: str):
async def create_task_run(
self,
task: "TaskObject",
flow_run_id: UUID,
flow_run_id: Optional[UUID],
dynamic_key: str,
name: str = None,
extra_tags: Iterable[str] = None,
Expand Down
1 change: 1 addition & 0 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class StateDetails(PrefectBaseModel):
refresh_cache: bool = None
retriable: bool = None
transition_id: Optional[UUID] = None
task_parameters_id: Optional[UUID] = None


class State(ObjectBaseModel, Generic[R]):
Expand Down
82 changes: 82 additions & 0 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import asyncio
from typing import Generic, Type, TypeVar

import orjson
import websockets
import websockets.exceptions
from starlette.status import WS_1008_POLICY_VIOLATION
from typing_extensions import Self

from prefect._internal.schemas.bases import IDBaseModel
from prefect.settings import PREFECT_API_KEY, PREFECT_API_URL

S = TypeVar("S", bound=IDBaseModel)


class Subscription(Generic[S]):
def __init__(self, model: Type[S], path: str):
self.model = model

base_url = PREFECT_API_URL.value().replace("http", "ws", 1)
self.subscription_url = f"{base_url}{path}"

self._connect = websockets.connect(
self.subscription_url,
subprotocols=["prefect"],
)
self._websocket = None

def __aiter__(self) -> Self:
return self

async def __anext__(self) -> S:
while True:
try:
await self._ensure_connected()
message = await self._websocket.recv()

message_data = orjson.loads(message)

if message_data.get("type") == "ping":
await self._websocket.send(orjson.dumps({"type": "pong"}).decode())
continue

return self.model.parse_raw(message)
except (
ConnectionRefusedError,
websockets.exceptions.ConnectionClosedError,
):
self._websocket = None
if hasattr(self._connect, "protocol"):
await self._connect.__aexit__(None, None, None)
await asyncio.sleep(0.5)

async def _ensure_connected(self):
if self._websocket:
return

websocket = await self._connect.__aenter__()

await websocket.send(
orjson.dumps({"type": "auth", "token": PREFECT_API_KEY.value()}).decode()
)

try:
auth = orjson.loads(await websocket.recv())
assert auth["type"] == "auth_success"
except (
AssertionError,
websockets.exceptions.ConnectionClosedError,
) as e:
if isinstance(e, AssertionError) or e.code == WS_1008_POLICY_VIOLATION:
raise Exception(
"Unable to authenticate to the subscription. Please "
"ensure the provided `PREFECT_API_KEY` you are using is "
"valid for this environment."
) from e
raise
else:
self._websocket = websocket

def __repr__(self) -> str:
return f"{type(self).__name__}[{self.model.__name__}]"
6 changes: 5 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class RunContext(ContextModel):
client: PrefectClient


class FlowRunContext(RunContext):
class EngineContext(RunContext):
"""
The context for a flow run. Data in this context is only available from within a
flow run function.
Expand All @@ -233,6 +233,7 @@ class FlowRunContext(RunContext):

flow: Optional["Flow"] = None
flow_run: Optional[FlowRun] = None
autonomous_task_run: Optional[TaskRun] = None
task_runner: BaseTaskRunner
log_prints: bool = False
parameters: Dict[str, Any]
Expand Down Expand Up @@ -266,6 +267,9 @@ class FlowRunContext(RunContext):
__var__ = ContextVar("flow_run")


FlowRunContext = EngineContext # for backwards compatibility


class TaskRunContext(RunContext):
"""
The context for a task run. Data in this context is only available from within a
Expand Down
97 changes: 75 additions & 22 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
from prefect.results import BaseResult, ResultFactory, UnknownResult
from prefect.settings import (
PREFECT_DEBUG_MODE,
PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING,
PREFECT_LOGGING_LOG_PRINTS,
PREFECT_TASK_INTROSPECTION_WARN_THRESHOLD,
PREFECT_TASKS_REFRESH_CACHE,
Expand All @@ -179,6 +180,7 @@
Paused,
Pending,
Running,
Scheduled,
State,
Suspended,
exception_to_crashed_state,
Expand Down Expand Up @@ -213,6 +215,7 @@
T = TypeVar("T", bound=RunInput)
EngineReturnType = Literal["future", "state", "result"]

NUM_CHARS_DYNAMIC_KEY = 8

API_HEALTHCHECKS = {}
UNTRACKABLE_TYPES = {bool, type(None), type(...), type(NotImplemented)}
Expand Down Expand Up @@ -1382,16 +1385,19 @@ def enter_task_run_engine(
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
mapped: bool,
) -> Union[PrefectFuture, Awaitable[PrefectFuture]]:
"""
Sync entrypoint for task calls
"""
) -> Union[PrefectFuture, Awaitable[PrefectFuture], TaskRun]:
"""Sync entrypoint for task calls"""

flow_run_context = FlowRunContext.get()

if not flow_run_context:
if PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value():
return _create_autonomous_task_run(task=task, parameters=parameters)

raise RuntimeError(
"Tasks cannot be run outside of a flow. To call the underlying task"
" function outside of a flow use `task.fn()`."
"Tasks cannot be run outside of a flow"
" - if you meant to submit an autonomous task, you need to set"
" `prefect config set PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING=true`"
)

if TaskRunContext.get():
Expand Down Expand Up @@ -1603,14 +1609,22 @@ async def create_task_run_future(

# Generate a name for the future
dynamic_key = _dynamic_key_for_task_run(flow_run_context, task)
task_run_name = f"{task.name}-{dynamic_key}"
task_run_name = (
f"{task.name}-{dynamic_key}"
if flow_run_context and flow_run_context.flow_run
else f"{task.name}-{dynamic_key[:NUM_CHARS_DYNAMIC_KEY]}" # autonomous task run
)

# Generate a future
future = PrefectFuture(
name=task_run_name,
key=uuid4(),
task_runner=task_runner,
asynchronous=task.isasync and flow_run_context.flow.isasync,
asynchronous=(
task.isasync and flow_run_context.flow.isasync
if flow_run_context and flow_run_context.flow
else task.isasync
),
)

# Create and submit the task run in the background
Expand Down Expand Up @@ -1650,14 +1664,18 @@ async def create_task_run_then_submit(
task_runner: BaseTaskRunner,
extra_task_inputs: Dict[str, Set[TaskRunInput]],
) -> None:
task_run = await create_task_run(
task=task,
name=task_run_name,
flow_run_context=flow_run_context,
parameters=parameters,
dynamic_key=task_run_dynamic_key,
wait_for=wait_for,
extra_task_inputs=extra_task_inputs,
task_run = (
await create_task_run(
task=task,
name=task_run_name,
flow_run_context=flow_run_context,
parameters=parameters,
dynamic_key=task_run_dynamic_key,
wait_for=wait_for,
extra_task_inputs=extra_task_inputs,
)
if not flow_run_context.autonomous_task_run
else flow_run_context.autonomous_task_run
)

# Attach the task run to the future to support `get_state` operations
Expand Down Expand Up @@ -1698,7 +1716,7 @@ async def create_task_run(
task_run = await flow_run_context.client.create_task_run(
task=task,
name=name,
flow_run_id=flow_run_context.flow_run.id,
flow_run_id=flow_run_context.flow_run.id if flow_run_context.flow_run else None,
dynamic_key=dynamic_key,
state=Pending(),
extra_tags=TagsContext.get().current_tags,
Expand All @@ -1721,7 +1739,10 @@ async def submit_task_run(
) -> PrefectFuture:
logger = get_run_logger(flow_run_context)

if task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL:
if (
task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL
and not flow_run_context.autonomous_task_run
):
logger.info(f"Executing {task_run.name!r} immediately...")

future = await task_runner.submit(
Expand Down Expand Up @@ -1799,7 +1820,7 @@ async def begin_task_run(
# worker, the flow run timeout will not be raised in the worker process.
interruptible = maybe_flow_run_context.timeout_scope is not None
else:
# Otherwise, retrieve a new client
# Otherwise, retrieve a new clien`t
client = await stack.enter_async_context(get_client())
interruptible = False
await stack.enter_async_context(anyio.create_task_group())
Expand Down Expand Up @@ -2153,7 +2174,6 @@ async def tick():
await _check_task_failure_retriable(task, task_run, terminal_state)
)
state = await propose_state(client, terminal_state, task_run_id=task_run.id)

last_event = _emit_task_run_state_change_event(
task_run=task_run,
initial_state=last_state,
Expand Down Expand Up @@ -2203,7 +2223,7 @@ async def tick():
level=logging.INFO if state.is_completed() else logging.ERROR,
msg=f"Finished in state {display_state}",
)

logger.warning(f"Task run {task_run.name!r} finished in state {display_state}")
return state


Expand Down Expand Up @@ -2572,7 +2592,12 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult:


def _dynamic_key_for_task_run(context: FlowRunContext, task: Task) -> int:
if task.task_key not in context.task_run_dynamic_keys:
if context.flow_run is None: # this is an autonomous task run
context.task_run_dynamic_keys[task.task_key] = getattr(
task, "dynamic_key", str(uuid4())
)

elif task.task_key not in context.task_run_dynamic_keys:
context.task_run_dynamic_keys[task.task_key] = 0
else:
context.task_run_dynamic_keys[task.task_key] += 1
Expand Down Expand Up @@ -2912,6 +2937,34 @@ def _emit_task_run_state_change_event(
)


@sync_compatible
async def _create_autonomous_task_run(
task: Task, parameters: Dict[str, Any]
) -> TaskRun:
async with get_client() as client:
scheduled = Scheduled()
if parameters:
parameters_id = uuid4()
scheduled.state_details.task_parameters_id = parameters_id

# TODO: We want to use result storage for parameters, but we'll need
# a better way to use it than this.
task.persist_result = True
factory = await ResultFactory.from_task(task, client=client)
await factory.store_parameters(parameters_id, parameters)

task_run = await client.create_task_run(
task=task,
flow_run_id=None,
dynamic_key=f"{task.task_key}-{str(uuid4())[:NUM_CHARS_DYNAMIC_KEY]}",
state=scheduled,
)

engine_logger.debug(f"Submitted run of task {task.name!r} for execution")

return task_run


if __name__ == "__main__":
try:
flow_run_id = UUID(
Expand Down
25 changes: 24 additions & 1 deletion src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import UUID

from typing_extensions import Self

Expand Down Expand Up @@ -51,7 +53,7 @@

ResultStorage = Union[WritableFileSystem, str]
ResultSerializer = Union[Serializer, str]
LITERAL_TYPES = {type(None), bool}
LITERAL_TYPES = {type(None), bool, UUID}


def DEFAULT_STORAGE_KEY_FN():
Expand Down Expand Up @@ -383,6 +385,27 @@ async def create_result(self, obj: R) -> Union[R, "BaseResult[R]"]:
cache_object=should_cache_object,
)

@sync_compatible
async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
assert (
self.storage_block_id is not None
), "Unexpected storage block ID. Was it persisted?"
data = self.serializer.dumps(parameters)
blob = PersistedResultBlob(serializer=self.serializer, data=data)
await self.storage_block.write_path(
f"parameters/{identifier}", content=blob.to_bytes()
)

@sync_compatible
async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
assert (
self.storage_block_id is not None
), "Unexpected storage block ID. Was it persisted?"
blob = PersistedResultBlob.parse_raw(
await self.storage_block.read_path(f"parameters/{identifier}")
)
return self.serializer.loads(blob.data)


@add_type_dispatch
class BaseResult(pydantic.BaseModel, abc.ABC, Generic[R]):
Expand Down
Loading
Loading