Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce TaskMixin #10930

Merged
merged 9 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 24 additions & 56 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.models.base import Operator
from airflow.models.pool import Pool
from airflow.models.taskinstance import Context, TaskInstance, clear_task_instances
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
Expand Down Expand Up @@ -84,7 +85,7 @@ def __call__(cls, *args, **kwargs):

# pylint: disable=too-many-instance-attributes,too-many-public-methods
@functools.total_ordering
class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta):
class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta):
"""
Abstract base class for all operators. Since operators create objects that
become nodes in the dag, BaseOperator contains many recursive methods for
Expand Down Expand Up @@ -491,38 +492,6 @@ def __hash__(self):
hash_components.append(repr(val))
return hash(tuple(hash_components))

# Composing Operators -----------------------------------------------

def __rshift__(self, other):
"""
Implements Self >> Other == self.set_downstream(other)
"""
self.set_downstream(other)
return other

def __lshift__(self, other):
"""
Implements Self << Other == self.set_upstream(other)
"""
self.set_upstream(other)
return other

def __rrshift__(self, other):
"""
Called for Operator >> [Operator] because list don't have
__rshift__ operators.
"""
self.__lshift__(other)
return self

def __rlshift__(self, other):
"""
Called for Operator << [Operator] because list don't have
__lshift__ operators.
"""
self.__rshift__(other)
return self

# including lineage information
def __or__(self, other):
"""
Expand Down Expand Up @@ -1146,27 +1115,26 @@ def add_only_new(self, item_set: Set[str], item: str) -> None:
else:
item_set.add(item)

def _set_relatives(self,
task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']],
upstream: bool = False) -> None:
@property
def roots(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return [self]

def _set_relatives(
self,
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
upstream: bool = False,
) -> None:
"""Sets relatives for the task or task list."""
from airflow.models.xcom_arg import XComArg

if isinstance(task_or_task_list, XComArg):
# otherwise we will start to iterate over xcomarg
# because of the "list" check below
# with current XComArg.__getitem__ implementation
task_list = [task_or_task_list.operator]
if isinstance(task_or_task_list, Sequence):
task_like_object_list = task_or_task_list
else:
try:
task_list = list(task_or_task_list) # type: ignore
except TypeError:
task_list = [task_or_task_list] # type: ignore
task_like_object_list = [task_or_task_list]

task_list = [
t.operator if isinstance(t, XComArg) else t
for t in task_list
]
task_list: List["BaseOperator"] = []
for task_object in task_like_object_list:
task_list.extend(task_object.roots)

for task in task_list:
if not isinstance(task, BaseOperator):
Expand All @@ -1177,8 +1145,8 @@ def _set_relatives(self,
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
dags = {
task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access
for task in [self] + task_list if task.has_dag()}
task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access,no-member
for task in self.roots + task_list if task.has_dag()} # pylint: disable=no-member

if len(dags) > 1:
raise AirflowException(
Expand All @@ -1205,17 +1173,17 @@ def _set_relatives(self,
self.add_only_new(self._downstream_task_ids, task.task_id)
task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id)

def set_downstream(self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']]) -> None:
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""
Set a task or a task list to be directly downstream from the current
task.
task. Required by TaskMixin.
"""
self._set_relatives(task_or_task_list, upstream=False)

def set_upstream(self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']]) -> None:
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""
Set a task or a task list to be directly upstream from the current
task.
task. Required by TaskMixin.
"""
self._set_relatives(task_or_task_list, upstream=True)

Expand Down
76 changes: 76 additions & 0 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from abc import abstractmethod
from typing import Sequence, Union


class TaskMixin:
"""
Mixing implementing common chain methods like >> and <<.

In the following functions we use:
Task = Union[BaseOperator, XComArg]
No type annotations due to cyclic imports.
"""

@property
def roots(self):
"""Should return list of root operator List[BaseOperator]"""
raise NotImplementedError()

@abstractmethod
def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Set a task or a task list to be directly upstream from the current task.
"""
raise NotImplementedError()

@abstractmethod
def set_downstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Set a task or a task list to be directly downstream from the current task.
"""
raise NotImplementedError()

def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Implements Task << Task
"""
self.set_upstream(other)
return other

def __rshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Implements Task >> Task
"""
self.set_downstream(other)
return other

def __rrshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Called for Task >> [Task] because list don't have __rshift__ operators.
"""
self.__lshift__(other)
return self

def __rlshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Called for Task << [Task] because list don't have __lshift__ operators.
"""
self.__rshift__(other)
return self
52 changes: 13 additions & 39 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY


class XComArg:
class XComArg(TaskMixin):
"""
Class that represents a XCom push from a previous operator.
Defaults to "return_value" as only key.
Expand Down Expand Up @@ -65,36 +66,6 @@ def __eq__(self, other):
return (self.operator == other.operator
and self.key == other.key)

def __lshift__(self, other):
"""
Implements XComArg << op
"""
self.set_upstream(other)
return other

def __rshift__(self, other):
"""
Implements XComArg >> op
"""
self.set_downstream(other)
return other

def __rrshift__(self, other):
"""
Called for XComArg >> [XComArg] because list don't have
__rshift__ operators.
"""
self.__lshift__(other)
return self

def __rlshift__(self, other):
"""
Called for XComArg >> [XComArg] because list don't have
__lshift__ operators.
"""
self.__rshift__(other)
return self

def __getitem__(self, item):
"""
Implements xcomresult['some_result_key']
Expand Down Expand Up @@ -123,25 +94,28 @@ def __str__(self):

@property
def operator(self) -> BaseOperator:
"""Returns operator of this XComArg"""
"""Returns operator of this XComArg."""
return self._operator

@property
def roots(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]

@property
def key(self) -> str:
"""Returns keys of this XComArg"""
return self._key

def set_upstream(self, task_or_task_list: Union[BaseOperator, List[BaseOperator]]):
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]):
"""
Proxy to underlying operator set_upstream method
Proxy to underlying operator set_upstream method. Required by TaskMixin.
"""
self.operator.set_upstream(task_or_task_list)

def set_downstream(
self, task_or_task_list: Union[BaseOperator, List[BaseOperator]]
):
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]):
"""
Proxy to underlying operator set_downstream method
Proxy to underlying operator set_downstream method. Required by TaskMixin.
"""
self.operator.set_downstream(task_or_task_list)

Expand Down
2 changes: 1 addition & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ max-branches=22
max-locals=24

# Maximum number of parents for a class (see R0901).
max-parents=7
max-parents=8

# Maximum number of public methods for a class (see R0904).
# BasPH: choose 27 because this was 50% of the sorted list of 30 number of public methods above 20 (Pylint default)
Expand Down