From 2a54e634df3b704433e4b8e800e27cd4c04e3d02 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 16 Dec 2021 17:15:44 +0000 Subject: [PATCH 01/15] Basic serialization support for MappedOperators --- airflow/models/baseoperator.py | 82 +++++++++++++++++-- airflow/serialization/schema.json | 8 +- airflow/serialization/serialized_objects.py | 38 +++++++-- tests/serialization/test_dag_serialization.py | 28 +++++++ 4 files changed, 143 insertions(+), 13 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index a3a2b48faa117..9c6cbad2d0b28 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -209,6 +209,7 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: result = func(self, **kwargs, default_args=default_args) # Store the args passed to init -- we need them to support task.map serialzation! + kwargs.pop('task_id', None) self._BaseOperator__init_kwargs.update(kwargs) # type: ignore # Here we set upstream task defined by XComArgs passed to template fields of the operator @@ -1673,7 +1674,7 @@ def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]: def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, value: Dict[str, Any]): - if isinstance(str, cls): + if isinstance(cls, str): # Serialized version -- would have been validated at parse time return @@ -1702,11 +1703,31 @@ def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, v raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword arguments {names}') -@attr.define(kw_only=True) +def _MappedOperator_minimal_repr(cls, fields): + results = [] + fields = iter(fields) + for field in fields: + results.append(field) + if field.name == "dag": + # Everything after 'dag' attribute is exluced form repr + break + + for field in fields: + results.append(field.evolve(repr=False)) + return results + + +@attr.define(kw_only=True, field_transformer=_MappedOperator_minimal_repr) class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG""" - operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__) + def _operator_class_repr(val): + # Can be a string if we are de-serialized + if isinstance(val, str): + return val.rsplit('.', 1)[-1] + return val.__name__ + + operator_class: Type[BaseOperator] = attr.ib(repr=_operator_class_repr) task_type: str = attr.ib() task_id: str partial_kwargs: Dict[str, Any] @@ -1714,17 +1735,36 @@ class MappedOperator(DAGNode): validator=lambda self, _, v: _validate_kwarg_names_for_mapping(self.operator_class, "map", v) ) dag: Optional["DAG"] = None - upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) - downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) - - task_group: Optional["TaskGroup"] = attr.ib(repr=False) + upstream_task_ids: Set[str] = attr.ib(factory=set) + downstream_task_ids: Set[str] = attr.ib(factory=set) + task_group: Optional["TaskGroup"] = attr.ib() # BaseOperator-like interface -- needed so we can add oursleves to the dag.tasks start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) owner: str = attr.ib(repr=False, default=conf.get("operators", "DEFAULT_OWNER")) max_active_tis_per_dag: Optional[int] = attr.ib(default=None) + # Needed for SerializedBaseOperator + _is_dummy: bool = attr.ib() + + deps: Iterable[BaseTIDep] = attr.ib() + operator_extra_links: Iterable['BaseOperatorLink'] = () + params: ParamsDict = None + template_fields: Iterable[str] = attr.ib() + + del _operator_class_repr + + @_is_dummy.default + def _is_dummy_default(self): + from airflow.operators.dummy import DummyOperator + + return issubclass(self.operator_class, DummyOperator) + + @deps.default + def _deps_from_class(self): + return self.operator_class.deps + @classmethod def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator": dag: Optional["DAG"] = getattr(operator, '_dag', None) @@ -1744,6 +1784,8 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> mapped_kwargs=mapped_kwargs, owner=operator.owner, max_active_tis_per_dag=operator.max_active_tis_per_dag, + deps=operator.deps, + params=operator.params, ) @classmethod @@ -1789,6 +1831,10 @@ def _default_task_group(self): return TaskGroupContext.get_current_task_group(self.dag) + @template_fields.default + def _template_fields_default(self): + return self.operator_class.template_fields + @property def node_id(self): return self.task_id @@ -1820,6 +1866,28 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: """Required by DAGNode.""" return DagAttributeTypes.OP, self.task_id + @property + def inherits_from_dummy_operator(self): + """Used to determine if an Operator is inherited from DummyOperator""" + # This looks like `isinstance(self, DummyOperator) would work, but this also + # needs to cope when `self` is a Serialized instance of a DummyOperator or one + # of its sub-classes (which don't inherit from anything but BaseOperator). + return getattr(self, '_is_dummy', False) + + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + __serialized_fields: ClassVar[Optional[FrozenSet[str]]] = None + + @classmethod + def get_serialized_fields(cls): + if cls.__serialized_fields is None: + fields_dict = attr.fields_dict(cls) + cls.__serialized_fields = frozenset( + fields_dict.keys() + - {'deps', 'inherits_from_dummy_operator', 'operator_extra_links', 'upstream_task_ids'} + | {'template_fields'} + ) + return cls.__serialized_fields + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 6d25c1ee3972e..d7e949dfc5f48 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -211,7 +211,13 @@ "doc_md": { "type": "string" }, "doc_json": { "type": "string" }, "doc_yaml": { "type": "string" }, - "doc_rst": { "type": "string" } + "doc_rst": { "type": "string" }, + "mapped_kwargs": { "type": "object" }, + "partial_kwargs": { "type": "object" } + }, + "dependencies": { + "mapped_kwargs":[ "partial_kwargs"], + "partial_kwargs":[ "mapped_kwargs"] }, "additionalProperties": true }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 19136f4034480..ddab24d45eac4 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -31,7 +31,7 @@ from airflow.compat.functools import cache from airflow.configuration import conf from airflow.exceptions import AirflowException, SerializationError -from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink, MappedOperator from airflow.models.connection import Connection from airflow.models.dag import DAG, create_timetable from airflow.models.param import Param, ParamsDict @@ -303,6 +303,8 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r return cls._encode(json_pod, type_=DAT.POD) elif isinstance(var, DAG): return SerializedDAG.serialize_dag(var) + elif isinstance(var, MappedOperator): + return SerializedBaseOperator.serialize_mapped_operator(var) elif isinstance(var, BaseOperator): return SerializedBaseOperator.serialize_operator(var) elif isinstance(var, cls._datetime_types): @@ -537,6 +539,14 @@ def task_type(self) -> str: def task_type(self, task_type: str): self._task_type = task_type + @classmethod + def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: + serialize_op = cls.serialize_operator(op) + serialize_op['_task_type'] = op.operator_class.__name__ + serialize_op['_task_module'] = op.operator_class.__module__ + serialize_op['_is_mapped'] = True + return serialize_op + @classmethod def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: """Serializes operator into a JSON object.""" @@ -588,7 +598,21 @@ def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: @classmethod def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: """Deserializes an operator from a JSON object.""" - op = SerializedBaseOperator(task_id=encoded_op['task_id']) + # Check if it's a mapped operator + if "mapped_kwargs" in encoded_op: + op = MappedOperator( + task_id=encoded_op['task_id'], + dag=None, + operator_class='.'.join(filter(None, (encoded_op['_task_module'], encoded_op['_task_type']))), + # These are all re-set later + partial_kwargs={}, + mapped_kwargs={}, + deps=tuple(), + is_dummy=False, + template_fields=(), + ) + else: + op = SerializedBaseOperator(task_id=encoded_op['task_id']) if "label" not in encoded_op: # Handle deserialization of old data before the introduction of TaskGroup @@ -625,7 +649,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: if k == "label": # Label shouldn't be set anymore -- it's computed from task_id now continue - if k == "_downstream_task_ids": + if k in {"_downstream_task_ids", "downstream_task_ids"}: v = set(v) elif k == "subdag": v = SerializedDAG.deserialize_dag(v) @@ -655,10 +679,14 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: v = cls._deserialize(v) # else use v as it is - setattr(op, k, v) + if hasattr(op, k) and isinstance(v, set): + getattr(op, k).update(v) + else: + setattr(op, k, v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): - setattr(op, k, None) + if not hasattr(op, k): + setattr(op, k, None) # Set all the template_field to None that were not present in Serialized JSON for field in op.template_fields: diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 4833906a91a01..3207b7767a49d 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1549,3 +1549,31 @@ def mock__import__(name, globals_=None, locals_=None, fromlist=(), level=0): # basic serialization should succeed module.SerializedDAG.to_dict(make_simple_dag()["simple_dag"]) + + +def test_mapped_operator_serde(): + real_op = BashOperator.partial(task_id='a').map(bash_command=[1, 2, {'a': 'b'}]) + + serialized = SerializedBaseOperator._serialize(real_op) + + assert serialized == { + '_is_dummy': False, + '_is_mapped': True, + '_task_module': 'airflow.operators.bash', + '_task_type': 'BashOperator', + 'downstream_task_ids': [], + 'mapped_kwargs': { + 'bash_command': [ + 1, + 2, + {"__type": "dict", "__var": {'a': 'b'}}, + ] + }, + 'partial_kwargs': {}, + 'task_id': 'a', + 'template_fields': ['bash_command', 'env'], + } + + op = SerializedBaseOperator.deserialize_operator(serialized) + + assert op.operator_class == "airflow.operators.bash.BashOperator" From 89512685f54b951278f7aac7fa0b9750f995d3b0 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 22:28:43 +0000 Subject: [PATCH 02/15] Serialization support for TaskGroups --- airflow/models/baseoperator.py | 2 +- airflow/serialization/serialized_objects.py | 12 +++++- airflow/utils/task_group.py | 2 +- tests/serialization/test_dag_serialization.py | 41 +++++++++++++++++-- 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 9c6cbad2d0b28..eb9f066c1f430 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1727,7 +1727,7 @@ def _operator_class_repr(val): return val.rsplit('.', 1)[-1] return val.__name__ - operator_class: Type[BaseOperator] = attr.ib(repr=_operator_class_repr) + operator_class: Union[Type[BaseOperator], str] = attr.ib(repr=_operator_class_repr) task_type: str = attr.ib() task_id: str partial_kwargs: Dict[str, Any] diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index ddab24d45eac4..05eec26440422 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -43,7 +43,7 @@ from airflow.timetables.base import Timetable from airflow.utils.code_utils import get_python_source from airflow.utils.module_loading import as_importable_string, import_string -from airflow.utils.task_group import TaskGroup +from airflow.utils.task_group import MappedTaskGroup, TaskGroup try: # isort: off @@ -596,7 +596,7 @@ def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: return serialize_op @classmethod - def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: + def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, MappedOperator]: """Deserializes an operator from a JSON object.""" # Check if it's a mapped operator if "mapped_kwargs" in encoded_op: @@ -1011,6 +1011,14 @@ def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Dict[str, Any]] "downstream_task_ids": cls._serialize(sorted(task_group.downstream_task_ids)), } + if isinstance(task_group, MappedTaskGroup): + if task_group.mapped_arg: + serialize_group['mapped_arg'] = cls._serialize(task_group.mapped_arg) + if task_group.mapped_kwargs: + serialize_group['mapped_arg'] = cls._serialize(task_group.mapped_kwargs) + if task_group.partial_kwargs: + serialize_group['mapped_arg'] = cls._serialize(task_group.partial_kwargs) + return serialize_group @classmethod diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 629de76deee8b..16445debb6783 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -400,7 +400,7 @@ def map(self, arg: Iterable) -> "MappedTaskGroup": raise RuntimeError("Cannot map a TaskGroup before it has a group_id") if self._parent_group: self._parent_group._remove(self) - return MappedTaskGroup(group_id=self._group_id, mapped_arg=arg) + return MappedTaskGroup(group_id=self._group_id, dag=self.dag, mapped_arg=arg) class MappedTaskGroup(TaskGroup): diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 3207b7767a49d..2291cfdacfacc 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -43,10 +43,15 @@ from airflow.operators.bash import BashOperator from airflow.security import permissions from airflow.serialization.json_schema import load_dag_schema_dict -from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG +from airflow.serialization.serialized_objects import ( + SerializedBaseOperator, + SerializedDAG, + SerializedTaskGroup, +) from airflow.timetables.simple import NullTimetable, OnceTimetable from airflow.utils import timezone from airflow.utils.context import Context +from airflow.utils.task_group import TaskGroup from tests.test_utils.mock_operators import CustomOperator, CustomOpLink, GoogleLink from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable @@ -1153,7 +1158,6 @@ def test_task_group_serialization(self): Test TaskGroup serialization/deserialization. """ from airflow.operators.dummy import DummyOperator - from airflow.utils.task_group import TaskGroup execution_date = datetime(2020, 1, 1) with DAG("test_task_group_serialization", start_date=execution_date) as dag: @@ -1229,7 +1233,6 @@ def test_task_group_sorted(self): """ from airflow.operators.dummy import DummyOperator from airflow.serialization.serialized_objects import SerializedTaskGroup - from airflow.utils.task_group import TaskGroup """ start @@ -1577,3 +1580,35 @@ def test_mapped_operator_serde(): op = SerializedBaseOperator.deserialize_operator(serialized) assert op.operator_class == "airflow.operators.bash.BashOperator" + + +def test_mapped_task_group_serde(): + execution_date = datetime(2020, 1, 1) + + literal = [1, 2, {'a': 'b'}] + with DAG("test", start_date=execution_date) as dag: + with TaskGroup("process_one", dag=dag).map(literal) as process_one: + BaseOperator(task_id='one') + + serialized = SerializedTaskGroup.serialize_task_group(process_one) + + assert serialized == { + '_group_id': 'process_one', + 'children': {'process_one.one': ('operator', 'process_one.one')}, + 'downstream_group_ids': [], + 'downstream_task_ids': [], + 'prefix_group_id': True, + 'tooltip': '', + 'ui_color': 'CornflowerBlue', + 'ui_fgcolor': '#000', + 'upstream_group_ids': [], + 'upstream_task_ids': [], + 'mapped_arg': [ + 1, + 2, + {"__type": "dict", "__var": {'a': 'b'}}, + ], + } + + with DAG("test", start_date=execution_date): + SerializedTaskGroup.deserialize_task_group(serialized, None, dag.task_dict) From 592b8eb364b03547a6940841f6ce265a052a468a Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 11 Jan 2022 16:15:05 +0800 Subject: [PATCH 03/15] Tweak type hints to reflect reality While keeping Mypy happy. --- airflow/models/baseoperator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index eb9f066c1f430..1313524aff213 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1673,7 +1673,11 @@ def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]: return False -def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, value: Dict[str, Any]): +def _validate_kwarg_names_for_mapping( + cls: Union[str, Type[BaseOperator]], + func_name: str, + value: Dict[str, Any], +) -> None: if isinstance(cls, str): # Serialized version -- would have been validated at parse time return @@ -1750,7 +1754,7 @@ def _operator_class_repr(val): deps: Iterable[BaseTIDep] = attr.ib() operator_extra_links: Iterable['BaseOperatorLink'] = () - params: ParamsDict = None + params: Union[ParamsDict, dict] = attr.ib(factory=ParamsDict) template_fields: Iterable[str] = attr.ib() del _operator_class_repr From 4e3aac7634d8a30d1061a528b62e03c659361f60 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 12 Jan 2022 19:52:09 +0000 Subject: [PATCH 04/15] fixup! Basic serialization support for MappedOperators --- airflow/models/baseoperator.py | 28 ++++++++------------- airflow/serialization/serialized_objects.py | 26 +++++++++++++------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 1313524aff213..ac45e52ec7ee3 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1707,31 +1707,25 @@ def _validate_kwarg_names_for_mapping( raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword arguments {names}') -def _MappedOperator_minimal_repr(cls, fields): - results = [] - fields = iter(fields) - for field in fields: - results.append(field) - if field.name == "dag": - # Everything after 'dag' attribute is exluced form repr - break - - for field in fields: - results.append(field.evolve(repr=False)) - return results - - -@attr.define(kw_only=True, field_transformer=_MappedOperator_minimal_repr) +@attr.define(kw_only=True) class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG""" + @staticmethod def _operator_class_repr(val): # Can be a string if we are de-serialized if isinstance(val, str): return val.rsplit('.', 1)[-1] return val.__name__ - operator_class: Union[Type[BaseOperator], str] = attr.ib(repr=_operator_class_repr) + def __repr__(self) -> str: + return ( + 'MappedOperator(operator_class={self._operator_class_repr(self.operator_class)}, ' + + 'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, ' + + 'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})' + ) + + operator_class: Union[Type[BaseOperator], str] task_type: str = attr.ib() task_id: str partial_kwargs: Dict[str, Any] @@ -1757,8 +1751,6 @@ def _operator_class_repr(val): params: Union[ParamsDict, dict] = attr.ib(factory=ParamsDict) template_fields: Iterable[str] = attr.ib() - del _operator_class_repr - @_is_dummy.default def _is_dummy_default(self): from airflow.operators.dummy import DummyOperator diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 05eec26440422..fbf252a137694 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -35,6 +35,7 @@ from airflow.models.connection import Connection from airflow.models.dag import DAG, create_timetable from airflow.models.param import Param, ParamsDict +from airflow.models.taskmixin import DAGNode from airflow.providers_manager import ProvidersManager from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -255,7 +256,7 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: @classmethod def serialize_to_json( - cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set + cls, object_to_serialize: Union["BaseOperator", "MappedOperator", DAG], decorated_fields: Set ) -> Dict[str, Any]: """Serializes an object to json""" serialized_object: Dict[str, Any] = {} @@ -541,7 +542,9 @@ def task_type(self, task_type: str): @classmethod def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - serialize_op = cls.serialize_operator(op) + serialize_op = cls._serialize_node(op) + # It must be a class at this point for it to work, not a string + assert isinstance(op.operator_class, type) serialize_op['_task_type'] = op.operator_class.__name__ serialize_op['_task_module'] = op.operator_class.__module__ serialize_op['_is_mapped'] = True @@ -549,10 +552,14 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: @classmethod def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: + return cls._serialize_node(op) + + @classmethod + def _serialize_node(cls, op: Union[BaseOperator, MappedOperator]) -> Dict[str, Any]: """Serializes operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) - serialize_op['_task_type'] = op.__class__.__name__ - serialize_op['_task_module'] = op.__class__.__module__ + serialize_op['_task_type'] = type(op).__name__ + serialize_op['_task_module'] = type(op).__module__ # Used to determine if an Operator is inherited from DummyOperator serialize_op['_is_dummy'] = op.inherits_from_dummy_operator @@ -571,6 +578,7 @@ def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: klass = type(dep) module_name = klass.__module__ if not module_name.startswith("airflow.ti_deps.deps."): + assert op.dag # for type checking raise SerializationError( f"Cannot serialize {(op.dag.dag_id + '.' + op.task_id)!r} with `deps` from non-core " f"module {module_name!r}" @@ -598,6 +606,7 @@ def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: @classmethod def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, MappedOperator]: """Deserializes an operator from a JSON object.""" + op: Union[BaseOperator, MappedOperator] # Check if it's a mapped operator if "mapped_kwargs" in encoded_op: op = MappedOperator( @@ -679,10 +688,11 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, v = cls._deserialize(v) # else use v as it is - if hasattr(op, k) and isinstance(v, set): - getattr(op, k).update(v) - else: + try: setattr(op, k, v) + except AttributeError: + # Likely read-only attribute, try updating it in place + getattr(op, k).update(v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): if not hasattr(op, k): @@ -704,7 +714,7 @@ def detect_dependencies(cls, op: BaseOperator) -> Optional['DagDependency']: return cls.dependency_detector.detect_task_dependencies(op) @classmethod - def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator): + def _is_excluded(cls, var: Any, attrname: str, op: "DAGNode"): if var is not None and op.has_dag() and attrname.endswith("_date"): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. From 08c42369f14ca7349314229b59c9f0f27c68a360 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 14 Jan 2022 13:43:48 +0000 Subject: [PATCH 05/15] fixup! Basic serialization support for MappedOperators --- airflow/models/baseoperator.py | 6 +++--- airflow/serialization/schema.json | 6 ++++-- airflow/serialization/serialized_objects.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ac45e52ec7ee3..0d34db101dba8 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1720,9 +1720,9 @@ def _operator_class_repr(val): def __repr__(self) -> str: return ( - 'MappedOperator(operator_class={self._operator_class_repr(self.operator_class)}, ' - + 'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, ' - + 'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})' + f'MappedOperator(operator_class={self._operator_class_repr(self.operator_class)}, ' + + f'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, ' + + f'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})' ) operator_class: Union[Type[BaseOperator], str] diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index d7e949dfc5f48..75da32140ba2e 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -212,12 +212,14 @@ "doc_json": { "type": "string" }, "doc_yaml": { "type": "string" }, "doc_rst": { "type": "string" }, + "_is_mapped": { "const": true, "$comment": "only present when True" }, "mapped_kwargs": { "type": "object" }, "partial_kwargs": { "type": "object" } }, "dependencies": { - "mapped_kwargs":[ "partial_kwargs"], - "partial_kwargs":[ "mapped_kwargs"] + "mapped_kwargs": ["partial_kwargs", "_is_mapped"], + "partial_kwargs": ["mapped_kwargs", "_is_mapped"], + "_is_mapped": ["mapped_kwargs", "partial_kwargs"] }, "additionalProperties": true }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index fbf252a137694..55cfcae325a57 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -608,11 +608,11 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, """Deserializes an operator from a JSON object.""" op: Union[BaseOperator, MappedOperator] # Check if it's a mapped operator - if "mapped_kwargs" in encoded_op: + if encoded_op.get("_is_mapped", False): op = MappedOperator( task_id=encoded_op['task_id'], dag=None, - operator_class='.'.join(filter(None, (encoded_op['_task_module'], encoded_op['_task_type']))), + operator_class=f"{encoded_op['_task_module']}.{encoded_op['_task_type']}", # These are all re-set later partial_kwargs={}, mapped_kwargs={}, From d8239581ad9a7723b711920928a0acc8cf83ea36 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 14 Jan 2022 14:18:23 +0000 Subject: [PATCH 06/15] Adjust to changes on main --- airflow/models/baseoperator.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0d34db101dba8..50df50c465d61 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1711,16 +1711,9 @@ def _validate_kwarg_names_for_mapping( class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG""" - @staticmethod - def _operator_class_repr(val): - # Can be a string if we are de-serialized - if isinstance(val, str): - return val.rsplit('.', 1)[-1] - return val.__name__ - def __repr__(self) -> str: return ( - f'MappedOperator(operator_class={self._operator_class_repr(self.operator_class)}, ' + f'MappedOperator(task_type={self.task_type}, ' + f'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, ' + f'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})' ) @@ -1738,8 +1731,8 @@ def __repr__(self) -> str: task_group: Optional["TaskGroup"] = attr.ib() # BaseOperator-like interface -- needed so we can add oursleves to the dag.tasks - start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) - end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) + start_date: Optional[pendulum.DateTime] = attr.ib(default=None) + end_date: Optional[pendulum.DateTime] = attr.ib(default=None) owner: str = attr.ib(repr=False, default=conf.get("operators", "DEFAULT_OWNER")) max_active_tis_per_dag: Optional[int] = attr.ib(default=None) @@ -1819,7 +1812,11 @@ def __attrs_post_init__(self): @task_type.default def _default_task_type(self): - return self.operator_class.__name__ + # Can be a string if we are de-serialized + val = self.operator_class + if isinstance(val, str): + return val.rsplit('.', 1)[-1] + return val.__name__ @task_group.default def _default_task_group(self): @@ -1865,10 +1862,7 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: @property def inherits_from_dummy_operator(self): """Used to determine if an Operator is inherited from DummyOperator""" - # This looks like `isinstance(self, DummyOperator) would work, but this also - # needs to cope when `self` is a Serialized instance of a DummyOperator or one - # of its sub-classes (which don't inherit from anything but BaseOperator). - return getattr(self, '_is_dummy', False) + return self._is_dummy # The _serialized_fields are lazily loaded when get_serialized_fields() method is called __serialized_fields: ClassVar[Optional[FrozenSet[str]]] = None @@ -1879,7 +1873,13 @@ def get_serialized_fields(cls): fields_dict = attr.fields_dict(cls) cls.__serialized_fields = frozenset( fields_dict.keys() - - {'deps', 'inherits_from_dummy_operator', 'operator_extra_links', 'upstream_task_ids'} + - { + 'deps', + 'inherits_from_dummy_operator', + 'operator_extra_links', + 'upstream_task_ids', + 'task_type', + } | {'template_fields'} ) return cls.__serialized_fields From c7662ada7c6c4dddf4094a0cd75f1d1dbe7758e1 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 16:01:52 +0000 Subject: [PATCH 07/15] Make downstream_task_ids a normal writeable property It simplifies a few things. We also deal with (and test) the old name when deserializing --- airflow/models/baseoperator.py | 21 ++++------------ airflow/models/taskmixin.py | 12 ++-------- airflow/serialization/schema.json | 2 +- airflow/serialization/serialized_objects.py | 14 +++++------ airflow/utils/task_group.py | 13 ++-------- tests/serialization/test_dag_serialization.py | 24 ++++++++++++++++--- 6 files changed, 37 insertions(+), 49 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 50df50c465d61..752b8a6cfd6e9 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -747,9 +747,8 @@ def __init__( self.doc_rst = doc_rst self.doc = doc - # Private attributes - self._upstream_task_ids: Set[str] = set() - self._downstream_task_ids: Set[str] = set() + self.upstream_task_ids: Set[str] = set() + self.downstream_task_ids: Set[str] = set() if dag: self.dag = dag @@ -1261,16 +1260,6 @@ def resolve_template_files(self) -> None: self.log.exception(e) self.prepare_template() - @property - def upstream_task_ids(self) -> Set[str]: - """@property: set of ids of tasks directly upstream""" - return self._upstream_task_ids - - @property - def downstream_task_ids(self) -> Set[str]: - """@property: set of ids of tasks directly downstream""" - return self._downstream_task_ids - @provide_session def clear( self, @@ -1430,9 +1419,9 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: downstream. """ if upstream: - return self._upstream_task_ids + return self.upstream_task_ids else: - return self._downstream_task_ids + return self.downstream_task_ids def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: """ @@ -1578,7 +1567,7 @@ def get_serialized_fields(cls): - { 'inlets', 'outlets', - '_upstream_task_ids', + 'upstream_task_ids', 'default_args', 'dag', '_dag', diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 24796b8a4f518..4fc95665e79fb 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -128,20 +128,12 @@ def label(self) -> Optional[str]: start_date: Optional[pendulum.DateTime] end_date: Optional[pendulum.DateTime] + upstream_task_ids: Set[str] + downstream_task_ids: Set[str] def has_dag(self) -> bool: return self.dag is not None - @property - @abstractmethod - def upstream_task_ids(self) -> Set[str]: - raise NotImplementedError() - - @property - @abstractmethod - def downstream_task_ids(self) -> Set[str]: - raise NotImplementedError() - @property def log(self) -> "Logger": raise NotImplementedError() diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 75da32140ba2e..6eab39b92f8d8 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -196,7 +196,7 @@ "items": { "type": "string" } }, "subdag": { "$ref": "#/definitions/dag" }, - "_downstream_task_ids": { + "downstream_task_ids": { "type": "array", "items": { "type": "string" } }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 55cfcae325a57..971192075d792 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -655,10 +655,13 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) for k, v in encoded_op.items(): + if k == "_downstream_task_ids": + # Upgrade from old format/name + k = "downstream_task_ids" if k == "label": # Label shouldn't be set anymore -- it's computed from task_id now continue - if k in {"_downstream_task_ids", "downstream_task_ids"}: + elif k == "downstream_task_ids": v = set(v) elif k == "subdag": v = SerializedDAG.deserialize_dag(v) @@ -688,11 +691,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, v = cls._deserialize(v) # else use v as it is - try: - setattr(op, k, v) - except AttributeError: - # Likely read-only attribute, try updating it in place - getattr(op, k).update(v) + setattr(op, k, v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): if not hasattr(op, k): @@ -971,8 +970,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': for task_id in serializable_task.downstream_task_ids: # Bypass set_upstream etc here - it does more than we want - - dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) + dag.task_dict[task_id].upstream_task_ids.add(serializable_task.task_id) return dag diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 16445debb6783..3e47e09dafc73 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -148,9 +148,8 @@ def __init__( # so that we can optimize the number of edges when entire TaskGroups depend on each other. self.upstream_group_ids: Set[Optional[str]] = set() self.downstream_group_ids: Set[Optional[str]] = set() - # Since the parent class defines these as read-only properties, we can 't just do `self.x = ...` - self.__dict__['upstream_task_ids'] = set() - self.__dict__['downstream_task_ids'] = set() + self.upstream_task_ids = set() + self.downstream_task_ids = set() def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): if self._group_id is None: @@ -185,14 +184,6 @@ def is_root(self) -> bool: """Returns True if this TaskGroup is the root TaskGroup. Otherwise False""" return not self.group_id - @property - def upstream_task_ids(self) -> Set[str]: - return self.__dict__['upstream_task_ids'] - - @property - def downstream_task_ids(self) -> Set[str]: - return self.__dict__['downstream_task_ids'] - def __iter__(self): for child in self.children.values(): if isinstance(child, TaskGroup): diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 2291cfdacfacc..acde989e3f682 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -102,7 +102,7 @@ "retry_delay": 300.0, "max_retry_delay": 600.0, "sla": 100.0, - "_downstream_task_ids": [], + "downstream_task_ids": [], "_inlets": [], "_is_dummy": False, "_outlets": [], @@ -132,7 +132,7 @@ "retry_delay": 300.0, "max_retry_delay": 600.0, "sla": 100.0, - "_downstream_task_ids": [], + "downstream_task_ids": [], "_inlets": [], "_is_dummy": False, "_outlets": [], @@ -1099,13 +1099,13 @@ def test_no_new_fields_added_to_base_operator(self): base_operator = BaseOperator(task_id="10") fields = {k: v for (k, v) in vars(base_operator).items() if k in BaseOperator.get_serialized_fields()} assert fields == { - '_downstream_task_ids': set(), '_inlets': [], '_log': base_operator.log, '_outlets': [], '_pre_execute_hook': None, '_post_execute_hook': None, 'depends_on_past': False, + 'downstream_task_ids': set(), 'do_xcom_push': True, 'doc': None, 'doc_json': None, @@ -1153,6 +1153,24 @@ def test_no_new_fields_added_to_base_operator(self): !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! """ + def test_operator_deserialize_old_names(self): + blob = { + "task_id": "custom_task", + "_downstream_task_ids": ['foo'], + "template_ext": [], + "template_fields": ['bash_command'], + "template_fields_renderers": {}, + "_task_type": "CustomOperator", + "_task_module": "tests.test_utils.mock_operators", + "pool": "default_pool", + "ui_color": "#fff", + "ui_fgcolor": "#000", + } + + SerializedDAG._json_schema.validate(blob, _schema=load_dag_schema_dict()['definitions']['operator']) + serialized_op = SerializedBaseOperator.deserialize_operator(blob) + assert serialized_op.downstream_task_ids == {'foo'} + def test_task_group_serialization(self): """ Test TaskGroup serialization/deserialization. From 9439378183cb4da70e33e9c6f04d2d52924709a5 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 17:07:04 +0000 Subject: [PATCH 08/15] Handle `task_id` specially in apply_defaults Since `task_id` is handled speically in the serialization of MappedOperators, we don't want it duplicated in to the partial_kwargs. --- airflow/models/baseoperator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 752b8a6cfd6e9..71170320d6477 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -136,6 +136,7 @@ def _apply_defaults(cls, func: T) -> T: non_optional_args = { name for (name, param) in non_variadic_params.items() if param.default == param.empty } + non_optional_args -= {'task_id'} class autostacklevel_warn: def __init__(self): @@ -158,7 +159,7 @@ def warn(self, message, category=None, stacklevel=1, source=None): func.__globals__['warnings'] = autostacklevel_warn() @functools.wraps(func) - def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: + def apply_defaults(self: "BaseOperator", *args: Any, task_id: str, **kwargs: Any) -> Any: from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext @@ -201,15 +202,15 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: hook = getattr(self, '_hook_apply_defaults', None) if hook: - args, kwargs = hook(**kwargs, default_args=default_args) + args, kwargs = hook(task_id=task_id, **kwargs, default_args=default_args) + task_id = kwargs.pop('task_id') default_args = kwargs.pop('default_args', {}) if not hasattr(self, '_BaseOperator__init_kwargs'): self._BaseOperator__init_kwargs = {} - result = func(self, **kwargs, default_args=default_args) + result = func(self, **kwargs, task_id=task_id, default_args=default_args) # Store the args passed to init -- we need them to support task.map serialzation! - kwargs.pop('task_id', None) self._BaseOperator__init_kwargs.update(kwargs) # type: ignore # Here we set upstream task defined by XComArgs passed to template fields of the operator From 5225ad0339a652b279525cd0aa9c637fc227c494 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 17:25:10 +0000 Subject: [PATCH 09/15] Update airflow/serialization/serialized_objects.py --- airflow/serialization/serialized_objects.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 971192075d792..0da885c11616c 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -694,6 +694,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, setattr(op, k, v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): + # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check could go away. if not hasattr(op, k): setattr(op, k, None) From 5a890ee3b335d231dafcc5841efc2f50db6cfc5c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 17:26:23 +0000 Subject: [PATCH 10/15] Update airflow/serialization/serialized_objects.py --- airflow/serialization/serialized_objects.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 0da885c11616c..785b9524adaa6 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -694,7 +694,8 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, setattr(op, k, v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): - # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check could go away. + # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check + # could go away. if not hasattr(op, k): setattr(op, k, None) From 57028118495007ecc10156b219fd25a3a56a6e29 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 17:30:55 +0000 Subject: [PATCH 11/15] Update airflow/models/baseoperator.py Co-authored-by: Tzu-ping Chung --- airflow/models/baseoperator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 71170320d6477..e341bfe799ad4 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -134,9 +134,11 @@ def _apply_defaults(cls, func: T) -> T: if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) } non_optional_args = { - name for (name, param) in non_variadic_params.items() if param.default == param.empty + name + for name, param in non_variadic_params.items() + if param.default == param.empty + and name != "task_id" } - non_optional_args -= {'task_id'} class autostacklevel_warn: def __init__(self): From 27371425ee72f2abbdc80de743f2073148336737 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 18:12:43 +0000 Subject: [PATCH 12/15] Update airflow/models/baseoperator.py Co-authored-by: Tzu-ping Chung --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index e341bfe799ad4..46486db416cd4 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1706,8 +1706,8 @@ class MappedOperator(DAGNode): def __repr__(self) -> str: return ( f'MappedOperator(task_type={self.task_type}, ' - + f'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, ' - + f'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})' + f'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, ' + f'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})' ) operator_class: Union[Type[BaseOperator], str] From 4d1c222c94706d17e15e95fafe6e98fa5c1337a9 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 17 Jan 2022 18:58:37 +0000 Subject: [PATCH 13/15] fixup! Update airflow/models/baseoperator.py --- airflow/models/baseoperator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 46486db416cd4..c5a646df1f743 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -136,8 +136,7 @@ def _apply_defaults(cls, func: T) -> T: non_optional_args = { name for name, param in non_variadic_params.items() - if param.default == param.empty - and name != "task_id" + if param.default == param.empty and name != "task_id" } class autostacklevel_warn: From 26ce6d36a81f6016ede7febe0d01679f7533c392 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 18 Jan 2022 13:07:32 +0800 Subject: [PATCH 14/15] Fix usages of the old _[up|down]stream_task_ids --- airflow/models/dag.py | 4 ++-- tests/ti_deps/deps/test_trigger_rule_dep.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d912d3e896f97..9f1690202367c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2102,8 +2102,8 @@ def filter_task_group(group, parent_group): for t in dag.tasks: # Removing upstream/downstream references to tasks that did not # make the cut - t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys()) - t._downstream_task_ids = t.downstream_task_ids.intersection(dag.task_dict.keys()) + t.upstream_task_ids.intersection_update(dag.task_dict) + t.downstream_task_ids.intersection_update(dag.task_dict) if len(dag.tasks) < len(self.tasks): dag.partial = True diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index bbdb84679cb1b..f7a47349baefa 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -41,7 +41,7 @@ def _get_task_instance(trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstrea task_id='test_task', trigger_rule=trigger_rule, start_date=datetime(2015, 1, 1) ) if upstream_task_ids: - task._upstream_task_ids.update(upstream_task_ids) + task.upstream_task_ids.update(upstream_task_ids) dr = dag_maker.create_dagrun() ti = dr.task_instances[0] ti.task = task From e70062dd8d2f2a039b5053e0367465ee7022b3c9 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 18 Jan 2022 13:45:26 +0800 Subject: [PATCH 15/15] Remove task_id from kwargs only in MappedOperator --- airflow/models/baseoperator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index c5a646df1f743..458bd313bd42a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -160,7 +160,7 @@ def warn(self, message, category=None, stacklevel=1, source=None): func.__globals__['warnings'] = autostacklevel_warn() @functools.wraps(func) - def apply_defaults(self: "BaseOperator", *args: Any, task_id: str, **kwargs: Any) -> Any: + def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext @@ -203,14 +203,13 @@ def apply_defaults(self: "BaseOperator", *args: Any, task_id: str, **kwargs: Any hook = getattr(self, '_hook_apply_defaults', None) if hook: - args, kwargs = hook(task_id=task_id, **kwargs, default_args=default_args) - task_id = kwargs.pop('task_id') + args, kwargs = hook(**kwargs, default_args=default_args) default_args = kwargs.pop('default_args', {}) if not hasattr(self, '_BaseOperator__init_kwargs'): self._BaseOperator__init_kwargs = {} - result = func(self, **kwargs, task_id=task_id, default_args=default_args) + result = func(self, **kwargs, default_args=default_args) # Store the args passed to init -- we need them to support task.map serialzation! self._BaseOperator__init_kwargs.update(kwargs) # type: ignore @@ -1753,6 +1752,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> # are mapped, we want to _remove_ that task from the dag dag._remove_task(operator.task_id) + operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore return MappedOperator( operator_class=type(operator), task_id=operator.task_id, @@ -1760,7 +1760,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> dag=getattr(operator, '_dag', None), start_date=operator.start_date, end_date=operator.end_date, - partial_kwargs=operator._BaseOperator__init_kwargs, # type: ignore + partial_kwargs={k: v for k, v in operator_init_kwargs.items() if k != "task_id"}, mapped_kwargs=mapped_kwargs, owner=operator.owner, max_active_tis_per_dag=operator.max_active_tis_per_dag,