Skip to content

Commit

Permalink
[AIRFLOW-3958] Support list tasks as upstream in chain
Browse files Browse the repository at this point in the history
helpers.chain only support list as downstream
This PR make list as upstream work, also make
list parallel work, which like below

     / -> t2 -> t4 \
t1 ->               -> t6
     \ -> t3 -> t5 /
  • Loading branch information
zhongjiajie committed May 28, 2019
1 parent a8a4d32 commit 37c27da
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 10 deletions.
46 changes: 39 additions & 7 deletions airflow/utils/helpers.py
Expand Up @@ -24,6 +24,7 @@
from builtins import input
from datetime import datetime
from functools import reduce
from collections import Iterable
import os
import re
import signal
Expand Down Expand Up @@ -150,19 +151,50 @@ def as_flattened_list(iterable):


def chain(*tasks):
"""
r"""
Given a number of tasks, builds a dependency chain.
Support mix airflow.models.BaseOperator and List[airflow.models.BaseOperator].
If you want to chain between two List[airflow.models.BaseOperator], have to
make sure they have same length.
chain(task_1, task_2, task_3, task_4)
chain(t1, [t2, t3], [t4, t5], t6)
is equivalent to
task_1.set_downstream(task_2)
task_2.set_downstream(task_3)
task_3.set_downstream(task_4)
/ -> t2 -> t4 \
t1 -> t6
\ -> t3 -> t5 /
t1.set_downstream(t2)
t1.set_downstream(t3)
t2.set_downstream(t4)
t3.set_downstream(t5)
t4.set_downstream(t6)
t5.set_downstream(t6)
:param tasks: List of tasks or List[airflow.models.BaseOperator] to set dependencies
:type tasks: List[airflow.models.BaseOperator] or airflow.models.BaseOperator
"""
for up_task, down_task in zip(tasks[:-1], tasks[1:]):
up_task.set_downstream(down_task)
from airflow.models import BaseOperator

for index, up_task in enumerate(tasks[:-1]):
down_task = tasks[index + 1]
if isinstance(up_task, BaseOperator):
up_task.set_downstream(down_task)
elif isinstance(down_task, BaseOperator):
down_task.set_upstream(up_task)
else:
if not isinstance(up_task, Iterable) or not isinstance(down_task, Iterable):
raise TypeError(
'Chain not supported between instances of {up_type} and {down_type}'.format(
up_type=type(up_task), down_type=type(down_task)))
elif len(up_task) != len(down_task):
raise AirflowException(
'Chain not supported different length Iterable but get {up_len} and {down_len}'.format(
up_len=len(up_task), down_len=len(down_task)))
else:
for up, down in zip(up_task, down_task):
up.set_downstream(down)


def cross_downstream(from_tasks, to_tasks):
Expand Down
80 changes: 77 additions & 3 deletions docs/concepts.rst
Expand Up @@ -185,6 +185,9 @@ Bitshift Composition

*Added in Airflow 1.8*

We recommend you setting operator relationships with bitshift operators rather than ``set_upstream()``
and ``set_downstream()``.

Traditionally, operator relationships are set with the ``set_upstream()`` and
``set_downstream()`` methods. In Airflow 1.8, this can be done with the Python
bitshift operators ``>>`` and ``<<``. The following four statements are all
Expand Down Expand Up @@ -248,21 +251,92 @@ Bitshift can also be used with lists. For example:

.. code:: python
op1 >> [op2, op3]
op1 >> [op2, op3] >> op4
is equivalent to:

.. code:: python
op1 >> op2
op1 >> op3
op1 >> op2 >> op4
op1 >> op3 >> op4
and equivalent to:

.. code:: python
op1.set_downstream([op2, op3])
Relationship Helper
--------------------

``chain`` and ``cross_downstream`` function provide easier ways to set relationships
between operators in specific situation.

When setting relationships between two list of operators and wish all up list
operators as upstream to all down list operators, we have to split one list
manually using bitshift composition.

.. code:: python
[op1, op2, op3] >> op4
[op1, op2, op3] >> op5
[op1, op2, op3] >> op6
``cross_downstream`` could handle list relationships easier.

.. code:: python
cross_downstream([op1, op2, op3], [op4, op5, op6])
When setting single direction relationships to many operators, we could
concat them with bitshift composition.

.. code:: python
op1 >> op2 >> op3 >> op4 >> op5
use ``chain`` could do that

.. code:: python
chain(op1, op2, op3, op4, op5)
even without operator's name

.. code:: python
chain([DummyOperator(task_id='op' + i, dag=dag) for i in range(1, 6)])
``chain`` could handle list of operators

.. code:: python
chain(op1, [op2, op3], op4)
is equivalent to:

.. code:: python
op1 >> [op2, op3] >> op4
Have to same size when ``chain`` set relationships between two list
of operators.

.. code:: python
chain(op1, [op2, op3], [op4, op5], op6)
is equivalent to:

.. code:: python
op1 >> [op2, op3]
op2 >> op4
op3 >> op5
[op4, op5] >> op6
Tasks
=====

Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_helpers.py
Expand Up @@ -32,6 +32,7 @@
from airflow.utils import helpers
from airflow.models import TaskInstance
from airflow.operators.dummy_operator import DummyOperator
from airflow.exceptions import AirflowException


class TestHelpers(unittest.TestCase):
Expand Down Expand Up @@ -248,6 +249,28 @@ def test_cross_downstream(self):
for start_task in start_tasks:
six.assertCountEqual(self, start_task.get_direct_relatives(upstream=False), end_tasks)

def test_chain(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[t1, t2, t3, t4, t5, t6] = [DummyOperator(task_id='t{i}'.format(i=i), dag=dag) for i in range(1, 7)]
helpers.chain(t1, [t2, t3], [t4, t5], t6)

self.assertCountEqual([t2, t3], t1.get_direct_relatives(upstream=False))
self.assertEqual([t4], t2.get_direct_relatives(upstream=False))
self.assertEqual([t5], t3.get_direct_relatives(upstream=False))
self.assertCountEqual([t4, t5], t6.get_direct_relatives(upstream=True))

def test_chain_not_support_type(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[t1, t2] = [DummyOperator(task_id='t{i}'.format(i=i), dag=dag) for i in range(1, 3)]
with self.assertRaises(TypeError):
helpers.chain([t1, t2], 1)

def test_chain_different_length_iterable(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[t1, t2, t3, t4, t5] = [DummyOperator(task_id='t{i}'.format(i=i), dag=dag) for i in range(1, 6)]
with self.assertRaises(AirflowException):
helpers.chain([t1, t2], [t3, t4, t5])


if __name__ == '__main__':
unittest.main()

0 comments on commit 37c27da

Please sign in to comment.