Skip to content

Commit

Permalink
Merge pull request #2443 from jcrist/new-conditional-api
Browse files Browse the repository at this point in the history
New conditional api
  • Loading branch information
jcrist committed May 5, 2020
2 parents 254c8c0 + a106524 commit 55fd1ea
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -14,7 +14,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Task Library

- None
- Add new `case` control-flow construct, for nicer management of conditional tasks - [#2443](https://github.com/PrefectHQ/prefect/pull/2443)

### Fixes

Expand Down
39 changes: 32 additions & 7 deletions docs/core/task_library/control_flow.md
@@ -1,35 +1,60 @@
# Control Flow

Tasks and utilities for implementing control flow constructs like branching and rejoining flows.
Tasks and utilities for implementing control flow constructs like branching and
rejoining flows.

## Case <Badge text="fn"/>

A conditional block in a flow.

Used as a context-manager, ``case(task, value)`` creates a block of tasks that
are only run if the result of ``task`` is equal to ``value``.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-case-case)

## If/Else <Badge text="fn"/>

Builds a conditional branch into a workflow.

If the condition evaluates True(ish), the true_task will run. If it evaluates False(ish), the false_task will run. The task doesn't run is Skipped, as are all downstream tasks that don't set `skip_on_upstream_skip=False`.
If the condition evaluates True(ish), the ``true_task`` will run. If it
evaluates False(ish), the ``false_task`` will run. The task that doesn't run is
Skipped, as are all downstream tasks that don't set
`skip_on_upstream_skip=False`.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-conditional-ifelse)


## Switch <Badge text="task"/>
## Switch <Badge text="fn"/>

Adds a SWITCH to a workflow.

The condition task is evaluated and the result is compared to the keys of the cases dictionary. The task corresponding to the matching key is run; all other tasks are skipped. Any tasks downstream of the skipped tasks are also skipped unless they set `skip_on_upstream_skip=False`.
The condition task is evaluated and the result is compared to the keys of the
cases dictionary. The task corresponding to the matching key is run; all other
tasks are skipped. Any tasks downstream of the skipped tasks are also skipped
unless they set `skip_on_upstream_skip=False`.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-conditional-switch)


## Merge <Badge text="task"/>

Merges conditional branches back together.

A conditional branch in a flow results in one or more tasks proceeding and one or more tasks skipping. It is often convenient to merge those branches back into a single result. This function is a simple way to achieve that goal.
A conditional branch in a flow results in one or more tasks proceeding and one
or more tasks skipping. It is often convenient to merge those branches back
into a single result. This function is a simple way to achieve that goal.

The merge will return the first real result it encounters, or `None`. If multiple tasks might return a result, group them with a list.
The merge will return the first real result it encounters, or `None`. If
multiple tasks might return a result, group them with a list.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-conditional-merge)

## FilterTask <Badge text="task"/>

Task for filtering lists of results.

The default filter is to filter out `NoResults` and `Exceptions` for filtering out mapped results. Note that this task has a default trigger of `all_finished` and `skip_on_upstream_skip=False`.
The default filter is to filter out `NoResults` and `Exceptions` for filtering
out mapped results. Note that this task has a default trigger of `all_finished`
and `skip_on_upstream_skip=False`.

