Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New conditional api #2443

Merged
merged 9 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Task Library

- None
- Add new `case` control-flow construct, for nicer management of conditional tasks - [#2443](https://github.com/PrefectHQ/prefect/pull/2443)

### Fixes

Expand Down
39 changes: 32 additions & 7 deletions docs/core/task_library/control_flow.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,60 @@
# Control Flow

Tasks and utilities for implementing control flow constructs like branching and rejoining flows.
Tasks and utilities for implementing control flow constructs like branching and
rejoining flows.

## Case <Badge text="fn"/>

A conditional block in a flow.

Used as a context-manager, ``case(task, value)`` creates a block of tasks that
are only run if the result of ``task`` is equal to ``value``.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-case-case)

## If/Else <Badge text="fn"/>

Builds a conditional branch into a workflow.

If the condition evaluates True(ish), the true_task will run. If it evaluates False(ish), the false_task will run. The task doesn't run is Skipped, as are all downstream tasks that don't set `skip_on_upstream_skip=False`.
If the condition evaluates True(ish), the ``true_task`` will run. If it
evaluates False(ish), the ``false_task`` will run. The task that doesn't run is
Skipped, as are all downstream tasks that don't set
`skip_on_upstream_skip=False`.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-conditional-ifelse)


## Switch <Badge text="task"/>
## Switch <Badge text="fn"/>

Adds a SWITCH to a workflow.

The condition task is evaluated and the result is compared to the keys of the cases dictionary. The task corresponding to the matching key is run; all other tasks are skipped. Any tasks downstream of the skipped tasks are also skipped unless they set `skip_on_upstream_skip=False`.
The condition task is evaluated and the result is compared to the keys of the
cases dictionary. The task corresponding to the matching key is run; all other
tasks are skipped. Any tasks downstream of the skipped tasks are also skipped
unless they set `skip_on_upstream_skip=False`.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-conditional-switch)


## Merge <Badge text="task"/>

Merges conditional branches back together.

A conditional branch in a flow results in one or more tasks proceeding and one or more tasks skipping. It is often convenient to merge those branches back into a single result. This function is a simple way to achieve that goal.
A conditional branch in a flow results in one or more tasks proceeding and one
or more tasks skipping. It is often convenient to merge those branches back
into a single result. This function is a simple way to achieve that goal.

The merge will return the first real result it encounters, or `None`. If multiple tasks might return a result, group them with a list.
The merge will return the first real result it encounters, or `None`. If
multiple tasks might return a result, group them with a list.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-conditional-merge)

## FilterTask <Badge text="task"/>

Task for filtering lists of results.

