Skip to content

Commit

Permalink
Resolve upstream tasks when template field is XComArg
Browse files Browse the repository at this point in the history
closes: #8054
  • Loading branch information
turbaszek committed May 10, 2020
1 parent 280f1f0 commit d40c9d9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
48 changes: 46 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,25 @@
ScheduleInterval = Union[str, timedelta, relativedelta]


class BaseOperatorMeta(type):
"""
Base metaclass of BaseOperator.
"""
def __call__(cls, *args, **kwargs):
"""
Called when you call BaseOperator(). In this way we are able to perform an action
after initializing an operator no matter where the ``super().__init__`` is called
(before or after assign of new attributes in a custom operator).
"""
obj = type.__call__(cls, *args, **kwargs)
# Set upstream task defined by XComArgs passed to template fields of an operator
obj._set_xcomargs_dependencies() # pylint: disable=protected-access
return obj


# pylint: disable=too-many-instance-attributes,too-many-public-methods
@functools.total_ordering
class BaseOperator(Operator, LoggingMixin):
class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta):
"""
Abstract base class for all operators. Since operators create objects that
become nodes in the dag, BaseOperator contains many recursive methods for
Expand Down Expand Up @@ -244,6 +260,7 @@ class derived from this one results in the creation of a task object,
result
:type do_xcom_push: bool
"""

# For derived classes to define which fields will get jinjaified
template_fields: Iterable[str] = []
# Defines which files extensions to look for in the templated fields
Expand Down Expand Up @@ -634,6 +651,33 @@ def deps(self) -> Set[BaseTIDep]:
NotPreviouslySkippedDep(),
}

def _set_xcomargs_dependencies(self) -> None:
"""
Resolves upstream dependencies of a task. In this way passing an ``XComArg`
as value for a template field will result in creating upstream relation between
two tasks.
**Example**: ::
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(..., html_content=generate_content.output)
# This is equivalent to
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
)
generate_content >> send_email
"""
from airflow.models.xcom_arg import XComArg
for field in self.template_fields:
arg = getattr(self, field)
if isinstance(arg, XComArg):
self.set_upstream(arg.operator)

@property
def priority_weight_total(self) -> int:
"""
Expand Down Expand Up @@ -1141,7 +1185,7 @@ def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperat

@property
def output(self):
"""Returns default XComArg for the operator"""
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg
return XComArg(operator=self)

Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.models import DAG
from airflow.models.baseoperator import chain, cross_downstream
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.decorators import apply_defaults
from tests.models import DEFAULT_DATE
from tests.test_utils.mock_operators import MockNamedTuple, MockOperator

Expand Down Expand Up @@ -262,6 +263,33 @@ def test_email_on_actions(self):
assert test_task.email_on_retry is False
assert test_task.email_on_failure is True

def test_upstream_is_set_when_template_field_is_xcomarg(self):
class CustomOpSuperBefore(DummyOperator):
template_fields = ("field",)

@apply_defaults
def __init__(self, field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.field = field

class CustomOpSuperAfter(DummyOperator):
template_fields = ("field",)

@apply_defaults
def __init__(self, field, *args, **kwargs):
self.field = field
super().__init__(*args, **kwargs)

with DAG("test_dag", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOpSuperBefore(task_id="op2", field=op1.output)
op3 = CustomOpSuperAfter(task_id="op3", field=op1.output)

assert op1 in op2.upstream_list
assert op1 in op3.upstream_list
assert op2 in op1.downstream_list
assert op3 in op1.downstream_list


class TestBaseOperatorMethods(unittest.TestCase):
def test_cross_downstream(self):
Expand Down

0 comments on commit d40c9d9

Please sign in to comment.