Skip to content
Permalink
Browse files

Merge pull request #1356 from PrefectHQ/looping-take-2

TASK LOOPING (in both Core and Cloud)
  • Loading branch information...
cicdw committed Aug 13, 2019
2 parents 714a861 + d997f21 commit 3d2a26f1f147ce8a613493dd60de4b4e451300d3
@@ -7,6 +7,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/
### Features

- Added Local, Kubernetes, and Nomad agents - [#1341](https://github.com/PrefectHQ/prefect/pull/1341)
- Add the ability for Tasks to sequentially loop - [#1356](https://github.com/PrefectHQ/prefect/pull/1356)

### Enhancements

@@ -5,7 +5,7 @@ Author: Jeremiah Lowin

# Status

Proposed
Accepted

# Context

@@ -95,7 +95,7 @@ def call_runner_target_handlers(self, old_state: State, new_state: State) -> Sta
cache_for=self.task.cache_for,
)
except Exception as exc:
self.logger.debug(
self.logger.error(
"Failed to set task state with error: {}".format(repr(exc))
)
raise ENDRUN(state=ClientFailed(state=new_state))
@@ -56,6 +56,27 @@ class FAIL(PrefectStateSignal):
_state_cls = state.Failed


class LOOP(PrefectStateSignal):
"""
Indicates that a task should loop.
Args:
- message (Any, optional): Defaults to `None`. A message about the signal.
- *args (Any, optional): additional arguments to pass to this Signal's
associated state constructor
- **kwargs (Any, optional): additional keyword arguments to pass to this Signal's
associated state constructor
"""

_state_cls = state.Looped

def __init__(self, message: str = None, *args, **kwargs): # type: ignore
kwargs.setdefault(
"result", repr(self)
) # looped results are always result handled
super().__init__(message, *args, **kwargs) # type: ignore


class TRIGGERFAIL(FAIL):
"""
Indicates that a task trigger failed.
@@ -125,6 +125,15 @@ def is_finished(self) -> bool:
"""
return isinstance(self, Finished)

def is_looped(self) -> bool:
"""
Checks if the state is currently in a looped state
Returns:
- bool: `True` if the state is looped, `False` otherwise
"""
return isinstance(self, Looped)

def is_scheduled(self) -> bool:
"""
Checks if the state is currently in a scheduled state, which includes retrying.
@@ -467,6 +476,32 @@ class Finished(State):
color = "#003ccb"


class Looped(Finished):
"""
Finished state indicating one successful run of a looped task - if a Task is in this state, it will
run the next iteration of the loop immediately after.
Args:
- message (str or Exception, optional): Defaults to `None`. A message about the
state, which could be an `Exception` (or [`Signal`](signals.html)) that caused it.
- result (Any, optional): Defaults to `None`. A data payload for the state.
- loop_count (int): The iteration number of the looping task.
Defaults to the value stored in context under "task_loop_count" or 1,
if that value isn't found.
"""

color = "#003ccb"

def __init__(
self, message: str = None, result: Any = NoResult, loop_count: int = None
):
super().__init__(result=result, message=message)
if loop_count is None:
loop_count = prefect.context.get("task_loop_count", 1)
assert loop_count is not None # mypy assert
self.loop_count = loop_count # type: int


class Success(Finished):
"""
Finished state indicating success.
@@ -25,10 +25,12 @@
from prefect.core import Edge, Task
from prefect.engine import signals
from prefect.engine.result import NoResult, Result
from prefect.engine.result_handlers import JSONResultHandler
from prefect.engine.runner import ENDRUN, Runner, call_state_handlers
from prefect.engine.state import (
Cached,
Failed,
Looped,
Mapped,
Paused,
Pending,
@@ -147,6 +149,22 @@ def initialize_run( # type: ignore
if isinstance(state, Resume):
context.update(resume=True)

if hasattr(state, "cached_inputs"):
if "_loop_count" in (state.cached_inputs or {}): # type: ignore
loop_context = {
"task_loop_count": state.cached_inputs.pop( # type: ignore
"_loop_count"
) # type: ignore
.to_result()
.value,
"task_loop_result": state.cached_inputs.pop( # type: ignore
"_loop_result"
) # type: ignore
.to_result()
.value,
}
context.update(loop_context)

context.update(
task_run_count=run_count, task_name=self.task.name, task_tags=self.task.tags
)
@@ -198,9 +216,12 @@ def run(
mapped = any([e.mapped for e in upstream_states]) and map_index is None
task_inputs = {} # type: Dict[str, Any]

self.logger.info(
"Task '{name}': Starting task run...".format(name=context["task_full_name"])
)
if context.get("task_loop_count") is None:
self.logger.info(
"Task '{name}': Starting task run...".format(
name=context["task_full_name"]
)
)

try:
# initialize the run
@@ -270,6 +291,14 @@ def run(
# check if the task needs to be retried
state = self.check_for_retry(state, inputs=task_inputs)

state = self.check_task_is_looping(
state,
inputs=task_inputs,
upstream_states=upstream_states,
context=context,
executor=executor,
)

# for pending signals, including retries and pauses we need to make sure the
# task_inputs are set
except (ENDRUN, signals.PrefectStateSignal) as exc:
@@ -290,11 +319,15 @@ def run(
if prefect.context.get("raise_on_exception"):
raise exc

self.logger.info(
"Task '{name}': finished task run for task with final state: '{state}'".format(
name=context["task_full_name"], state=type(state).__name__
if prefect.context.get("task_loop_count") is None:
# to prevent excessive repetition of this log
# since looping relies on recursively calling self.run
# TODO: figure out a way to only log this one single time instead of twice
self.logger.info(
"Task '{name}': finished task run for task with final state: '{state}'".format(
name=context["task_full_name"], state=type(state).__name__
)
)
)

return state

@@ -836,6 +869,17 @@ def get_task_run_state(
)
return state

except signals.LOOP as exc:
new_state = exc.state
assert isinstance(new_state, Looped)
new_state.result = Result(
value=new_state.result, result_handler=self.result_handler
)
new_state.message = exc.state.message or "Task is looping ({})".format(
new_state.loop_count
)
return new_state

result = Result(value=result, result_handler=self.result_handler)
state = Success(result=result, message="Task run succeeded.")

@@ -900,6 +944,18 @@ def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State:
"""
if state.is_failed():
run_count = prefect.context.get("task_run_count", 1)
if prefect.context.get("task_loop_count") is not None:
loop_context = {
"_loop_count": Result(
value=prefect.context["task_loop_count"],
result_handler=JSONResultHandler(),
),
"_loop_result": Result(
value=prefect.context.get("task_loop_result"),
result_handler=self.result_handler,
),
}
inputs.update(loop_context)
if run_count <= self.task.max_retries:
start_time = pendulum.now("utc") + self.task.retry_delay
msg = "Retrying Task (after attempt {n} of {m})".format(
@@ -914,3 +970,50 @@ def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State:
return retry_state

return state

def check_task_is_looping(
self,
state: State,
inputs: Dict[str, Result] = None,
upstream_states: Dict[Edge, State] = None,
context: Dict[str, Any] = None,
executor: "prefect.engine.executors.Executor" = None,
) -> State:
"""
Checks to see if the task is in a `Looped` state and if so, rerun the pipeline with an incremeneted `loop_count`.
Args:
- state (State, optional): initial `State` to begin task run from;
defaults to `Pending()`
- inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
to the task's `run()` arguments.
- upstream_states (Dict[Edge, State]): a dictionary
representing the states of any tasks upstream of this one. The keys of the
dictionary should correspond to the edges leading to the task.
- context (dict, optional): prefect Context to use for execution
- executor (Executor, optional): executor to use when performing
computation; defaults to the executor specified in your prefect configuration
Returns:
- `State` object representing the final post-run state of the Task
"""
if state.is_looped():
assert isinstance(state, Looped) # mypy assert
assert isinstance(context, dict) # mypy assert
msg = "Looping task (on loop index {})".format(state.loop_count)
context.update(
{
"task_loop_result": state.result,
"task_loop_count": state.loop_count + 1,
}
)
context.update(task_run_version=prefect.context.get("task_run_version"))
new_state = Pending(message=msg)
return self.run(
new_state,
upstream_states=upstream_states,
context=context,
executor=executor,
)

return state
@@ -105,6 +105,13 @@ class Meta:
object_class = state.Finished


class LoopedSchema(BaseStateSchema):
class Meta:
object_class = state.Looped

loop_count = fields.Int(allow_none=False)


class SuccessSchema(FinishedSchema):
class Meta:
object_class = state.Success
@@ -192,6 +199,7 @@ class StateSchema(OneOfSchema):
"ClientFailed": ClientFailedSchema,
"Failed": FailedSchema,
"Finished": FinishedSchema,
"Looped": LoopedSchema,
"Mapped": MappedSchema,
"Paused": PausedSchema,
"Pending": PendingSchema,
@@ -16,7 +16,7 @@
from prefect.core.task import Parameter, Task
from prefect.engine.cache_validators import all_inputs, partial_inputs_only
from prefect.engine.result_handlers import LocalResultHandler, ResultHandler
from prefect.engine.signals import PrefectError
from prefect.engine.signals import PrefectError, LOOP
from prefect.engine.state import (
Failed,
Finished,
@@ -2094,3 +2094,75 @@ def test_flow_run_handles_error_states_when_initial_state_is_provided():
res = AddTask()("5", 5)
state = f.run(state=Pending())
assert state.is_failed()


def test_looping_works_in_a_flow():
@task
def looper(x):
if prefect.context.get("task_loop_count", 1) < 20:
raise LOOP(result=prefect.context.get("task_loop_result", 0) + x)
return prefect.context.get("task_loop_result") + x

@task
def downstream(l):
return l ** 2

with Flow(name="looping") as f:
inter = looper(10)
final = downstream(inter)

flow_state = f.run()

assert flow_state.is_successful()
assert flow_state.result[inter].result == 200
assert flow_state.result[final].result == 200 ** 2


def test_looping_with_retries_works_in_a_flow():
@task(max_retries=1, retry_delay=datetime.timedelta(seconds=0))
def looper(x):
if (
prefect.context.get("task_loop_count") == 2
and prefect.context.get("task_run_count", 1) == 1
):
raise ValueError("err")

if prefect.context.get("task_loop_count", 1) < 20:
raise LOOP(result=prefect.context.get("task_loop_result", 0) + x)
return prefect.context.get("task_loop_result") + x

@task
def downstream(l):
return l ** 2

with Flow(name="looping") as f:
inter = looper(10)
final = downstream(inter)

flow_state = f.run()

assert flow_state.is_successful()
assert flow_state.result[inter].result == 200
assert flow_state.result[final].result == 200 ** 2


def test_starting_at_arbitrary_loop_index():
@task
def looper(x):
if prefect.context.get("task_loop_count", 1) < 20:
raise LOOP(result=prefect.context.get("task_loop_result", 0) + x)
return prefect.context.get("task_loop_result", 0) + x

@task
def downstream(l):
return l ** 2

with Flow(name="looping") as f:
inter = looper(10)
final = downstream(inter)

flow_state = f.run(context={"task_loop_count": 20})

assert flow_state.is_successful()
assert flow_state.result[inter].result == 10
assert flow_state.result[final].result == 100

0 comments on commit 3d2a26f

Please sign in to comment.
You can’t perform that action at this time.