The default filter is to filter out `NoResults` and `Exceptions` for filtering out mapped results. Note that this task has a default trigger of `all_finished` and `skip_on_upstream_skip=False`.
The default filter is to filter out `NoResults` and `Exceptions` for filtering
out mapped results. Note that this task has a default trigger of `all_finished`
and `skip_on_upstream_skip=False`.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-filter-filtertask)
2 changes: 1 addition & 1 deletion docs/outline.toml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ classes = ["DaskKubernetesEnvironment", "DaskCloudProviderEnvironment", "Fargate
[pages.tasks.control_flow]
title = "Control Flow Tasks"
module = "prefect.tasks.control_flow"
classes = ["FilterTask"]
classes = ["FilterTask", "case"]
functions = ["switch", "ifelse", "merge"]

[pages.tasks.airtable]
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ def bind(
if not flow:
raise ValueError("Could not infer an active Flow context.")

case = prefect.context.get("case", None)
if case is not None:
case.add_task(self)

self.set_dependencies(
flow=flow,
upstream_tasks=upstream_tasks,
Expand Down
1 change: 1 addition & 0 deletions src/prefect/tasks/control_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from prefect.tasks.control_flow.conditional import ifelse, switch, merge
from prefect.tasks.control_flow.filter import FilterTask
from prefect.tasks.control_flow.case import case
97 changes: 97 additions & 0 deletions src/prefect/tasks/control_flow/case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any

import prefect
from prefect import Task

from .conditional import CompareValue


__all__ = ("case",)


class case(object):
"""A conditional block in a flow definition.

Used as a context-manager, ``case`` creates a block of tasks that are only
run if the result of ``task`` is equal to ``value``.

Args:
- task (Task): The task to use in the comparison
- value (Any): A constant the result of ``task`` will be compared with

Example:

A ``case`` block is similar to Python's if-blocks. It delimits a block
of tasks that will only be run if the result of ``task`` is equal to
``value``:

```python
# Standard python code
if task == value:
res = run_if_task_equals_value()
other_task(res)

# Equivalent prefect code
with case(task, value):
# Tasks created in this block are only run if the
# result of ``task`` is equal to ``value``
res = run_if_task_equals_value()
other_task(run)
```

The ``value`` argument can be any non-task object. Here we branch on a
string result:

```python
with Flow("example") as flow:
cond = condition()

with case(cond, "a"):
run_if_cond_is_a()

with case(cond, "b"):
run_if_cond_is_b()
```
"""

def __init__(self, task: Task, value: Any):
if isinstance(value, Task):
raise TypeError("`value` cannot be a task")

self.task = task
self.value = value
self._tasks = set()

def add_task(self, task: Task) -> None:
"""Add a new task under the case statement.

Args:
- task (Task): the task to add
"""
self._tasks.add(task)

def __enter__(self):
self.__prev_case = prefect.context.get("case")
prefect.context.update(case=self)

def __exit__(self, *args):
if self.__prev_case is None:
prefect.context.pop("case", None)
else:
prefect.context.update(case=self.__prev_case)

if self._tasks:

flow = prefect.context.get("flow", None)
if not flow:
raise ValueError("Could not infer an active Flow context.")

cond = CompareValue(self.value, name=f"case({self.value})",).bind(
value=self.task
)

for child in self._tasks:
# If a task has no upstream tasks created in this case block,
# the case conditional should be set as an upstream task.
if not self._tasks.intersection(flow.upstream_tasks(child)):
child.set_upstream(cond)
142 changes: 141 additions & 1 deletion tests/tasks/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from prefect import Flow, Task, task
from prefect.engine.result import NoResult
from prefect.engine.state import Skipped, Success
from prefect.tasks.control_flow import FilterTask, ifelse, merge, switch
from prefect.tasks.control_flow import FilterTask, ifelse, merge, switch, case
from prefect.tasks.control_flow.conditional import CompareValue
from prefect.tasks.core.constants import Constant


Expand Down Expand Up @@ -324,3 +325,142 @@ def test_filter_func_can_be_changed(self):
res = task.run([NoResult, NoResult, 0, 1, 5, "", exc])
assert len(res) == 6
assert res == [NoResult, NoResult, 0, 1, "", exc]


@task
def identity(x):
return x


@task
def inc(x):
return x + 1


class TestCase:
def test_case_errors(self):
with Flow("test"):
with pytest.raises(TypeError, match="`value` cannot be a task"):
with case(identity(True), identity(False)):
pass

def test_empty_case_block_no_tasks(self):
with Flow("test") as flow:
cond = identity(True)
with case(cond, True):
pass

# No tasks added if case block is empty
assert flow.tasks == {cond}

def test_case_sets_and_clears_context(self):
with Flow("test"):
c1 = case(identity(True), True)
c2 = case(identity(False), True)
assert "case" not in prefect.context
with c1:
assert prefect.context["case"] is c1
with c2:
assert prefect.context["case"] is c2
assert prefect.context["case"] is c1
assert "case" not in prefect.context

def test_case_upstream_tasks(self):
with Flow("test") as flow:
a = identity(0)
with case(identity(True), True):
b = inc(a)
c = inc(b)
d = identity(1)
e = inc(d)
f = inc(e)

compare_values = [t for t in flow.tasks if isinstance(t, CompareValue)]
assert len(compare_values) == 1
comp = compare_values[0]

assert flow.upstream_tasks(a) == set()
assert flow.upstream_tasks(b) == {comp, a}
assert flow.upstream_tasks(c) == {b}
assert flow.upstream_tasks(d) == {comp}
assert flow.upstream_tasks(e) == {d}
assert flow.upstream_tasks(f) == {e}

@pytest.mark.parametrize("branch", ["a", "b", "c"])
def test_case_execution(self, branch):
with Flow("test") as flow:
cond = identity(branch)
with case(cond, "a"):
a = identity(1)
b = inc(a)

with case(cond, "b"):
c = identity(3)
d = inc(c)

e = merge(b, d)

state = flow.run()

if branch == "a":
assert state.result[a].result == 1
assert state.result[b].result == 2
assert state.result[c].is_skipped()
assert state.result[d].is_skipped()
assert state.result[e].result == 2
elif branch == "b":
assert state.result[a].is_skipped()
assert state.result[b].is_skipped()
assert state.result[c].result == 3
assert state.result[d].result == 4
assert state.result[e].result == 4
elif branch == "c":
for t in [a, b, c, d, e]:
assert state.result[t].is_skipped()

@pytest.mark.parametrize("branch1", [True, False])
@pytest.mark.parametrize("branch2", [True, False])
def test_nested_case_execution(self, branch1, branch2):
with Flow("test") as flow:
cond1 = identity(branch1)

a = identity(0)
with case(cond1, True):
cond2 = identity(branch2)
b = identity(10)
with case(cond2, True):
c = inc(a)
d = inc(c)
with case(cond2, False):
e = inc(b)
f = inc(e)

g = merge(d, f)

with case(cond1, False):
h = identity(3)
i = inc(h)

j = merge(g, i)

state = flow.run()

sol = {a: 0, cond1: branch1}
if branch1:
sol[cond2] = branch2
sol[b] = 10
if branch2:
sol[c] = 1
sol[d] = sol[g] = sol[j] = 2
else:
sol[e] = 11
sol[f] = sol[g] = sol[j] = 12
else:
sol[h] = 3
sol[i] = sol[j] = 4

for t in [cond1, cond2, a, b, c, d, e, f, g, h, i, j]:
if t in sol:
assert state.result[t].result == sol[t]
else:
assert state.result[t].is_skipped()