Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,23 +973,10 @@ def set_xcomargs_dependencies(self) -> None:
"""
from airflow.models.xcom_arg import XComArg

def apply_set_upstream(arg: Any):
if isinstance(arg, XComArg):
self.set_upstream(arg.operator)
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
apply_set_upstream(elem)
elif isinstance(arg, dict):
for elem in arg.values():
apply_set_upstream(elem)
elif hasattr(arg, "template_fields"):
for elem in arg.template_fields:
apply_set_upstream(elem)

for field in self.template_fields:
if hasattr(self, field):
arg = getattr(self, field)
apply_set_upstream(arg)
XComArg.apply_upstream_relationship(self, arg)

@property
def priority_weight_total(self) -> int:
Expand Down Expand Up @@ -1734,6 +1721,8 @@ def __repr__(self) -> str:
params: Union[ParamsDict, dict] = attr.ib(factory=ParamsDict)
template_fields: Iterable[str] = attr.ib()

subdag: None = attr.ib(init=False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"badly" as part of the serialization code

if serializable_task.subdag is not None:
setattr(serializable_task.subdag, 'parent_dag', dag)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be fixed/not needed by your #20945?


@_is_dummy.default
def _is_dummy_default(self):
from airflow.operators.dummy import DummyOperator
Expand Down Expand Up @@ -1795,12 +1784,17 @@ def from_decorator(
return operator

def __attrs_post_init__(self):
from airflow.models.xcom_arg import XComArg

if self.task_group:
self.task_id = self.task_group.child_id(self.task_id)
self.task_group.add(self)
if self.dag:
self.dag.add_task(self)

for arg in self.mapped_kwargs.values():
XComArg.apply_upstream_relationship(self, arg)

@task_type.default
def _default_task_type(self):
# Can be a string if we are de-serialized
Expand Down Expand Up @@ -1829,8 +1823,12 @@ def map(self, **kwargs) -> "MappedOperator":

:return: ``self`` for easier method chaining
"""
from airflow.models.xcom_arg import XComArg

if self.mapped_kwargs:
raise RuntimeError("Already a mapped task")
for arg in kwargs.values():
XComArg.apply_upstream_relationship(self, arg)
return attr.evolve(self, mapped_kwargs=kwargs)

