Skip to content

Commit

Permalink
Merge branch 'main' into issue-11792
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jan 31, 2024
2 parents 02a3230 + 1bb9c73 commit a569b6a
Show file tree
Hide file tree
Showing 20 changed files with 760 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,7 +1961,7 @@ async def set_task_run_name(self, task_run_id: UUID, name: str):
async def create_task_run(
self,
task: "TaskObject",
flow_run_id: UUID,
flow_run_id: Optional[UUID],
dynamic_key: str,
name: str = None,
extra_tags: Iterable[str] = None,
Expand Down
1 change: 1 addition & 0 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class StateDetails(PrefectBaseModel):
refresh_cache: bool = None
retriable: bool = None
transition_id: Optional[UUID] = None
task_parameters_id: Optional[UUID] = None


class State(ObjectBaseModel, Generic[R]):
Expand Down
82 changes: 82 additions & 0 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import asyncio
from typing import Generic, Type, TypeVar

import orjson
import websockets
import websockets.exceptions
from starlette.status import WS_1008_POLICY_VIOLATION
from typing_extensions import Self

from prefect._internal.schemas.bases import IDBaseModel
from prefect.settings import PREFECT_API_KEY, PREFECT_API_URL

S = TypeVar("S", bound=IDBaseModel)


class Subscription(Generic[S]):
def __init__(self, model: Type[S], path: str):
self.model = model

base_url = PREFECT_API_URL.value().replace("http", "ws", 1)
self.subscription_url = f"{base_url}{path}"

self._connect = websockets.connect(
self.subscription_url,
subprotocols=["prefect"],
)
self._websocket = None

def __aiter__(self) -> Self:
return self

async def __anext__(self) -> S:
while True:
try:
await self._ensure_connected()
message = await self._websocket.recv()

message_data = orjson.loads(message)

if message_data.get("type") == "ping":
await self._websocket.send(orjson.dumps({"type": "pong"}).decode())
continue

return self.model.parse_raw(message)
except (
ConnectionRefusedError,
websockets.exceptions.ConnectionClosedError,
):
self._websocket = None
if hasattr(self._connect, "protocol"):
await self._connect.__aexit__(None, None, None)
await asyncio.sleep(0.5)

async def _ensure_connected(self):
if self._websocket:
return

websocket = await self._connect.__aenter__()

await websocket.send(
orjson.dumps({"type": "auth", "token": PREFECT_API_KEY.value()}).decode()
)

try:
auth = orjson.loads(await websocket.recv())
assert auth["type"] == "auth_success"
except (
AssertionError,
websockets.exceptions.ConnectionClosedError,
) as e:
if isinstance(e, AssertionError) or e.code == WS_1008_POLICY_VIOLATION:
raise Exception(
"Unable to authenticate to the subscription. Please "
"ensure the provided `PREFECT_API_KEY` you are using is "
"valid for this environment."
) from e
raise
else:
self._websocket = websocket

def __repr__(self) -> str:
return f"{type(self).__name__}[{self.model.__name__}]"
6 changes: 5 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class RunContext(ContextModel):
client: PrefectClient


class FlowRunContext(RunContext):
class EngineContext(RunContext):
"""
The context for a flow run. Data in this context is only available from within a
flow run function.
Expand All @@ -233,6 +233,7 @@ class FlowRunContext(RunContext):

flow: Optional["Flow"] = None
flow_run: Optional[FlowRun] = None
autonomous_task_run: Optional[TaskRun] = None
task_runner: BaseTaskRunner
log_prints: bool = False
parameters: Dict[str, Any]
Expand Down Expand Up @@ -266,6 +267,9 @@ class FlowRunContext(RunContext):
__var__ = ContextVar("flow_run")


FlowRunContext = EngineContext # for backwards compatibility


class TaskRunContext(RunContext):
"""
The context for a task run. Data in this context is only available from within a
Expand Down
97 changes: 75 additions & 22 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
from prefect.results import BaseResult, ResultFactory, UnknownResult
from prefect.settings import (
PREFECT_DEBUG_MODE,
PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING,
PREFECT_LOGGING_LOG_PRINTS,
PREFECT_TASK_INTROSPECTION_WARN_THRESHOLD,
PREFECT_TASKS_REFRESH_CACHE,
Expand All @@ -179,6 +180,7 @@
Paused,
Pending,
Running,
Scheduled,
State,
Suspended,
exception_to_crashed_state,
Expand Down Expand Up @@ -213,6 +215,7 @@
T = TypeVar("T")
EngineReturnType = Literal["future", "state", "result"]

NUM_CHARS_DYNAMIC_KEY = 8

API_HEALTHCHECKS = {}
UNTRACKABLE_TYPES = {bool, type(None), type(...), type(NotImplemented)}
Expand Down Expand Up @@ -1359,16 +1362,19 @@ def enter_task_run_engine(
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
mapped: bool,
) -> Union[PrefectFuture, Awaitable[PrefectFuture]]:
"""
Sync entrypoint for task calls
"""
) -> Union[PrefectFuture, Awaitable[PrefectFuture], TaskRun]:
"""Sync entrypoint for task calls"""

flow_run_context = FlowRunContext.get()

if not flow_run_context:
if PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value():
return _create_autonomous_task_run(task=task, parameters=parameters)

raise RuntimeError(
"Tasks cannot be run outside of a flow. To call the underlying task"
" function outside of a flow use `task.fn()`."
"Tasks cannot be run outside of a flow"
" - if you meant to submit an autonomous task, you need to set"
" `prefect config set PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING=true`"
)

if TaskRunContext.get():
Expand Down Expand Up @@ -1580,14 +1586,22 @@ async def create_task_run_future(

# Generate a name for the future
dynamic_key = _dynamic_key_for_task_run(flow_run_context, task)
task_run_name = f"{task.name}-{dynamic_key}"
task_run_name = (
f"{task.name}-{dynamic_key}"
if flow_run_context and flow_run_context.flow_run
else f"{task.name}-{dynamic_key[:NUM_CHARS_DYNAMIC_KEY]}" # autonomous task run
)

# Generate a future
future = PrefectFuture(
name=task_run_name,
key=uuid4(),
task_runner=task_runner,
asynchronous=task.isasync and flow_run_context.flow.isasync,
asynchronous=(
task.isasync and flow_run_context.flow.isasync
if flow_run_context and flow_run_context.flow
else task.isasync
),
)

# Create and submit the task run in the background
Expand Down Expand Up @@ -1627,14 +1641,18 @@ async def create_task_run_then_submit(
task_runner: BaseTaskRunner,
extra_task_inputs: Dict[str, Set[TaskRunInput]],
) -> None:
task_run = await create_task_run(
task=task,
name=task_run_name,
flow_run_context=flow_run_context,
parameters=parameters,
dynamic_key=task_run_dynamic_key,
wait_for=wait_for,
extra_task_inputs=extra_task_inputs,
task_run = (
await create_task_run(
task=task,
name=task_run_name,
flow_run_context=flow_run_context,
parameters=parameters,
dynamic_key=task_run_dynamic_key,
wait_for=wait_for,
extra_task_inputs=extra_task_inputs,
)
if not flow_run_context.autonomous_task_run
else flow_run_context.autonomous_task_run
)

# Attach the task run to the future to support `get_state` operations
Expand Down Expand Up @@ -1675,7 +1693,7 @@ async def create_task_run(
task_run = await flow_run_context.client.create_task_run(
task=task,
name=name,
flow_run_id=flow_run_context.flow_run.id,
flow_run_id=flow_run_context.flow_run.id if flow_run_context.flow_run else None,
dynamic_key=dynamic_key,
state=Pending(),
extra_tags=TagsContext.get().current_tags,
Expand All @@ -1698,7 +1716,10 @@ async def submit_task_run(
) -> PrefectFuture:
logger = get_run_logger(flow_run_context)

if task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL:
if (
task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL
and not flow_run_context.autonomous_task_run
):
logger.info(f"Executing {task_run.name!r} immediately...")

future = await task_runner.submit(
Expand Down Expand Up @@ -1776,7 +1797,7 @@ async def begin_task_run(
# worker, the flow run timeout will not be raised in the worker process.
interruptible = maybe_flow_run_context.timeout_scope is not None
else:
# Otherwise, retrieve a new client
# Otherwise, retrieve a new clien`t
client = await stack.enter_async_context(get_client())
interruptible = False
await stack.enter_async_context(anyio.create_task_group())
Expand Down Expand Up @@ -2130,7 +2151,6 @@ async def tick():
await _check_task_failure_retriable(task, task_run, terminal_state)
)
state = await propose_state(client, terminal_state, task_run_id=task_run.id)

last_event = _emit_task_run_state_change_event(
task_run=task_run,
initial_state=last_state,
Expand Down Expand Up @@ -2180,7 +2200,7 @@ async def tick():
level=logging.INFO if state.is_completed() else logging.ERROR,
msg=f"Finished in state {display_state}",
)

logger.warning(f"Task run {task_run.name!r} finished in state {display_state}")
return state


Expand Down Expand Up @@ -2549,7 +2569,12 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult:


def _dynamic_key_for_task_run(context: FlowRunContext, task: Task) -> int:
if task.task_key not in context.task_run_dynamic_keys:
if context.flow_run is None: # this is an autonomous task run
context.task_run_dynamic_keys[task.task_key] = getattr(
task, "dynamic_key", str(uuid4())
)

elif task.task_key not in context.task_run_dynamic_keys:
context.task_run_dynamic_keys[task.task_key] = 0
else:
context.task_run_dynamic_keys[task.task_key] += 1
Expand Down Expand Up @@ -2889,6 +2914,34 @@ def _emit_task_run_state_change_event(
)


@sync_compatible
async def _create_autonomous_task_run(
task: Task, parameters: Dict[str, Any]
) -> TaskRun:
async with get_client() as client:
scheduled = Scheduled()
if parameters:
parameters_id = uuid4()
scheduled.state_details.task_parameters_id = parameters_id

# TODO: We want to use result storage for parameters, but we'll need
# a better way to use it than this.
task.persist_result = True
factory = await ResultFactory.from_task(task, client=client)
await factory.store_parameters(parameters_id, parameters)

task_run = await client.create_task_run(
task=task,
flow_run_id=None,
dynamic_key=f"{task.task_key}-{str(uuid4())[:NUM_CHARS_DYNAMIC_KEY]}",
state=scheduled,
)

engine_logger.debug(f"Submitted run of task {task.name!r} for execution")

return task_run


if __name__ == "__main__":
try:
flow_run_id = UUID(
Expand Down
25 changes: 24 additions & 1 deletion src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import UUID

from typing_extensions import Self

Expand Down Expand Up @@ -51,7 +53,7 @@

ResultStorage = Union[WritableFileSystem, str]
ResultSerializer = Union[Serializer, str]
LITERAL_TYPES = {type(None), bool}
LITERAL_TYPES = {type(None), bool, UUID}


def DEFAULT_STORAGE_KEY_FN():
Expand Down Expand Up @@ -383,6 +385,27 @@ async def create_result(self, obj: R) -> Union[R, "BaseResult[R]"]:
cache_object=should_cache_object,
)

@sync_compatible
async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
assert (
self.storage_block_id is not None
), "Unexpected storage block ID. Was it persisted?"
data = self.serializer.dumps(parameters)
blob = PersistedResultBlob(serializer=self.serializer, data=data)
await self.storage_block.write_path(
f"parameters/{identifier}", content=blob.to_bytes()
)

@sync_compatible
async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
assert (
self.storage_block_id is not None
), "Unexpected storage block ID. Was it persisted?"
blob = PersistedResultBlob.parse_raw(
await self.storage_block.read_path(f"parameters/{identifier}")
)
return self.serializer.loads(blob.data)


@add_type_dispatch
class BaseResult(pydantic.BaseModel, abc.ABC, Generic[R]):
Expand Down

0 comments on commit a569b6a

Please sign in to comment.