Skip to content

Commit

Permalink
[task scheduling] add intermediate Pending state and update policy (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Feb 14, 2024
1 parent 7c7de22 commit c07bd63
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 42 deletions.
19 changes: 1 addition & 18 deletions src/prefect/server/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AutonomousTaskPolicy(BaseOrchestrationPolicy):

def priority():
return [
PreventRunningToRunningTransitions,
PreventPendingTransitions,
CacheRetrieval,
HandleTaskTerminalStateTransitions,
SecureTaskConcurrencySlots, # retrieve cached states even if slots are full
Expand Down Expand Up @@ -1004,20 +1004,3 @@ async def before_transition(
state=None,
reason="This run has already made this state transition.",
)


class PreventRunningToRunningTransitions(BaseOrchestrationRule):
"""Prevents transitions from Running to Running states."""

FROM_STATES = [StateType.RUNNING]
TO_STATES = [StateType.RUNNING]

async def before_transition(
self,
initial_state: Optional[states.State],
proposed_state: Optional[states.State],
context: OrchestrationContext,
) -> None:
await self.abort_transition(
reason="Cannot transition from Running to Running.",
)
2 changes: 2 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

import anyio
import greenback
from typing_extensions import Literal

from prefect._internal.concurrency.api import create_call, from_async, from_sync
Expand Down Expand Up @@ -68,4 +69,5 @@ async def submit_autonomous_task_to_engine(
if task.isasync:
return await from_async.wait_for_call_in_loop_thread(begin_run)
else:
await greenback.ensure_portal()
return from_sync.wait_for_call_in_loop_thread(begin_run)
19 changes: 17 additions & 2 deletions src/prefect/task_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from prefect._internal.concurrency.api import create_call, from_sync
from prefect.client.schemas.objects import TaskRun
from prefect.client.subscriptions import Subscription
from prefect.engine import propose_state
from prefect.logging.loggers import get_logger
from prefect.results import ResultFactory
from prefect.settings import (
PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING,
PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS,
)
from prefect.states import Pending
from prefect.task_engine import submit_autonomous_task_to_engine
from prefect.task_runners import BaseTaskRunner, ConcurrentTaskRunner
from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible
Expand Down Expand Up @@ -108,9 +110,9 @@ async def _subscribe_to_task_scheduling(self):
[task.task_key for task in self.tasks],
):
logger.info(f"Received task run: {task_run.id} - {task_run.name}")
await self._submit_pending_task_run(task_run)
await self._submit_scheduled_task_run(task_run)

async def _submit_pending_task_run(self, task_run: TaskRun):
async def _submit_scheduled_task_run(self, task_run: TaskRun):
logger.debug(
f"Found task run: {task_run.name!r} in state: {task_run.state.name!r}"
)
Expand Down Expand Up @@ -152,6 +154,19 @@ async def _submit_pending_task_run(self, task_run: TaskRun):
f"Submitting run {task_run.name!r} of task {task.name!r} to engine"
)

state = await propose_state(
client=get_client(), # TODO prove that we cannot use self._client here
state=Pending(),
task_run_id=task_run.id,
)

if not state.is_pending():
logger.warning(
f"Aborted task run {task_run.id!r} -"
f" server returned a non-pending state {state.type.value!r}."
" Task run may have already begun execution."
)

self._runs_task_group.start_soon(
partial(
submit_autonomous_task_to_engine,
Expand Down
11 changes: 7 additions & 4 deletions tests/server/api/test_task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,11 @@ async def test_set_task_run_state_returns_404_on_missing_flow_run(
)
assert response.status_code == status.HTTP_404_NOT_FOUND

async def test_autonomous_task_run_aborts_if_transitions_to_running_twice(
self, client, session
@pytest.mark.parametrize(
"incoming_state_type", ["PENDING", "RUNNING", "CANCELLED", "CANCELLING"]
)
async def test_autonomous_task_run_aborts_if_enters_pending_from_disallowed_state(
self, client, session, incoming_state_type
):
autonomous_task_run = await models.task_runs.create_task_run(
session=session,
Expand All @@ -543,7 +546,7 @@ async def test_autonomous_task_run_aborts_if_transitions_to_running_twice(

response_1 = await client.post(
f"/task_runs/{autonomous_task_run.id}/set_state",
json=dict(state=dict(type="RUNNING")),
json=dict(state=dict(type=incoming_state_type)),
)

api_response_1 = OrchestrationResult.parse_obj(response_1.json())
Expand All @@ -552,7 +555,7 @@ async def test_autonomous_task_run_aborts_if_transitions_to_running_twice(

response_2 = await client.post(
f"/task_runs/{autonomous_task_run.id}/set_state",
json=dict(state=dict(type="RUNNING")),
json=dict(state=dict(type="PENDING")),
)

api_response_2 = OrchestrationResult.parse_obj(response_2.json())
Expand Down
18 changes: 0 additions & 18 deletions tests/server/orchestration/test_core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
PreventDuplicateTransitions,
PreventPendingTransitions,
PreventRunningTasksFromStoppedFlows,
PreventRunningToRunningTransitions,
ReleaseTaskConcurrencySlots,
RenameReruns,
RetryFailedFlows,
Expand Down Expand Up @@ -3134,20 +3133,3 @@ async def test_same_transition_id(

# states have the same transition id so the transition should be rejected
assert ctx.response_status == SetStateStatus.REJECT


class TestPreventRunningToRunningTransitions:
async def test_prevents_running_to_running_transitions(
self,
session,
initialize_orchestration,
):
transition = (StateType.RUNNING, StateType.RUNNING)
context = await initialize_orchestration(
session, "flow", *transition, initial_details=None, proposed_details=None
)

async with PreventRunningToRunningTransitions(context, *transition) as ctx:
await ctx.validate_proposed_state()

assert ctx.response_status == SetStateStatus.ABORT

0 comments on commit c07bd63

Please sign in to comment.