From cee0edb9e4913a2ea58bb53be85798461fcf16ba Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 1 Mar 2022 20:39:37 +0800 Subject: [PATCH 1/2] More explicit mapped argument validation Instead of always using MagicMock to validate mapped arguments, this implements a more sophisticated protocol that allows an operator to implement a 'validate_mapped_arguments' to provide custom validation logic. If an operator just wants to use __init__ for validation, however, they can set a flag 'mapped_arguments_validated_by_init' to get the behavior easily. (This does *not* use MagicMock, however, since any custom validation logic should be able to handle those on its own). The 'validate_mapped_arguments' flag is currently only set on PythonOperator. It can likely be used on a lot more operators down the road. --- airflow/decorators/base.py | 29 ++++++--------- airflow/models/baseoperator.py | 30 ++++++++++++---- airflow/models/mappedoperator.py | 62 ++++++++------------------------ airflow/operators/python.py | 2 ++ 4 files changed, 51 insertions(+), 72 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index c3c020688545b..3ba6bbb494cc2 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -49,7 +49,6 @@ from airflow.models.mappedoperator import ( MappedOperator, ValidationSource, - create_mocked_kwargs, ensure_xcomarg_return_value, get_mappable_types, prevent_duplicates, @@ -411,29 +410,23 @@ def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]: """ return self.mapped_op_kwargs - def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], real: bool) -> "BaseOperator": - assert not isinstance(self.operator_class, str) + def _get_unmap_kwargs(self) -> Dict[str, Any]: partial_kwargs = self.partial_kwargs.copy() - if real: - mapped_op_kwargs: Dict[str, Any] = self.mapped_op_kwargs - else: - mapped_op_kwargs = create_mocked_kwargs(self.mapped_op_kwargs) op_kwargs = _merge_kwargs( partial_kwargs.pop("op_kwargs"), - mapped_op_kwargs, + self.mapped_op_kwargs, fail_reason="mapping already partial", ) - return self.operator_class( - dag=self.dag, - task_group=self.task_group, - task_id=self.task_id, - op_kwargs=op_kwargs, - multiple_outputs=self.multiple_outputs, - python_callable=self.python_callable, - _airflow_map_validation=not real, + return { + "dag": self.dag, + "task_group": self.task_group, + "task_id": self.task_id, + "op_kwargs": op_kwargs, + "multiple_outputs": self.multiple_outputs, + "python_callable": self.python_callable, **partial_kwargs, - **mapped_kwargs, - ) + **self.mapped_kwargs, + } def _expand_mapped_field(self, key: str, content: Any, context: Context, *, session: Session) -> Any: if key != "op_kwargs" or not isinstance(content, collections.abc.Mapping): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 91b165f8a60aa..296fa298936f8 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -372,11 +372,13 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: if not hasattr(self, '_BaseOperator__init_kwargs'): self._BaseOperator__init_kwargs = {} + mapped_validation_only = kwargs.pop("_airflow_mapped_validation_only", False) result = func(self, **kwargs, default_args=default_args) + # Store the args passed to init -- we need them to support task.map serialzation! self._BaseOperator__init_kwargs.update(kwargs) # type: ignore - if not kwargs.get("_airflow_map_validation"): + if not mapped_validation_only: # Set upstream task defined by XComArgs passed to template fields of the operator. self.set_xcomargs_dependencies() # Mark instance as instantiated. @@ -656,11 +658,16 @@ class derived from this one results in the creation of a task object, start_date: Optional[pendulum.DateTime] = None end_date: Optional[pendulum.DateTime] = None + # How operator-mapping arguments should be validated. If True, a default validation implementation that + # calls the operator's constructor is used. If False, the operator should implement its own validation + # logic (default implementation is 'pass' i.e. no validation whatsoever). + mapped_arguments_validated_by_init: ClassVar[bool] = False + def __new__( cls, dag: Optional['DAG'] = None, task_group: Optional["TaskGroup"] = None, - _airflow_map_validation: bool = False, # If True, this is called to validate a MappedOperator. + _airflow_mapped_validation_only: bool = False, # Whether called to validate a MappedOperator. **kwargs, ): # If we are creating a new Task _and_ we are in the context of a MappedTaskGroup, then we should only @@ -671,7 +678,7 @@ def __new__( dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) - if not _airflow_map_validation and isinstance(task_group, MappedTaskGroup): + if not _airflow_mapped_validation_only and isinstance(task_group, MappedTaskGroup): return cls.partial(dag=dag, task_group=task_group, **kwargs).apply() return super().__new__(cls) @@ -730,10 +737,6 @@ def __init__( super().__init__() - # This keyword is used internally to signify whether the operator is - # instantiated to validate a MappedOperator. - kwargs.pop("_airflow_map_validation", None) - if kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): raise AirflowException( @@ -1470,6 +1473,19 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) + @classmethod + def _validate_mapped_arguments_by_init(cls, **kwargs: Any) -> None: + """Mapping argument validation by actually creating the operator.""" + operator = cls(**kwargs, _airflow_mapped_validation_only=True) + if operator.dag: + operator.dag._remove_task(operator.task_id) + + @classmethod + def validate_mapped_arguments(cls, **kwargs: Any) -> None: + """Validate arguments when this operator is being mapped.""" + if cls.mapped_arguments_validated_by_init: + cls._validate_mapped_arguments_by_init(**kwargs) + def unmap(self) -> "BaseOperator": """:meta private:""" return self diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 044660b839257..6340d695724c3 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -21,7 +21,6 @@ import datetime import functools import operator -import unittest.mock import warnings from typing import ( TYPE_CHECKING, @@ -152,24 +151,6 @@ def ensure_xcomarg_return_value(arg: Any) -> None: ensure_xcomarg_return_value(v) -def create_mocked_kwargs(kwargs: Dict[str, "Mappable"]) -> Dict[str, unittest.mock.MagicMock]: - """Create a mapping of mocks for given map arguments. - - When a mapped operator is created, we want to perform basic validation on - the map arguments, especially the count of arguments. However, most of this - kind of logic lives directly on an operator class's ``__init__``, and - there's no good way to validate the arguments except to actually try to - create an operator instance. - - Since the map arguments are yet to be populated when the mapped operator is - being parsed, we need to "invent" some mocked values for this validation - purpose. The :class:`~unittest.mock.MagicMock` class is a good fit for this - since it not only provide good run-time properties, but also enjoy special - treatments in Mypy. - """ - return {k: unittest.mock.MagicMock(name=k) for k in kwargs} - - @attr.define(kw_only=True, repr=False) class OperatorPartial: """An "intermediate state" returned by ``BaseOperator.partial()``. @@ -321,11 +302,7 @@ def _validate_argument_count(self) -> None: """ if isinstance(self.operator_class, str): return # No need to validate deserialized operator. - mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs) - op = self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False) - dag = op.get_dag() - if dag: - dag._remove_task(op.task_id) + self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs()) @property def task_type(self) -> str: @@ -455,36 +432,27 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: """Implementing DAGNode.""" return DagAttributeTypes.OP, self.task_id - def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], real: bool) -> "BaseOperator": - """Create a task of the underlying class based on this mapped operator. - - :param mapped_kwargs: Mapped keyword arguments to be used to create the - task. Do not use ``self.mapped_kwargs``. - :param real: Whether the task should be created "for real" (i.e. *False* - means the operator is only created for validation purposes and not - going to be added to the actual DAG). This is simply forwarded to - the operator's ``_airflow_map_validation`` argument. - """ - assert not isinstance(self.operator_class, str) - return self.operator_class( - task_id=self.task_id, - dag=self.dag, - task_group=self.task_group, - params=self.params, - start_date=self.start_date, - end_date=self.end_date, - _airflow_map_validation=not real, + def _get_unmap_kwargs(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "dag": self.dag, + "task_group": self.task_group, + "params": self.params, + "start_date": self.start_date, + "end_date": self.end_date, **self.partial_kwargs, - **mapped_kwargs, - ) + **self.mapped_kwargs, + } def unmap(self) -> "BaseOperator": - """Get the "normal" Operator after applying the current mapping""" + """Get the "normal" Operator after applying the current mapping.""" dag = self.dag if not dag: raise RuntimeError("Cannot unmap a task without a DAG") dag._remove_task(self.task_id) - return self._create_unmapped_operator(mapped_kwargs=self.mapped_kwargs, real=True) + if isinstance(self.operator_class, str): + raise RuntimeError("Cannot unmap a deserialized operator") + return self.operator_class(**self._get_unmap_kwargs()) def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]: """The kwargs to calculate expansion length against. diff --git a/airflow/operators/python.py b/airflow/operators/python.py index acdaae9468ae1..fb46abca0e331 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -133,6 +133,8 @@ def my_python_callable(**kwargs): 'op_kwargs', ) + mapped_arguments_validated_by_init = True + def __init__( self, *, From 3cccb207984372e9b4f870cc012f9319959d0775 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 2 Mar 2022 07:51:03 +0800 Subject: [PATCH 2/2] Add flag to distinguish a validation-only init There's just too much magic during a task's initialization that tries to add it into the dependency graph. This flag is needed to work around all that, I think. --- airflow/models/baseoperator.py | 35 +++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 296fa298936f8..292b3ec58fc22 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -337,6 +337,11 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: if len(args) > 0: raise AirflowException("Use keyword arguments when initializing operators") + mapped_validation_only = kwargs.pop( + "_airflow_mapped_validation_only", + getattr(self, "_BaseOperator__mapped_validation", False), + ) + dag: Optional[DAG] = kwargs.get('dag') or DagContext.get_current_dag() task_group: Optional[TaskGroup] = kwargs.get('task_group') if dag and not task_group: @@ -371,8 +376,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: if not hasattr(self, '_BaseOperator__init_kwargs'): self._BaseOperator__init_kwargs = {} + self._BaseOperator__mapped_validation = mapped_validation_only - mapped_validation_only = kwargs.pop("_airflow_mapped_validation_only", False) result = func(self, **kwargs, default_args=default_args) # Store the args passed to init -- we need them to support task.map serialzation! @@ -663,6 +668,9 @@ class derived from this one results in the creation of a task object, # logic (default implementation is 'pass' i.e. no validation whatsoever). mapped_arguments_validated_by_init: ClassVar[bool] = False + # Set to True for an operator instantiated only for mapping validation. + __mapped_validation = False + def __new__( cls, dag: Optional['DAG'] = None, @@ -751,12 +759,17 @@ def __init__( stacklevel=3, ) validate_key(task_id) - self.task_id = task_id + dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) + if task_group: self.task_id = task_group.child_id(task_id) + else: + self.task_id = task_id + if not self.__mapped_validation and task_group: task_group.add(self) + self.owner = owner self.email = email self.email_on_retry = email_on_retry @@ -975,9 +988,8 @@ def __lt__(self, other): def __setattr__(self, key, value): super().__setattr__(key, value) - if self._lock_for_execution: - # Skip any custom behaviour during execute - return + if self.__mapped_validation or self._lock_for_execution: + return # Skip any custom behavior for validation and during execute. if key in self.__init_kwargs: self.__init_kwargs[key] = value if self.__instantiated and key in self.template_fields: @@ -1029,6 +1041,9 @@ def dag(self, dag: Optional['DAG']): raise TypeError(f'Expected DAG; received {dag.__class__.__name__}') elif self.has_dag() and self.dag is not dag: raise AirflowException(f"The DAG assigned to {self} can not be changed.") + + if self.__mapped_validation: + pass # Don't add task to DAG for validation. elif self.task_id not in dag.task_dict: dag.add_task(self) elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self: @@ -1420,6 +1435,7 @@ def get_serialized_fields(cls): 'label', '_BaseOperator__instantiated', '_BaseOperator__init_kwargs', + '_BaseOperator__mapped_validation', } | { # Class level defaults need to be added to this list 'start_date', @@ -1473,18 +1489,11 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - @classmethod - def _validate_mapped_arguments_by_init(cls, **kwargs: Any) -> None: - """Mapping argument validation by actually creating the operator.""" - operator = cls(**kwargs, _airflow_mapped_validation_only=True) - if operator.dag: - operator.dag._remove_task(operator.task_id) - @classmethod def validate_mapped_arguments(cls, **kwargs: Any) -> None: """Validate arguments when this operator is being mapped.""" if cls.mapped_arguments_validated_by_init: - cls._validate_mapped_arguments_by_init(**kwargs) + cls(**kwargs, _airflow_mapped_validation_only=True) def unmap(self) -> "BaseOperator": """:meta private:"""