From a729e84e9d4562f0dfe718948dcae52523bdf752 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Mon, 14 Sep 2020 13:29:10 +0200 Subject: [PATCH 1/9] Introduce TaskMixin Both BaseOperator and XComArgs implement bit shift operators used to chain tasks in DAGs. By extracting this logic to new mixin we reduce code duplication and make it easier to implement it in future. Closes: #10926 --- airflow/models/baseoperator.py | 35 +---------------- airflow/models/taskmixin.py | 72 ++++++++++++++++++++++++++++++++++ airflow/models/xcom_arg.py | 33 +--------------- 3 files changed, 76 insertions(+), 64 deletions(-) create mode 100644 airflow/models/taskmixin.py diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 27013f25a769a..6efa41477b59c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -42,6 +42,7 @@ from airflow.models.base import Operator from airflow.models.pool import Pool from airflow.models.taskinstance import Context, TaskInstance, clear_task_instances +from airflow.models.taskmixin import TaskMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -84,7 +85,7 @@ def __call__(cls, *args, **kwargs): # pylint: disable=too-many-instance-attributes,too-many-public-methods @functools.total_ordering -class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta): +class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta): """ Abstract base class for all operators. Since operators create objects that become nodes in the dag, BaseOperator contains many recursive methods for @@ -491,38 +492,6 @@ def __hash__(self): hash_components.append(repr(val)) return hash(tuple(hash_components)) - # Composing Operators ----------------------------------------------- - - def __rshift__(self, other): - """ - Implements Self >> Other == self.set_downstream(other) - """ - self.set_downstream(other) - return other - - def __lshift__(self, other): - """ - Implements Self << Other == self.set_upstream(other) - """ - self.set_upstream(other) - return other - - def __rrshift__(self, other): - """ - Called for Operator >> [Operator] because list don't have - __rshift__ operators. - """ - self.__lshift__(other) - return self - - def __rlshift__(self, other): - """ - Called for Operator << [Operator] because list don't have - __lshift__ operators. - """ - self.__rshift__(other) - return self - # including lineage information def __or__(self, other): """ diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py new file mode 100644 index 0000000000000..6772368b2425e --- /dev/null +++ b/airflow/models/taskmixin.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from abc import abstractmethod + + +class TaskMixin: + """ + Mixing implementing common chain methods like >> and <<. + + In the following functions we use: + Task = Union[BaseOperator, XComArg] + No type annotations due to cyclic imports. + """ + + @abstractmethod + def set_upstream(self, other): + """ + Set a task or a task list to be directly upstream from the current task. + """ + raise NotImplementedError() + + @abstractmethod + def set_downstream(self, other): + """ + Set a task or a task list to be directly downstream from the current task. + """ + raise NotImplementedError() + + def __lshift__(self, other): + """ + Implements Task << Task + """ + self.set_upstream(other) + return other + + def __rshift__(self, other): + """ + Implements Task >> Task + """ + self.set_downstream(other) + return other + + def __rrshift__(self, other): + """ + Called for Task >> [Task] because list don't have + __rshift__ operators. + """ + self.__lshift__(other) + return self + + def __rlshift__(self, other): + """ + Called for Task >> [Task] because list don't have + __lshift__ operators. + """ + self.__rshift__(other) + return self diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index b85d638fc14c3..8ae3a6cbb53a5 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -19,10 +19,11 @@ from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401 +from airflow.models.taskmixin import TaskMixin from airflow.models.xcom import XCOM_RETURN_KEY -class XComArg: +class XComArg(TaskMixin): """ Class that represents a XCom push from a previous operator. Defaults to "return_value" as only key. @@ -65,36 +66,6 @@ def __eq__(self, other): return (self.operator == other.operator and self.key == other.key) - def __lshift__(self, other): - """ - Implements XComArg << op - """ - self.set_upstream(other) - return other - - def __rshift__(self, other): - """ - Implements XComArg >> op - """ - self.set_downstream(other) - return other - - def __rrshift__(self, other): - """ - Called for XComArg >> [XComArg] because list don't have - __rshift__ operators. - """ - self.__lshift__(other) - return self - - def __rlshift__(self, other): - """ - Called for XComArg >> [XComArg] because list don't have - __lshift__ operators. - """ - self.__rshift__(other) - return self - def __getitem__(self, item): """ Implements xcomresult['some_result_key'] From 0b091fd970d0dce2c394c942232a09224c725ad2 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Mon, 14 Sep 2020 15:03:15 +0200 Subject: [PATCH 2/9] Adjust pylintrc --- pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylintrc b/pylintrc index c9c7996bf0b47..67a4c626ec2a5 100644 --- a/pylintrc +++ b/pylintrc @@ -570,7 +570,7 @@ max-branches=22 max-locals=24 # Maximum number of parents for a class (see R0901). -max-parents=7 +max-parents=8 # Maximum number of public methods for a class (see R0904). # BasPH: choose 27 because this was 50% of the sorted list of 30 number of public methods above 20 (Pylint default) From 7885ddbd43103461376550a185158f46788c160e Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Mon, 14 Sep 2020 18:18:14 +0200 Subject: [PATCH 3/9] fixup! Adjust pylintrc --- airflow/models/taskmixin.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 6772368b2425e..50d4bf77ccbc3 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -16,6 +16,7 @@ # under the License. from abc import abstractmethod +from typing import Sequence, Union class TaskMixin: @@ -28,45 +29,43 @@ class TaskMixin: """ @abstractmethod - def set_upstream(self, other): + def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ Set a task or a task list to be directly upstream from the current task. """ raise NotImplementedError() @abstractmethod - def set_downstream(self, other): + def set_downstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ Set a task or a task list to be directly downstream from the current task. """ raise NotImplementedError() - def __lshift__(self, other): + def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ Implements Task << Task """ self.set_upstream(other) return other - def __rshift__(self, other): + def __rshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ Implements Task >> Task """ self.set_downstream(other) return other - def __rrshift__(self, other): + def __rrshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ - Called for Task >> [Task] because list don't have - __rshift__ operators. + Called for Task >> [Task] because list don't have __rshift__ operators. """ self.__lshift__(other) return self - def __rlshift__(self, other): + def __rlshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ - Called for Task >> [Task] because list don't have - __lshift__ operators. + Called for Task << [Task] because list don't have __lshift__ operators. """ self.__rshift__(other) return self From 7fba960164014265cfddd891bbb28d75328c5573 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Tue, 15 Sep 2020 10:50:08 +0200 Subject: [PATCH 4/9] Add operator property to TaskMixin --- airflow/models/baseoperator.py | 35 +++++++++++++++------------------- airflow/models/taskmixin.py | 7 +++++++ airflow/models/xcom_arg.py | 14 ++++++-------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 6efa41477b59c..34de76aec6242 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1115,27 +1115,22 @@ def add_only_new(self, item_set: Set[str], item: str) -> None: else: item_set.add(item) + @property + def operator(self) -> "BaseOperator": + """Required by TaskMixin""" + return self + def _set_relatives(self, - task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']], + task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], upstream: bool = False) -> None: """Sets relatives for the task or task list.""" - from airflow.models.xcom_arg import XComArg - if isinstance(task_or_task_list, XComArg): - # otherwise we will start to iterate over xcomarg - # because of the "list" check below - # with current XComArg.__getitem__ implementation - task_list = [task_or_task_list.operator] - else: - try: - task_list = list(task_or_task_list) # type: ignore - except TypeError: - task_list = [task_or_task_list] # type: ignore + try: + task_list = len(task_or_task_list) # type: ignore + except TypeError: + task_list = [task_or_task_list] # type: ignore - task_list = [ - t.operator if isinstance(t, XComArg) else t - for t in task_list - ] + task_list: List["BaseOperator"] = [t.operator for t in task_list] for task in task_list: if not isinstance(task, BaseOperator): @@ -1174,17 +1169,17 @@ def _set_relatives(self, self.add_only_new(self._downstream_task_ids, task.task_id) task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id) - def set_downstream(self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']]) -> None: + def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None: """ Set a task or a task list to be directly downstream from the current - task. + task. Required by TaskMixin. """ self._set_relatives(task_or_task_list, upstream=False) - def set_upstream(self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']]) -> None: + def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None: """ Set a task or a task list to be directly upstream from the current - task. + task. Required by TaskMixin. """ self._set_relatives(task_or_task_list, upstream=True) diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 50d4bf77ccbc3..face470efba9b 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -28,6 +28,13 @@ class TaskMixin: No type annotations due to cyclic imports. """ + @property + def operator(self): + """ + Returns underlying operator + """ + raise NotImplementedError() + @abstractmethod def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 8ae3a6cbb53a5..3bf41ac7ec3f3 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Sequence from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401 @@ -94,7 +94,7 @@ def __str__(self): @property def operator(self) -> BaseOperator: - """Returns operator of this XComArg""" + """Returns operator of this XComArg. Required by TaskMixin""" return self._operator @property @@ -102,17 +102,15 @@ def key(self) -> str: """Returns keys of this XComArg""" return self._key - def set_upstream(self, task_or_task_list: Union[BaseOperator, List[BaseOperator]]): + def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]): """ - Proxy to underlying operator set_upstream method + Proxy to underlying operator set_upstream method. Required by TaskMixin. """ self.operator.set_upstream(task_or_task_list) - def set_downstream( - self, task_or_task_list: Union[BaseOperator, List[BaseOperator]] - ): + def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]): """ - Proxy to underlying operator set_downstream method + Proxy to underlying operator set_downstream method. Required by TaskMixin. """ self.operator.set_downstream(task_or_task_list) From 163a95c25c4303fbd419cade5905c08bd8dc4ff1 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Tue, 15 Sep 2020 13:29:02 +0200 Subject: [PATCH 5/9] Fix types --- airflow/models/baseoperator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 34de76aec6242..220ad78593ecb 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1126,11 +1126,13 @@ def _set_relatives(self, """Sets relatives for the task or task list.""" try: - task_list = len(task_or_task_list) # type: ignore + # Check if this is sequence + len(task_or_task_list) + task_like_object_list = task_or_task_list except TypeError: - task_list = [task_or_task_list] # type: ignore + task_like_object_list = [task_or_task_list] - task_list: List["BaseOperator"] = [t.operator for t in task_list] + task_list: List["BaseOperator"] = [t.operator for t in task_like_object_list] for task in task_list: if not isinstance(task, BaseOperator): From db38495ec3e4329ef64f00c085db16e204e65968 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Tue, 15 Sep 2020 16:36:34 +0200 Subject: [PATCH 6/9] Fix types --- airflow/models/baseoperator.py | 6 ++---- airflow/models/xcom_arg.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 220ad78593ecb..01651b98f8e31 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1125,11 +1125,9 @@ def _set_relatives(self, upstream: bool = False) -> None: """Sets relatives for the task or task list.""" - try: - # Check if this is sequence - len(task_or_task_list) + if isinstance(task_or_task_list, Sequence): task_like_object_list = task_or_task_list - except TypeError: + else: task_like_object_list = [task_or_task_list] task_list: List["BaseOperator"] = [t.operator for t in task_like_object_list] diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 3bf41ac7ec3f3..3a1773bb8c8db 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union, Sequence +from typing import Any, Dict, Sequence, Union from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401 From 47944c42f544685fea5cee066b596b6318e02102 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Wed, 16 Sep 2020 10:39:57 +0200 Subject: [PATCH 7/9] Rename operator to roots --- airflow/models/baseoperator.py | 20 ++++++++++++-------- airflow/models/taskmixin.py | 6 ++---- airflow/models/xcom_arg.py | 9 +++++++-- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 01651b98f8e31..4058f055b54e7 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1116,13 +1116,15 @@ def add_only_new(self, item_set: Set[str], item: str) -> None: item_set.add(item) @property - def operator(self) -> "BaseOperator": + def roots(self) -> List["BaseOperator"]: """Required by TaskMixin""" - return self + return [self] - def _set_relatives(self, - task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], - upstream: bool = False) -> None: + def _set_relatives( + self, + task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], + upstream: bool = False, + ) -> None: """Sets relatives for the task or task list.""" if isinstance(task_or_task_list, Sequence): @@ -1130,7 +1132,9 @@ def _set_relatives(self, else: task_like_object_list = [task_or_task_list] - task_list: List["BaseOperator"] = [t.operator for t in task_like_object_list] + task_list: List["BaseOperator"] = [] + for task_object in task_like_object_list: + task_list.extend(task_object.roots) for task in task_list: if not isinstance(task, BaseOperator): @@ -1141,8 +1145,8 @@ def _set_relatives(self, # relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. dags = { - task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access - for task in [self] + task_list if task.has_dag()} + task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access,no-member + for task in self.roots + task_list if task.has_dag()} # pylint: disable=no-member if len(dags) > 1: raise AirflowException( diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index face470efba9b..a3d42242ca4f7 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -29,10 +29,8 @@ class TaskMixin: """ @property - def operator(self): - """ - Returns underlying operator - """ + def roots(self): + """Should return list of root operator List[BaseOperator]""" raise NotImplementedError() @abstractmethod diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 3a1773bb8c8db..0f647bf4f86f2 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Sequence, Union +from typing import Any, Dict, List, Sequence, Union from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401 @@ -94,9 +94,14 @@ def __str__(self): @property def operator(self) -> BaseOperator: - """Returns operator of this XComArg. Required by TaskMixin""" + """Returns operator of this XComArg.""" return self._operator + @property + def roots(self) -> List[BaseOperator]: + """Required by TaskMixin""" + return [self._operator] + @property def key(self) -> str: """Returns keys of this XComArg""" From 06a408057a019399994a8d3c5742ef4866167cc8 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Wed, 16 Sep 2020 11:04:16 +0200 Subject: [PATCH 8/9] Typing in example --- airflow/example_dags/example_xcomargs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/example_dags/example_xcomargs.py b/airflow/example_dags/example_xcomargs.py index 24df48ed8f9a6..9165c97ead183 100644 --- a/airflow/example_dags/example_xcomargs.py +++ b/airflow/example_dags/example_xcomargs.py @@ -66,4 +66,4 @@ def print_value(value): xcom_args_a = print_value("first!") # type: ignore xcom_args_b = print_value("second!") # type: ignore - bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 + bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 # type: ignore From c5ee4a86fe4f38232dbb5c4f8aa5842620f04e0a Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Wed, 16 Sep 2020 12:33:52 +0200 Subject: [PATCH 9/9] Typing in example --- airflow/example_dags/example_xcomargs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/example_dags/example_xcomargs.py b/airflow/example_dags/example_xcomargs.py index 9165c97ead183..7ec89a7617bf5 100644 --- a/airflow/example_dags/example_xcomargs.py +++ b/airflow/example_dags/example_xcomargs.py @@ -32,7 +32,7 @@ def generate_value(): return "Bring me a shrubbery!" -@task +@task() def print_value(value): """Dummy function""" ctx = get_current_context() @@ -63,7 +63,7 @@ def print_value(value): ) as dag2: bash_op1 = BashOperator(task_id="c", bash_command="echo c") bash_op2 = BashOperator(task_id="d", bash_command="echo c") - xcom_args_a = print_value("first!") # type: ignore - xcom_args_b = print_value("second!") # type: ignore + xcom_args_a = print_value("first!") + xcom_args_b = print_value("second!") - bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 # type: ignore + bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2