Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implment blocking pause/resume for flows #7637

Merged
merged 19 commits into from Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/prefect/__init__.py
Expand Up @@ -30,6 +30,7 @@
from prefect.manifests import Manifest
from prefect.utilities.annotations import unmapped, allow_failure
from prefect.results import BaseResult
from prefect.engine import pause, resume
from prefect.client.orion import get_client, OrionClient
from prefect.client.cloud import get_cloud_client, CloudClient

Expand Down
74 changes: 74 additions & 0 deletions src/prefect/engine.py
Expand Up @@ -41,8 +41,10 @@
from prefect.deployments import load_flow_from_flow_run
from prefect.exceptions import (
Abort,
FlowPauseTimeout,
MappingLengthMismatch,
MappingMissingIterable,
NotPausedError,
UpstreamTaskError,
)
from prefect.flows import Flow
Expand All @@ -63,6 +65,7 @@
from prefect.results import BaseResult, ResultFactory
from prefect.settings import PREFECT_DEBUG_MODE
from prefect.states import (
Paused,
Pending,
Running,
State,
Expand All @@ -84,6 +87,7 @@
run_async_from_worker_thread,
run_sync_in_interruptible_worker_thread,
run_sync_in_worker_thread,
sync_compatible,
)
from prefect.utilities.callables import parameters_to_args_kwargs
from prefect.utilities.collections import isiterable, visit_collection
Expand Down Expand Up @@ -687,6 +691,76 @@ async def orchestrate_flow_run(
return state


@sync_compatible
async def pause(timeout: int = 300, poll_interval: int = 10):
"""
Pauses a flow run by stopping execution until resumed.

When called within a flow run, execution will block and no downstream tasks will
run until the flow is resumed. Task runs that have already started will continue
running. A timeout parameter can be passed that will fail the flow run if it has not
been resumed within the specified time.

Args:
timeout: the number of seconds to wait for the flow to be resumed before
failing. Defaults to 5 minutes (300 seconds). If the pause timeout exceeds
any configured flow-level timeout, the flow might fail even after resuming.
poll_interval: The number of seconds between checking whether the flow has been
resumed. Defaults to 10 seconds.
"""

if TaskRunContext.get():
raise RuntimeError("Cannot pause task runs.")

frc = FlowRunContext.get()
logger = get_run_logger(context=frc)

logger.info("Pausing flow, execution will continue when this flow run is resumed.")
client = get_client()
response = await client.set_flow_run_state(
frc.flow_run.id,
Paused(),
)

with anyio.move_on_after(timeout):

# attempt to check if a flow has resumed at least once
await anyio.sleep(timeout / 2)
flow_run = await client.read_flow_run(frc.flow_run.id)
if flow_run.state.is_running():
logger.info("Resuming flow run execution!")
return

while True:
await anyio.sleep(poll_interval)
flow_run = await client.read_flow_run(frc.flow_run.id)
if flow_run.state.is_running():
logger.info("Resuming flow run execution!")
return

raise FlowPauseTimeout("Flow run was paused and never resumed.")


@sync_compatible
async def resume(flow_run_id):
"""
Resumes a paused flow.

Args:
flow_run_id: the flow_run_id to resume
"""
client = get_client()
flow_run = await client.read_flow_run(flow_run_id)

if not flow_run.state.is_paused():
raise NotPausedError("Cannot resume a run that isn't paused!")

await client.set_flow_run_state(
flow_run_id,
Running(name="Resuming"),
)


def enter_task_run_engine(
task: Task,
parameters: Dict[str, Any],
Expand Down
8 changes: 8 additions & 0 deletions src/prefect/exceptions.py
Expand Up @@ -311,3 +311,11 @@ class ProtectedBlockError(PrefectException):

class InvalidRepositoryURLError(PrefectException):
"""Raised when an incorrect URL is provided to a GitHub filesystem block."""


class NotPausedError(PrefectException):
"""Raised when attempting to unpause a run that isn't paused."""


class FlowPauseTimeout(PrefectException):
"""Raised when a flow pause times out"""
16 changes: 16 additions & 0 deletions src/prefect/orion/schemas/states.py
Expand Up @@ -26,6 +26,7 @@ class StateType(AutoEnum):
FAILED = AutoEnum.auto()
CANCELLED = AutoEnum.auto()
CRASHED = AutoEnum.auto()
PAUSED = AutoEnum.auto()


TERMINAL_STATES = {
Expand Down Expand Up @@ -115,6 +116,12 @@ def is_cancelled(self) -> bool:
def is_final(self) -> bool:
return self.type in TERMINAL_STATES

def is_paused(self) -> bool:
return self.type == StateType.PAUSED

def is_resuming(self) -> bool:
return self.type == StateType.RESUMING

def copy(self, *, update: dict = None, reset_fields: bool = False, **kwargs):
"""
Copying API models should return an object that could be inserted into the
Expand Down Expand Up @@ -278,6 +285,15 @@ def Pending(cls: Type[State] = State, **kwargs) -> State:
return cls(type=StateType.PENDING, **kwargs)


def Paused(cls: Type[State] = State, **kwargs) -> State:
"""Convenience function for creating `Paused` states.

Returns:
State: a Paused state
"""
return cls(type=StateType.PAUSED, **kwargs)


def AwaitingRetry(
scheduled_time: datetime.datetime = None, cls: Type[State] = State, **kwargs
) -> State:
Expand Down
9 changes: 9 additions & 0 deletions src/prefect/states.py
Expand Up @@ -475,6 +475,15 @@ def Pending(cls: Type[State] = State, **kwargs) -> State:
return schemas.states.Pending(cls=cls, **kwargs)


def Paused(cls: Type[State] = State, **kwargs) -> State:
"""Convenience function for creating `Paused` states.

Returns:
State: a Paused state
"""
return schemas.states.Paused(cls=cls, **kwargs)


def AwaitingRetry(
cls: Type[State] = State, scheduled_time: datetime.datetime = None, **kwargs
) -> State:
Expand Down
128 changes: 128 additions & 0 deletions tests/test_engine.py
Expand Up @@ -21,11 +21,14 @@
link_state_to_result,
orchestrate_flow_run,
orchestrate_task_run,
pause,
resume,
retrieve_flow_then_begin_flow_run,
)
from prefect.exceptions import (
Abort,
CrashedRun,
FlowPauseTimeout,
ParameterTypeError,
SignatureMismatchError,
)
Expand Down Expand Up @@ -97,6 +100,131 @@ async def _get_flow_run_context():
return _get_flow_run_context


class TestPausingFlows:
async def test_tasks_cannot_be_paused(self):
@task
def the_little_task_that_pauses():
pause()
return True

@flow
def the_mountain():
return the_little_task_that_pauses()

with pytest.raises(RuntimeError, match="Cannot pause task runs.*"):
the_mountain()

async def test_paused_flows_fail_if_not_resumed(self):
@task
def doesnt_pause():
return 42

@flow()
def pausing_flow():
x = doesnt_pause.submit()
pause(timeout=0.1)
y = doesnt_pause.submit()
z = doesnt_pause(wait_for=[x])
alpha = doesnt_pause(wait_for=[y])
omega = doesnt_pause(wait_for=[x, y])

with pytest.raises(FlowPauseTimeout):
pausing_flow()

async def test_paused_flows_block_execution_in_sync_flows(self, orion_client):
@task
def doesnt_pause():
return 42

@flow()
def pausing_flow():
x = doesnt_pause.submit()
y = doesnt_pause.submit()
pause(timeout=0.1)
z = doesnt_pause(wait_for=[x])
alpha = doesnt_pause(wait_for=[y])
omega = doesnt_pause(wait_for=[x, y])

flow_run_state = pausing_flow(return_state=True)
flow_run_id = flow_run_state.state_details.flow_run_id
task_runs = await orion_client.read_task_runs(
flow_run_filter=FlowRunFilter(id={"any_": [flow_run_id]})
)
assert len(task_runs) == 2, "only two tasks should have completed"

async def test_paused_flows_block_execution_in_async_flows(self, orion_client):
@task
async def doesnt_pause():
return 42

@flow()
async def pausing_flow():
x = await doesnt_pause.submit()
y = await doesnt_pause.submit()
await pause(timeout=0.1)
z = await doesnt_pause(wait_for=[x])
alpha = await doesnt_pause(wait_for=[y])
omega = await doesnt_pause(wait_for=[x, y])

flow_run_state = await pausing_flow(return_state=True)
flow_run_id = flow_run_state.state_details.flow_run_id
task_runs = await orion_client.read_task_runs(
flow_run_filter=FlowRunFilter(id={"any_": [flow_run_id]})
)
assert len(task_runs) == 2, "only two tasks should have completed"

async def test_paused_flows_block_execution_in_async_flows(self, orion_client):
@task
async def doesnt_pause():
return 42

@flow()
async def pausing_flow():
x = await doesnt_pause.submit()
y = await doesnt_pause.submit()
await pause(timeout=0.1)
z = await doesnt_pause(wait_for=[x])
alpha = await doesnt_pause(wait_for=[y])
omega = await doesnt_pause(wait_for=[x, y])

flow_run_state = await pausing_flow(return_state=True)
flow_run_id = flow_run_state.state_details.flow_run_id
task_runs = await orion_client.read_task_runs(
flow_run_filter=FlowRunFilter(id={"any_": [flow_run_id]})
)
assert len(task_runs) == 2, "only two tasks should have completed"

async def test_paused_flows_can_be_resumed(self, orion_client):
@task
async def doesnt_pause():
return 42

@flow()
async def pausing_flow():
x = await doesnt_pause.submit()
y = await doesnt_pause.submit()
await pause(timeout=10, poll_interval=2)
z = await doesnt_pause(wait_for=[x])
alpha = await doesnt_pause(wait_for=[y])
omega = await doesnt_pause(wait_for=[x, y])

async def flow_resumer():
await anyio.sleep(3)
flow_runs = await orion_client.read_flow_runs(limit=1)
active_flow_run = flow_runs[0]
await resume(active_flow_run.id)

flow_run_state, the_answer = await asyncio.gather(
pausing_flow(return_state=True),
flow_resumer(),
)
flow_run_id = flow_run_state.state_details.flow_run_id
task_runs = await orion_client.read_task_runs(
flow_run_filter=FlowRunFilter(id={"any_": [flow_run_id]})
)
assert len(task_runs) == 5, "all tasks should finish running"


class TestOrchestrateTaskRun:
async def test_waits_until_scheduled_start_time(
self,
Expand Down