Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from airflow.models.mappedoperator import (
MappedOperator,
ValidationSource,
create_mocked_kwargs,
ensure_xcomarg_return_value,
get_mappable_types,
prevent_duplicates,
Expand Down Expand Up @@ -411,29 +410,23 @@ def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]:
"""
return self.mapped_op_kwargs

def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], real: bool) -> "BaseOperator":
assert not isinstance(self.operator_class, str)
def _get_unmap_kwargs(self) -> Dict[str, Any]:
partial_kwargs = self.partial_kwargs.copy()
if real:
mapped_op_kwargs: Dict[str, Any] = self.mapped_op_kwargs
else:
mapped_op_kwargs = create_mocked_kwargs(self.mapped_op_kwargs)
op_kwargs = _merge_kwargs(
partial_kwargs.pop("op_kwargs"),
mapped_op_kwargs,
self.mapped_op_kwargs,
fail_reason="mapping already partial",
)
return self.operator_class(
dag=self.dag,
task_group=self.task_group,
task_id=self.task_id,
op_kwargs=op_kwargs,
multiple_outputs=self.multiple_outputs,
python_callable=self.python_callable,
_airflow_map_validation=not real,
return {
"dag": self.dag,
"task_group": self.task_group,
"task_id": self.task_id,
"op_kwargs": op_kwargs,
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
**partial_kwargs,
**mapped_kwargs,
)
**self.mapped_kwargs,
}

def _expand_mapped_field(self, key: str, content: Any, context: Context, *, session: Session) -> Any:
if key != "op_kwargs" or not isinstance(content, collections.abc.Mapping):
Expand Down
47 changes: 36 additions & 11 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,11 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
if len(args) > 0:
raise AirflowException("Use keyword arguments when initializing operators")

mapped_validation_only = kwargs.pop(
"_airflow_mapped_validation_only",
getattr(self, "_BaseOperator__mapped_validation", False),
)

dag: Optional[DAG] = kwargs.get('dag') or DagContext.get_current_dag()
task_group: Optional[TaskGroup] = kwargs.get('task_group')
if dag and not task_group:
Expand Down Expand Up @@ -371,12 +376,14 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:

if not hasattr(self, '_BaseOperator__init_kwargs'):
self._BaseOperator__init_kwargs = {}
self._BaseOperator__mapped_validation = mapped_validation_only

result = func(self, **kwargs, default_args=default_args)

# Store the args passed to init -- we need them to support task.map serialzation!
self._BaseOperator__init_kwargs.update(kwargs) # type: ignore

if not kwargs.get("_airflow_map_validation"):
if not mapped_validation_only:
# Set upstream task defined by XComArgs passed to template fields of the operator.
self.set_xcomargs_dependencies()
# Mark instance as instantiated.
Expand Down Expand Up @@ -656,11 +663,19 @@ class derived from this one results in the creation of a task object,
start_date: Optional[pendulum.DateTime] = None
end_date: Optional[pendulum.DateTime] = None

# How operator-mapping arguments should be validated. If True, a default validation implementation that
# calls the operator's constructor is used. If False, the operator should implement its own validation
# logic (default implementation is 'pass' i.e. no validation whatsoever).
mapped_arguments_validated_by_init: ClassVar[bool] = False

# Set to True for an operator instantiated only for mapping validation.
__mapped_validation = False

def __new__(
cls,
dag: Optional['DAG'] = None,
task_group: Optional["TaskGroup"] = None,
_airflow_map_validation: bool = False, # If True, this is called to validate a MappedOperator.
_airflow_mapped_validation_only: bool = False, # Whether called to validate a MappedOperator.
**kwargs,
):
# If we are creating a new Task _and_ we are in the context of a MappedTaskGroup, then we should only
Expand All @@ -671,7 +686,7 @@ def __new__(
dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)

if not _airflow_map_validation and isinstance(task_group, MappedTaskGroup):
if not _airflow_mapped_validation_only and isinstance(task_group, MappedTaskGroup):
return cls.partial(dag=dag, task_group=task_group, **kwargs).apply()
return super().__new__(cls)

Expand Down Expand Up @@ -730,10 +745,6 @@ def __init__(

super().__init__()

# This keyword is used internally to signify whether the operator is
# instantiated to validate a MappedOperator.
kwargs.pop("_airflow_map_validation", None)

if kwargs:
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
raise AirflowException(
Expand All @@ -748,12 +759,17 @@ def __init__(
stacklevel=3,
)
validate_key(task_id)
self.task_id = task_id

dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)

if task_group:
self.task_id = task_group.child_id(task_id)
else:
self.task_id = task_id
if not self.__mapped_validation and task_group:
task_group.add(self)

self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
Expand Down Expand Up @@ -972,9 +988,8 @@ def __lt__(self, other):

def __setattr__(self, key, value):
super().__setattr__(key, value)
if self._lock_for_execution:
# Skip any custom behaviour during execute
return
if self.__mapped_validation or self._lock_for_execution:
return # Skip any custom behavior for validation and during execute.
if key in self.__init_kwargs:
self.__init_kwargs[key] = value
if self.__instantiated and key in self.template_fields:
Expand Down Expand Up @@ -1026,6 +1041,9 @@ def dag(self, dag: Optional['DAG']):
raise TypeError(f'Expected DAG; received {dag.__class__.__name__}')
elif self.has_dag() and self.dag is not dag:
raise AirflowException(f"The DAG assigned to {self} can not be changed.")

if self.__mapped_validation:
pass # Don't add task to DAG for validation.
elif self.task_id not in dag.task_dict:
dag.add_task(self)
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
Expand Down Expand Up @@ -1417,6 +1435,7 @@ def get_serialized_fields(cls):
'label',
'_BaseOperator__instantiated',
'_BaseOperator__init_kwargs',
'_BaseOperator__mapped_validation',
}
| { # Class level defaults need to be added to this list
'start_date',
Expand Down Expand Up @@ -1470,6 +1489,12 @@ def defer(
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

@classmethod
def validate_mapped_arguments(cls, **kwargs: Any) -> None:
"""Validate arguments when this operator is being mapped."""
if cls.mapped_arguments_validated_by_init:
cls(**kwargs, _airflow_mapped_validation_only=True)

def unmap(self) -> "BaseOperator":
""":meta private:"""
return self
Expand Down
62 changes: 15 additions & 47 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import datetime
import functools
import operator
import unittest.mock
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -152,24 +151,6 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
ensure_xcomarg_return_value(v)


def create_mocked_kwargs(kwargs: Dict[str, "Mappable"]) -> Dict[str, unittest.mock.MagicMock]:
"""Create a mapping of mocks for given map arguments.

