diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index d15d10c6ed7ac..e2dc6e27f404b 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -60,9 +60,25 @@ ScheduleInterval = Union[str, timedelta, relativedelta] +class BaseOperatorMeta(type): + """ + Base metaclass of BaseOperator. + """ + def __call__(cls, *args, **kwargs): + """ + Called when you call BaseOperator(). In this way we are able to perform an action + after initializing an operator no matter where the ``super().__init__`` is called + (before or after assign of new attributes in a custom operator). + """ + obj = type.__call__(cls, *args, **kwargs) + # Set upstream task defined by XComArgs passed to template fields of an operator + obj._set_xcomargs_dependencies() # pylint: disable=protected-access + return obj + + # pylint: disable=too-many-instance-attributes,too-many-public-methods @functools.total_ordering -class BaseOperator(Operator, LoggingMixin): +class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta): """ Abstract base class for all operators. Since operators create objects that become nodes in the dag, BaseOperator contains many recursive methods for @@ -244,6 +260,7 @@ class derived from this one results in the creation of a task object, result :type do_xcom_push: bool """ + # For derived classes to define which fields will get jinjaified template_fields: Iterable[str] = [] # Defines which files extensions to look for in the templated fields @@ -634,6 +651,33 @@ def deps(self) -> Set[BaseTIDep]: NotPreviouslySkippedDep(), } + def _set_xcomargs_dependencies(self) -> None: + """ + Resolves upstream dependencies of a task. In this way passing an ``XComArg` + as value for a template field will result in creating upstream relation between + two tasks. + + **Example**: :: + + with DAG(...): + generate_content = GenerateContentOperator(task_id="generate_content") + send_email = EmailOperator(..., html_content=generate_content.output) + + # This is equivalent to + with DAG(...): + generate_content = GenerateContentOperator(task_id="generate_content") + send_email = EmailOperator( + ..., html_content="{{ task_instance.xcom_pull('generate_content') }}" + ) + generate_content >> send_email + + """ + from airflow.models.xcom_arg import XComArg + for field in self.template_fields: + arg = getattr(self, field) + if isinstance(arg, XComArg): + self.set_upstream(arg.operator) + @property def priority_weight_total(self) -> int: """ @@ -1141,7 +1185,7 @@ def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperat @property def output(self): - """Returns default XComArg for the operator""" + """Returns reference to XCom pushed by current operator""" from airflow.models.xcom_arg import XComArg return XComArg(operator=self) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index c9a07f94542b8..273052c71fb0c 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -29,6 +29,7 @@ from airflow.models import DAG from airflow.models.baseoperator import chain, cross_downstream from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.decorators import apply_defaults from tests.models import DEFAULT_DATE from tests.test_utils.mock_operators import MockNamedTuple, MockOperator @@ -262,6 +263,33 @@ def test_email_on_actions(self): assert test_task.email_on_retry is False assert test_task.email_on_failure is True + def test_upstream_is_set_when_template_field_is_xcomarg(self): + class CustomOpSuperBefore(DummyOperator): + template_fields = ("field",) + + @apply_defaults + def __init__(self, field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.field = field + + class CustomOpSuperAfter(DummyOperator): + template_fields = ("field",) + + @apply_defaults + def __init__(self, field, *args, **kwargs): + self.field = field + super().__init__(*args, **kwargs) + + with DAG("test_dag", default_args={"start_date": datetime.today()}): + op1 = DummyOperator(task_id="op1") + op2 = CustomOpSuperBefore(task_id="op2", field=op1.output) + op3 = CustomOpSuperAfter(task_id="op3", field=op1.output) + + assert op1 in op2.upstream_list + assert op1 in op3.upstream_list + assert op2 in op1.downstream_list + assert op3 in op1.downstream_list + class TestBaseOperatorMethods(unittest.TestCase): def test_cross_downstream(self):