Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track task run count via Retry states #281

Merged
merged 2 commits into from
Oct 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Add `is_skipped()` and `is_scheduled()` methods for `State` objects - [#266](https://github.com/PrefectHQ/prefect/pull/266), [#278](https://github.com/PrefectHQ/prefect/pull/278)
- Adds `now()` as a default `start_time` for `Scheduled` states - [#278](https://github.com/PrefectHQ/prefect/pull/278)
- `Signal` classes now pass arguments to underlying `State` objects - [#279](https://github.com/PrefectHQ/prefect/pull/279)
- Run counts are tracked via `Retrying` states - [#281](https://github.com/PrefectHQ/prefect/pull/281)

### Fixes

Expand Down
22 changes: 22 additions & 0 deletions src/prefect/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import datetime
from typing import Any, Dict, Union

import prefect
from prefect.utilities.json import Serializable

MessageType = Union[str, Exception]
Expand Down Expand Up @@ -237,10 +238,31 @@ class Retrying(Scheduled):
- start_time (datetime): time at which the task is scheduled to be retried
- cached_inputs (dict): Defaults to `None`. A dictionary of input
keys to values. Used / set if the Task requires Retries.
- run_count (int): The number of runs that had been attempted at the time of this
Retry. Defaults to the value stored in context under "_task_run_count" or 1,
if that value isn't found.
"""

color = "#FFFF00"

def __init__(
self,
result: Any = None,
message: MessageType = None,
start_time: datetime.datetime = None,
cached_inputs: Dict[str, Any] = None,
run_count: int = None,
) -> None:
super().__init__(
result=result,
message=message,
start_time=start_time,
cached_inputs=cached_inputs,
)
if run_count is None:
run_count = prefect.context.get("_task_run_count", 1)
self.run_count = run_count


# -------------------------------------------------------------------
# Running States
Expand Down
34 changes: 30 additions & 4 deletions src/prefect/engine/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def run(
with prefect.context(context, _task_name=self.task.name):

try:
# retrieve the run number and place in context
state = self.get_run_count(state=state)

# check if all upstream tasks have finished
state = self.check_upstream_finished(
state, upstream_states_set=upstream_states_set
Expand Down Expand Up @@ -200,6 +203,26 @@ def run(

return state

@call_state_handlers
def get_run_count(self, state: State) -> State:
"""
If the task is being retried, then we retrieve the run count from the initial Retry
state. Otherwise, we assume the run count is 1. The run count is stored in context as
_task_run_count.

Args:
- state (State): the current state of the task

Returns:
State: the state of the task after running the check
"""
if isinstance(state, Retrying):
run_count = state.run_count + 1
else:
run_count = 1
prefect.context.update(_task_run_count=run_count)
return state

@call_state_handlers
def check_upstream_finished(
self, state: State, upstream_states_set: Set[State]
Expand Down Expand Up @@ -473,14 +496,17 @@ def check_for_retry(self, state: State, inputs: Dict[str, Any]) -> State:
State: the state of the task after running the check
"""
if state.is_failed():
run_number = prefect.context.get("_task_run_number", 1)
if run_number <= self.task.max_retries or isinstance(state, Retrying):
run_count = prefect.context.get("_task_run_count", 1)
if run_count <= self.task.max_retries:
start_time = datetime.datetime.utcnow() + self.task.retry_delay
msg = "Retrying Task (after attempt {n} of {m})".format(
n=run_number, m=self.task.max_retries + 1
n=run_count, m=self.task.max_retries + 1
)
return Retrying(
start_time=start_time, cached_inputs=inputs, message=msg
start_time=start_time,
cached_inputs=inputs,
message=msg,
run_count=run_count,
)

return state
34 changes: 34 additions & 0 deletions tests/engine/test_flow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,40 @@ def test_return_failed_includes_retries(self):
assert isinstance(state.result[e], Retrying)


class TestRunCount:
def test_run_count_tracked_via_retry_states(self):
flow = Flow()
t1 = ErrorTask(max_retries=1)
t2 = ErrorTask(max_retries=2)
flow.add_task(t1)
flow.add_task(t2)

# first run
state = FlowRunner(flow=flow).run(return_tasks=[t1, t2])
assert state.is_pending()
assert isinstance(state.result[t1], Retrying)
assert state.result[t1].run_count == 1
assert isinstance(state.result[t2], Retrying)
assert state.result[t2].run_count == 1

# second run
state = FlowRunner(flow=flow).run(
task_states=state.result, return_tasks=[t1, t2]
)
assert state.is_pending()
assert isinstance(state.result[t1], Failed)
assert isinstance(state.result[t2], Retrying)
assert state.result[t2].run_count == 2

# third run
state = FlowRunner(flow=flow).run(
task_states=state.result, return_tasks=[t1, t2]
)
assert state.is_failed()
assert isinstance(state.result[t1], Failed)
assert isinstance(state.result[t2], Failed)


@pytest.mark.skipif(
sys.version_info < (3, 5), reason="dask.distributed does not support Python 3.4"
)
Expand Down
21 changes: 21 additions & 0 deletions tests/engine/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import pytest

import prefect
from prefect.engine.signals import (
FAIL,
RETRY,
Expand Down Expand Up @@ -54,6 +55,26 @@ def test_retry_signals_can_set_retry_time():
assert exc.value.state.start_time == date


def test_retry_signals_accept_run_count():
with pytest.raises(PrefectStateSignal) as exc:
raise RETRY(run_count=5)
assert exc.value.state.run_count == 5


def test_retry_signals_take_run_count_from_context():
with prefect.context(_task_run_count=5):
with pytest.raises(PrefectStateSignal) as exc:
raise RETRY()
assert exc.value.state.run_count == 5


def test_retry_signals_prefer_supplied_run_count_to_context():
with prefect.context(_task_run_count=5):
with pytest.raises(PrefectStateSignal) as exc:
raise RETRY(run_count=6)
assert exc.value.state.run_count == 6


@pytest.mark.parametrize(
"signal,state",
[
Expand Down
18 changes: 17 additions & 1 deletion tests/engine/test_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime

import pytest

import prefect
from prefect.engine.state import (
CachedState,
Failed,
Expand Down Expand Up @@ -87,6 +87,22 @@ def test_timestamp_is_serialized():
assert state.timestamp == deserialized_state.timestamp


def test_retry_stores_run_count():
state = Retrying(run_count=2)
assert state.run_count == 2


def test_retry_stores_default_run_count():
state = Retrying()
assert state.run_count == 1


def test_retry_stores_default_run_count_in_context():
with prefect.context(_task_run_count=5):
state = Retrying()
assert state.run_count == 5


@pytest.mark.parametrize("cls", all_states)
def test_states_have_color(cls):
assert cls.color.startswith("#")
Expand Down
89 changes: 67 additions & 22 deletions tests/engine/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,27 @@ def test_task_that_raises_fail_is_marked_fail():
assert not isinstance(task_runner.run(), TriggerFailed)


def test_task_that_fails_gets_retried_up_to_1_time():
def test_task_that_fails_gets_retried_up_to_max_retry_time():
"""
Test that failed tasks are marked for retry if run_number is available
Test that failed tasks are marked for retry if run_count is available
"""
err_task = ErrorTask(max_retries=1)
err_task = ErrorTask(max_retries=2)
task_runner = TaskRunner(task=err_task)

# first run should be retrying
with prefect.context(_task_run_number=1):
state = task_runner.run()
# first run should be retry
state = task_runner.run()
assert isinstance(state, Retrying)
assert isinstance(state.start_time, datetime.datetime)
assert state.run_count == 1

# second run should
with prefect.context(_task_run_number=2):
state = task_runner.run(state=state)
# second run should retry
state = task_runner.run(state=state)
assert isinstance(state, Retrying)
assert isinstance(state.start_time, datetime.datetime)
assert state.run_count == 2

# second run should fail
state = task_runner.run(state=state)
assert isinstance(state, Failed)


Expand Down Expand Up @@ -162,14 +167,14 @@ def test_task_that_raises_retry_gets_retried_even_if_max_retries_is_set():
task_runner = TaskRunner(task=retry_task)

# first run should be retrying
with prefect.context(_task_run_number=1):
with prefect.context(_task_run_count=1):
state = task_runner.run()
assert isinstance(state, Retrying)
assert isinstance(state.start_time, datetime.datetime)

# second run should also be retry because the task raises it explicitly

with prefect.context(_task_run_number=2):
with prefect.context(_task_run_count=2):
state = task_runner.run(state=state)
assert isinstance(state, Retrying)

Expand Down Expand Up @@ -305,6 +310,34 @@ def test_task_runner_handles_secrets():
assert state.result is "my_private_str"


class TestGetRunCount:
@pytest.mark.parametrize(
"state", [Success(), Failed(), Pending(), Scheduled(), Skipped(), CachedState()]
)
def test_states_without_run_count(self, state):
with prefect.context() as ctx:
assert "_task_run_count" not in ctx
new_state = TaskRunner(Task()).get_run_count(state)
assert ctx._task_run_count == 1
assert new_state is state

@pytest.mark.parametrize(
"state",
[
Retrying(),
Retrying(run_count=1),
Retrying(run_count=2),
Retrying(run_count=10),
],
)
def test_states_with_run_count(self, state):
with prefect.context() as ctx:
assert "_task_run_count" not in ctx
new_state = TaskRunner(Task()).get_run_count(state)
assert ctx._task_run_count == state.run_count + 1
assert new_state is state


class TestCheckUpstreamFinished:
def test_with_empty_set(self):
state = Pending()
Expand Down Expand Up @@ -699,26 +732,29 @@ def fn(x):


class TestCheckRetryStep:
@pytest.mark.parametrize("state", [Success(), Pending(), Running(), Skipped()])
@pytest.mark.parametrize(
"state", [Success(), Pending(), Running(), Retrying(), Skipped()]
)
def test_non_failed_states(self, state):
new_state = TaskRunner(task=Task()).check_for_retry(state=state, inputs={})
assert new_state is state

def test_failed_no_retry(self):
def test_failed_zero_max_retry(self):
state = Failed()
new_state = TaskRunner(task=Task()).check_for_retry(state=state, inputs={})
assert new_state is state

def test_failed_one_retry(self):
def test_failed_one_max_retry(self):
state = Failed()
new_state = TaskRunner(task=Task(max_retries=1)).check_for_retry(
state=state, inputs={}
)
assert isinstance(new_state, Retrying)
assert new_state.run_count == 1

def test_failed_one_retry_second_run(self):
def test_failed_one_max_retry_second_run(self):
state = Failed()
with prefect.context(_task_run_number=2):
with prefect.context(_task_run_count=2):
new_state = TaskRunner(task=Task(max_retries=1)).check_for_retry(
state=state, inputs={}
)
Expand All @@ -732,12 +768,21 @@ def test_failed_retry_caches_inputs(self):
assert isinstance(new_state, Retrying)
assert new_state.cached_inputs == {"x": 1}

def test_retrying_with_start_time(self):
state = Retrying(start_time=datetime.datetime.utcnow())
new_state = TaskRunner(task=Task(max_retries=1)).check_for_retry(
state=state, inputs={}
)
assert new_state is state
def test_retrying_when_run_count_greater_than_max_retries(self):
with prefect.context(_task_run_count=10):
state = Retrying()
new_state = TaskRunner(task=Task(max_retries=1)).check_for_retry(
state=state, inputs={}
)
assert new_state is state

def test_retrying_when_state_has_explicit_run_count_set(self):
with prefect.context(_task_run_count=10):
state = Retrying(run_count=5)
new_state = TaskRunner(task=Task(max_retries=1)).check_for_retry(
state=state, inputs={}
)
assert new_state is state


class TestCacheResultStep:
Expand Down