When a mapped operator is created, we want to perform basic validation on
the map arguments, especially the count of arguments. However, most of this
kind of logic lives directly on an operator class's ``__init__``, and
there's no good way to validate the arguments except to actually try to
create an operator instance.

Since the map arguments are yet to be populated when the mapped operator is
being parsed, we need to "invent" some mocked values for this validation
purpose. The :class:`~unittest.mock.MagicMock` class is a good fit for this
since it not only provide good run-time properties, but also enjoy special
treatments in Mypy.
"""
return {k: unittest.mock.MagicMock(name=k) for k in kwargs}


@attr.define(kw_only=True, repr=False)
class OperatorPartial:
"""An "intermediate state" returned by ``BaseOperator.partial()``.
Expand Down Expand Up @@ -321,11 +302,7 @@ def _validate_argument_count(self) -> None:
"""
if isinstance(self.operator_class, str):
return # No need to validate deserialized operator.
mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
op = self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)
dag = op.get_dag()
if dag:
dag._remove_task(op.task_id)
self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs())

@property
def task_type(self) -> str:
Expand Down Expand Up @@ -455,36 +432,27 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
"""Implementing DAGNode."""
return DagAttributeTypes.OP, self.task_id

def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], real: bool) -> "BaseOperator":
"""Create a task of the underlying class based on this mapped operator.

:param mapped_kwargs: Mapped keyword arguments to be used to create the
task. Do not use ``self.mapped_kwargs``.
:param real: Whether the task should be created "for real" (i.e. *False*
means the operator is only created for validation purposes and not
going to be added to the actual DAG). This is simply forwarded to
the operator's ``_airflow_map_validation`` argument.
"""
assert not isinstance(self.operator_class, str)
return self.operator_class(
task_id=self.task_id,
dag=self.dag,
task_group=self.task_group,
params=self.params,
start_date=self.start_date,
end_date=self.end_date,
_airflow_map_validation=not real,
def _get_unmap_kwargs(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"dag": self.dag,
"task_group": self.task_group,
"params": self.params,
"start_date": self.start_date,
"end_date": self.end_date,
**self.partial_kwargs,
**mapped_kwargs,
)
**self.mapped_kwargs,
}

def unmap(self) -> "BaseOperator":
"""Get the "normal" Operator after applying the current mapping"""
"""Get the "normal" Operator after applying the current mapping."""
dag = self.dag
if not dag:
raise RuntimeError("Cannot unmap a task without a DAG")
dag._remove_task(self.task_id)
return self._create_unmapped_operator(mapped_kwargs=self.mapped_kwargs, real=True)
if isinstance(self.operator_class, str):
raise RuntimeError("Cannot unmap a deserialized operator")
return self.operator_class(**self._get_unmap_kwargs())

def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]:
"""The kwargs to calculate expansion length against.
Expand Down
2 changes: 2 additions & 0 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def my_python_callable(**kwargs):
'op_kwargs',
)

mapped_arguments_validated_by_init = True

def __init__(
self,
*,
Expand Down