Skip to content

Commit

Permalink
Merge pull request #1648 from PrefectHQ/dask-future-opt
Browse files Browse the repository at this point in the history
Optimize Dask task runner future resolution
  • Loading branch information
zanieb committed Apr 27, 2022
2 parents 727cf31 + 73f6c8d commit c0fbbab
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 21 deletions.
36 changes: 21 additions & 15 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
run_sync_in_worker_thread,
)
from prefect.utilities.callables import parameters_to_args_kwargs
from prefect.utilities.collections import visit_collection
from prefect.utilities.collections import Quote, visit_collection

R = TypeVar("R")
engine_logger = get_logger("engine")
Expand Down Expand Up @@ -739,9 +739,9 @@ async def orchestrate_task_run(

try:
# Resolve futures in parameters into data
resolved_parameters = await resolve_upstream_task_futures(parameters)
resolved_parameters = await resolve_inputs(parameters)
# Resolve futures in any non-data dependencies to ensure they are ready
await resolve_upstream_task_futures(wait_for, return_data=False)
await resolve_inputs(wait_for, return_data=False)
except UpstreamTaskError as upstream_exc:
return await client.propose_state(
Pending(name="NotReady", message=str(upstream_exc)),
Expand Down Expand Up @@ -909,11 +909,12 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient):
raise exc from None


async def resolve_upstream_task_futures(
async def resolve_inputs(
parameters: Dict[str, Any], return_data: bool = True
) -> Dict[str, Any]:
"""
Resolve any `PrefectFuture` types nested in parameters into data.
Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into
data.
Returns:
A copy of the parameters with resolved data
Expand All @@ -923,20 +924,25 @@ async def resolve_upstream_task_futures(
"""

async def visit_fn(expr):
# Resolves futures into data, raising if they are not completed after `wait` is
# called.
if isinstance(expr, PrefectFuture):
state = None

if isinstance(expr, Quote):
return expr.unquote()
elif isinstance(expr, PrefectFuture):
state = await expr._wait()
if not state.is_completed():
raise UpstreamTaskError(
f"Upstream task run '{state.state_details.task_run_id}' did not reach a 'COMPLETED' state."
)
# Only load the state data if requested
if return_data:
return state.result()
elif isinstance(expr, State):
state = expr
else:
return expr

if not state.is_completed():
raise UpstreamTaskError(
f"Upstream task run '{state.state_details.task_run_id}' did not reach a 'COMPLETED' state."
)

# Only retrieve the result if requested as it may be expensive
return state.result() if return_data else None

return await visit_collection(
parameters,
visit_fn=visit_fn,
Expand Down
16 changes: 16 additions & 0 deletions src/prefect/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from prefect.orion.schemas.states import State
from prefect.states import exception_to_crashed_state
from prefect.utilities.asyncio import A, sync_compatible
from prefect.utilities.collections import visit_collection
from prefect.utilities.hashing import to_qualified_name
from prefect.utilities.importtools import import_object

Expand Down Expand Up @@ -361,6 +362,10 @@ async def submit(
"The task runner must be started before submitting work."
)

# Cast Prefect futures to Dask futures where possible to optimize Dask task
# scheduling
run_kwargs = await self._optimize_futures(run_kwargs)

self._dask_futures[task_run.id] = self._client.submit(
run_fn,
# Dask displays the text up to the first '-' as the name, include the
Expand All @@ -386,6 +391,17 @@ def _get_dask_future(self, prefect_future: PrefectFuture) -> "distributed.Future
"""
return self._dask_futures[prefect_future.run_id]

async def _optimize_futures(self, expr):
async def visit_fn(expr):
if isinstance(expr, PrefectFuture):
dask_future = self._dask_futures.get(expr.run_id)
if dask_future is not None:
return dask_future
# Fallback to return the expression unaltered
return expr

return await visit_collection(expr, visit_fn=visit_fn, return_data=True)

async def wait(
self,
prefect_future: PrefectFuture,
Expand Down
6 changes: 3 additions & 3 deletions src/prefect/utilities/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,18 @@ def batched_iterable(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]
yield batch


@dataclass
class Quote(Generic[T]):
"""
Simple wrapper to mark an expression as a different type so it will not be coerced
by Prefect. For example, if you want to return a state from a flow without having
the flow assume that state.
"""

expr: T
def __init__(self, data: T) -> None:
self.data = data

def unquote(self) -> T:
return self.expr
return self.data


def quote(expr: T) -> Quote[T]:
Expand Down
40 changes: 38 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from prefect.task_runners import SequentialTaskRunner
from prefect.testing.utilities import AsyncMock, exceptions_equal
from prefect.utilities.collections import quote


class TestOrchestrateTaskRun:
Expand Down Expand Up @@ -241,10 +242,45 @@ def my_task(x):
== f"Upstream task run '{upstream_task_run.id}' did not reach a 'COMPLETED' state."
)

async def test_quoted_parameters_are_resolved(
self, orion_client, flow_run, local_storage_block
):
# Define a mock to ensure the task was not run
mock = MagicMock()

@task
def my_task(x):
mock(x)

# Create a task run to test
task_run = await orion_client.create_task_run(
task=my_task,
flow_run_id=flow_run.id,
state=Pending(),
dynamic_key="downstream",
)

# Actually run the task
state = await orchestrate_task_run(
task=my_task,
task_run=task_run,
# Quote some data
parameters={"x": quote(1)},
wait_for=None,
result_storage=local_storage_block,
client=orion_client,
)

# The task ran with the unqoted data
mock.assert_called_once_with(1)

# Check that the state completed happily
assert state.is_completed()

@pytest.mark.parametrize(
"upstream_task_state", [Pending(), Running(), Cancelled(), Failed()]
)
async def test_states_in_parameters_can_be_incomplete(
async def test_states_in_parameters_can_be_incomplete_if_quoted(
self, orion_client, flow_run, upstream_task_state, local_storage_block
):
# Define a mock to ensure the task was not run
Expand All @@ -267,7 +303,7 @@ def my_task(x):
task=my_task,
task_run=task_run,
# Nest the future in a collection to ensure that it is found
parameters={"x": upstream_task_state},
parameters={"x": quote(upstream_task_state)},
wait_for=None,
result_storage=local_storage_block,
client=orion_client,
Expand Down
40 changes: 39 additions & 1 deletion tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tests
from prefect import flow, task
from prefect.orion.schemas.core import TaskRun
from prefect.orion.schemas.states import State
from prefect.orion.schemas.states import DataDocument, State, StateType
from prefect.task_runners import (
ConcurrentTaskRunner,
DaskTaskRunner,
Expand Down Expand Up @@ -445,3 +445,41 @@ def child_flow(a):
task_state, subflow_state = parent_flow().result()
assert task_state.result() == "a"
assert subflow_state.result() == "a"

@pytest.mark.service("dask")
async def test_converts_prefect_futures_to_dask_futures(self):
task_run_1 = TaskRun(flow_run_id=uuid4(), task_key="foo", dynamic_key="1")
task_run_2 = TaskRun(flow_run_id=uuid4(), task_key="foo", dynamic_key="2")

async def fake_orchestrate_task_run(example_kwarg):
return State(
type=StateType.COMPLETED,
data=DataDocument.encode("cloudpickle", example_kwarg),
)

async with DaskTaskRunner().start() as task_runner:
fut_1 = await task_runner.submit(
task_run=task_run_1,
run_fn=fake_orchestrate_task_run,
run_kwargs=dict(example_kwarg=1),
)

original_submit = task_runner._client.submit
mock = task_runner._client.submit = MagicMock(side_effect=original_submit)

fut_2 = await task_runner.submit(
task_run=task_run_2,
run_fn=fake_orchestrate_task_run,
run_kwargs=dict(example_kwarg=fut_1),
)

called_with = mock.call_args[1].get("example_kwarg")
assert isinstance(
called_with, distributed.Future
), "Prefect future converted to Dask future"
assert called_with == task_runner._get_dask_future(fut_1)

state_1 = await task_runner.wait(fut_1, 5)

state_2 = await task_runner.wait(fut_2, 5)
assert state_2.result() == state_1, "Dask converted the future to the state"

0 comments on commit c0fbbab

Please sign in to comment.