@property
Expand Down Expand Up @@ -1865,6 +1863,7 @@ def get_serialized_fields(cls):
cls.__serialized_fields = frozenset(
fields_dict.keys()
- {
'dag',
'deps',
'inherits_from_dummy_operator',
'operator_extra_links',
Expand Down
44 changes: 28 additions & 16 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator, MappedOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator, MappedOperator


class XComArg(DependencyMixin):
"""
Expand Down Expand Up @@ -59,9 +61,9 @@ class XComArg(DependencyMixin):
:type key: str
"""

def __init__(self, operator: Union[BaseOperator, MappedOperator], key: str = XCOM_RETURN_KEY):
self._operator = operator
self._key = key
def __init__(self, operator: "Union[BaseOperator, MappedOperator]", key: str = XCOM_RETURN_KEY):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, operator: "Union[BaseOperator, MappedOperator]", key: str = XCOM_RETURN_KEY):
def __init__(self, operator: Union[BaseOperator, MappedOperator], key: str = XCOM_RETURN_KEY):

Any reason for this change?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid uncessary imports/reduce chance import cycles.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After #20945 we can probably change this to Operator, which would not cause import cycles.

self.operator = operator
self.key = key

def __eq__(self, other):
return self.operator == other.operator and self.key == other.key
Expand Down Expand Up @@ -92,25 +94,15 @@ def __str__(self):
xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
return xcom_pull

@property
def operator(self) -> Union[BaseOperator, MappedOperator]:
"""Returns operator of this XComArg."""
return self._operator

@property
def roots(self) -> List[DAGNode]:
"""Required by TaskMixin"""
return [self._operator]
return [self.operator]

@property
def leaves(self) -> List[DAGNode]:
"""Required by TaskMixin"""
return [self._operator]

@property
def key(self) -> str:
"""Returns keys of this XComArg"""
return self._key
return [self.operator]

def set_upstream(
self,
Expand Down Expand Up @@ -144,3 +136,23 @@ def resolve(self, context: Context) -> Any:
resolved_value = resolved_value[0]

return resolved_value

@staticmethod
def apply_upstream_relationship(op: "Union[BaseOperator, MappedOperator]", arg: Any):
"""
Set dependency for XComArgs.

This looks for XComArg objects in ``arg`` "deeply" (looking inside lists, dicts and classes decorated
with "template_fields") and sets the relationship to ``op`` on any found.
"""
if isinstance(arg, XComArg):
op.set_upstream(arg.operator)
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
XComArg.apply_upstream_relationship(op, elem)
elif isinstance(arg, dict):
for elem in arg.values():
XComArg.apply_upstream_relationship(op, elem)
elif hasattr(arg, "template_fields"):
for elem in arg.template_fields:
XComArg.apply_upstream_relationship(op, elem)
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ class DagAttributeTypes(str, Enum):
TASK_GROUP = 'taskgroup'
EDGE_INFO = 'edgeinfo'
PARAM = 'param'
XCOM_REF = 'xcomref'
37 changes: 36 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
from dataclasses import dataclass
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Type, Union

import cattr
import pendulum
Expand All @@ -36,6 +36,7 @@
from airflow.models.dag import DAG, create_timetable
from airflow.models.param import Param, ParamsDict
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg
from airflow.providers_manager import ProvidersManager
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
Expand Down Expand Up @@ -168,6 +169,18 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
return timetable_class.deserialize(var[Encoding.VAR])


class _XcomRef(NamedTuple):
"""
Used to store info needed to create XComArg when deserializing MappedOperator.

We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when deserializing an operator we
need to create _something_, and then post-process it in deserialize_dag
"""

task_id: str
key: str


class BaseSerialization:
"""BaseSerialization provides utils for serialization."""

Expand Down Expand Up @@ -331,6 +344,8 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
return SerializedTaskGroup.serialize_task_group(var)
elif isinstance(var, Param):
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
Expand Down Expand Up @@ -374,6 +389,8 @@ def _deserialize(cls, encoded_var: Any) -> Any:
return tuple(cls._deserialize(v) for v in var)
elif type_ == DAT.PARAM:
return cls._deserialize_param(var)
elif type_ == DAT.XCOM_REF:
return cls._deserialize_xcomref(var)
else:
raise TypeError(f'Invalid type {type_!s} in deserialization.')

Expand Down Expand Up @@ -476,6 +493,14 @@ def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:

return ParamsDict(op_params)

@classmethod
def _serialize_xcomarg(cls, arg: XComArg) -> dict:
return {"key": arg.key, "task_id": arg.operator.task_id}

@classmethod
def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef:
return _XcomRef(key=encoded['key'], task_id=encoded['task_id'])


class DependencyDetector:
"""Detects dependencies between DAGs."""
Expand Down Expand Up @@ -687,6 +712,8 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator,
v = cls._deserialize_deps(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k in ("mapped_kwargs", "partial_kwargs"):
v = {arg: cls._deserialize(value) for arg, value in v.items()}
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls._deserialize(v)
# else use v as it is
Expand Down Expand Up @@ -970,6 +997,14 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
if serializable_task.subdag is not None:
setattr(serializable_task.subdag, 'parent_dag', dag)

if isinstance(task, MappedOperator):
for d in (task.mapped_kwargs, task.partial_kwargs):
for k, v in d.items():
if not isinstance(v, _XcomRef):
continue

d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key)

for task_id in serializable_task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(serializable_task.task_id)
Expand Down
15 changes: 15 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,21 @@ def test_map_unknown_arg_raises():
BaseOperator(task_id='a').map(file=[1, 2, {'a': 'b'}])


def test_map_xcom_arg():
"""Test that dependencies are correct when mapping with an XComArg"""
from airflow.models.xcom_arg import XComArg

with DAG("test-dag", start_date=DEFAULT_DATE):
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
mapped = MockOperator(task_id='task_2').map(arg2=xcomarg)
finish = MockOperator(task_id="finish")

mapped >> finish

assert task1.downstream_list == [mapped]


def test_partial_on_instance() -> None:
"""`.partial` on an instance should fail -- it's only designed to be called on classes"""
with pytest.raises(TypeError):
Expand Down
55 changes: 51 additions & 4 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.hooks.base import BaseHook
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.models import DAG, Connection, DagBag
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink, MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.xcom import XCom
from airflow.operators.bash import BashOperator
Expand All @@ -52,7 +52,7 @@
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup
from tests.test_utils.mock_operators import CustomOperator, CustomOpLink, GoogleLink
from tests.test_utils.mock_operators import CustomOperator, CustomOpLink, GoogleLink, MockOperator
from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable

executor_config_pod = k8s.V1Pod(
Expand Down Expand Up @@ -1573,7 +1573,10 @@ def mock__import__(name, globals_=None, locals_=None, fromlist=(), level=0):


def test_mapped_operator_serde():
real_op = BashOperator.partial(task_id='a').map(bash_command=[1, 2, {'a': 'b'}])
literal = [1, 2, {'a': 'b'}]
real_op = BashOperator.partial(task_id='a', executor_config={'dict': {'sub': 'value'}}).map(
bash_command=literal
)

serialized = SerializedBaseOperator._serialize(real_op)

Expand All @@ -1590,14 +1593,58 @@ def test_mapped_operator_serde():
{"__type": "dict", "__var": {'a': 'b'}},
]
},
'partial_kwargs': {},
'partial_kwargs': {
'executor_config': {
'__type': 'dict',
'__var': {
'dict': {"__type": "dict", "__var": {'sub': 'value'}},
},
},
},
'task_id': 'a',
'template_fields': ['bash_command', 'env'],
}

op = SerializedBaseOperator.deserialize_operator(serialized)
assert isinstance(op, MappedOperator)

assert op.operator_class == "airflow.operators.bash.BashOperator"
assert op.mapped_kwargs['bash_command'] == literal
assert op.partial_kwargs['executor_config'] == {'dict': {'sub': 'value'}}


def test_mapped_operator_xcomarg_serde():
from airflow.models.xcom_arg import XComArg

with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
mapped = MockOperator(task_id='task_2').map(arg2=xcomarg)

serialized = SerializedBaseOperator._serialize(mapped)
assert serialized == {
'_is_dummy': False,
'_is_mapped': True,
'_task_module': 'tests.test_utils.mock_operators',
'_task_type': 'MockOperator',
'downstream_task_ids': [],
'mapped_kwargs': {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'test_key'}}},
'partial_kwargs': {},
'task_id': 'task_2',
'template_fields': ['arg1', 'arg2'],
}

op = SerializedBaseOperator.deserialize_operator(serialized)

arg = op.mapped_kwargs['arg2']
assert arg.task_id == 'op1'
assert arg.key == 'test_key'

serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))

xcom_arg = serialized_dag.task_dict['task_2'].mapped_kwargs['arg2']
assert isinstance(xcom_arg, XComArg)
assert xcom_arg.operator is serialized_dag.task_dict['op1']


def test_mapped_task_group_serde():
Expand Down