[API Reference](/api/latest/tasks/control_flow.html#prefect-tasks-control-flow-filter-filtertask)
2 changes: 1 addition & 1 deletion docs/outline.toml
Expand Up @@ -179,7 +179,7 @@ classes = ["DaskKubernetesEnvironment", "DaskCloudProviderEnvironment", "Fargate
[pages.tasks.control_flow]
title = "Control Flow Tasks"
module = "prefect.tasks.control_flow"
classes = ["FilterTask"]
classes = ["FilterTask", "case"]
functions = ["switch", "ifelse", "merge"]

[pages.tasks.airtable]
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/core/task.py
Expand Up @@ -436,6 +436,10 @@ def bind(
if not flow:
raise ValueError("Could not infer an active Flow context.")

case = prefect.context.get("case", None)
if case is not None:
case.add_task(self)

self.set_dependencies(
flow=flow,
upstream_tasks=upstream_tasks,
Expand Down
1 change: 1 addition & 0 deletions src/prefect/tasks/control_flow/__init__.py
@@ -1,2 +1,3 @@
from prefect.tasks.control_flow.conditional import ifelse, switch, merge
from prefect.tasks.control_flow.filter import FilterTask
from prefect.tasks.control_flow.case import case
97 changes: 97 additions & 0 deletions src/prefect/tasks/control_flow/case.py
@@ -0,0 +1,97 @@
from typing import Any

import prefect
from prefect import Task

from .conditional import CompareValue


__all__ = ("case",)


class case(object):
"""A conditional block in a flow definition.
Used as a context-manager, ``case`` creates a block of tasks that are only
run if the result of ``task`` is equal to ``value``.
Args:
- task (Task): The task to use in the comparison
- value (Any): A constant the result of ``task`` will be compared with
Example:
A ``case`` block is similar to Python's if-blocks. It delimits a block
of tasks that will only be run if the result of ``task`` is equal to
``value``:
```python
# Standard python code
if task == value:
res = run_if_task_equals_value()
other_task(res)
# Equivalent prefect code
with case(task, value):
# Tasks created in this block are only run if the
# result of ``task`` is equal to ``value``
res = run_if_task_equals_value()
other_task(run)
```
The ``value`` argument can be any non-task object. Here we branch on a
string result:
```python
with Flow("example") as flow:
cond = condition()
with case(cond, "a"):
run_if_cond_is_a()
with case(cond, "b"):
run_if_cond_is_b()
```
"""

def __init__(self, task: Task, value: Any):
if isinstance(value, Task):
raise TypeError("`value` cannot be a task")

self.task = task
self.value = value
self._tasks = set()

def add_task(self, task: Task) -> None:
"""Add a new task under the case statement.
Args:
- task (Task): the task to add
"""
self._tasks.add(task)

def __enter__(self):
self.__prev_case = prefect.context.get("case")
prefect.context.update(case=self)

def __exit__(self, *args):
if self.__prev_case is None:
prefect.context.pop("case", None)
else:
prefect.context.update(case=self.__prev_case)

if self._tasks:

flow = prefect.context.get("flow", None)
if not flow:
raise ValueError("Could not infer an active Flow context.")

cond = CompareValue(self.value, name=f"case({self.value})",).bind(
value=self.task
)

for child in self._tasks:
# If a task has no upstream tasks created in this case block,
# the case conditional should be set as an upstream task.
if not self._tasks.intersection(flow.upstream_tasks(child)):
child.set_upstream(cond)
142 changes: 141 additions & 1 deletion tests/tasks/test_control_flow.py
Expand Up @@ -4,7 +4,8 @@
from prefect import Flow, Task, task
from prefect.engine.result import NoResult
from prefect.engine.state import Skipped, Success
from prefect.tasks.control_flow import FilterTask, ifelse, merge, switch
from prefect.tasks.control_flow import FilterTask, ifelse, merge, switch, case
from prefect.tasks.control_flow.conditional import CompareValue
from prefect.tasks.core.constants import Constant


Expand Down Expand Up @@ -324,3 +325,142 @@ def test_filter_func_can_be_changed(self):
res = task.run([NoResult, NoResult, 0, 1, 5, "", exc])
assert len(res) == 6
assert res == [NoResult, NoResult, 0, 1, "", exc]


@task
def identity(x):
return x


@task
def inc(x):
return x + 1


class TestCase:
def test_case_errors(self):
with Flow("test"):
with pytest.raises(TypeError, match="`value` cannot be a task"):
with case(identity(True), identity(False)):
pass

def test_empty_case_block_no_tasks(self):
with Flow("test") as flow:
cond = identity(True)
with case(cond, True):
pass

# No tasks added if case block is empty
assert flow.tasks == {cond}

def test_case_sets_and_clears_context(self):
with Flow("test"):
c1 = case(identity(True), True)
c2 = case(identity(False), True)
assert "case" not in prefect.context
with c1:
assert prefect.context["case"] is c1
with c2:
assert prefect.context["case"] is c2
assert prefect.context["case"] is c1
assert "case" not in prefect.context

def test_case_upstream_tasks(self):
with Flow("test") as flow:
a = identity(0)
with case(identity(True), True):
b = inc(a)
c = inc(b)
d = identity(1)
e = inc(d)
f = inc(e)

compare_values = [t for t in flow.tasks if isinstance(t, CompareValue)]
assert len(compare_values) == 1
comp = compare_values[0]

assert flow.upstream_tasks(a) == set()
assert flow.upstream_tasks(b) == {comp, a}
assert flow.upstream_tasks(c) == {b}
assert flow.upstream_tasks(d) == {comp}
assert flow.upstream_tasks(e) == {d}
assert flow.upstream_tasks(f) == {e}

@pytest.mark.parametrize("branch", ["a", "b", "c"])
def test_case_execution(self, branch):
with Flow("test") as flow:
cond = identity(branch)
with case(cond, "a"):
a = identity(1)
b = inc(a)

with case(cond, "b"):
c = identity(3)
d = inc(c)

e = merge(b, d)

state = flow.run()

if branch == "a":
assert state.result[a].result == 1
assert state.result[b].result == 2
assert state.result[c].is_skipped()
assert state.result[d].is_skipped()
assert state.result[e].result == 2
elif branch == "b":
assert state.result[a].is_skipped()
assert state.result[b].is_skipped()
assert state.result[c].result == 3
assert state.result[d].result == 4
assert state.result[e].result == 4
elif branch == "c":
for t in [a, b, c, d, e]:
assert state.result[t].is_skipped()

@pytest.mark.parametrize("branch1", [True, False])
@pytest.mark.parametrize("branch2", [True, False])
def test_nested_case_execution(self, branch1, branch2):
with Flow("test") as flow:
cond1 = identity(branch1)

a = identity(0)
with case(cond1, True):
cond2 = identity(branch2)
b = identity(10)
with case(cond2, True):
c = inc(a)
d = inc(c)
with case(cond2, False):
e = inc(b)
f = inc(e)

g = merge(d, f)

with case(cond1, False):
h = identity(3)
i = inc(h)

j = merge(g, i)

state = flow.run()

sol = {a: 0, cond1: branch1}
if branch1:
sol[cond2] = branch2
sol[b] = 10
if branch2:
sol[c] = 1
sol[d] = sol[g] = sol[j] = 2
else:
sol[e] = 11
sol[f] = sol[g] = sol[j] = 12
else:
sol[h] = 3
sol[i] = sol[j] = 4

for t in [cond1, cond2, a, b, c, d, e, f, g, h, i, j]:
if t in sol:
assert state.result[t].result == sol[t]
else:
assert state.result[t].is_skipped()

0 comments on commit 55fd1ea

Please sign in to comment.