Skip to content

Commit

Permalink
Implement map() semantic (#25085)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Jul 25, 2022
1 parent d7777bb commit 877dc89
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 117 deletions.
12 changes: 0 additions & 12 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,6 @@ def __str__(self) -> str:
return f"unmappable return type {typename!r}"


class UnmappableXComValuePushed(AirflowException):
"""Raise when an invalid value is pushed as a mapped downstream's dependency."""

def __init__(self, value: Any, reason: str) -> None:
super().__init__(value, reason)
self.value = value
self.reason = reason

def __str__(self) -> str:
return f"unmappable return value {self.value!r} ({self.reason})"


class UnmappableXComLengthPushed(AirflowException):
"""Raise when the pushed value is too large to map as a downstream's dependency."""

Expand Down
18 changes: 0 additions & 18 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from sqlalchemy.orm import Session

from airflow.compat.functools import cache
from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed
from airflow.utils.context import Context

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,11 +71,6 @@ class DictOfListsExpandInput(NamedTuple):

value: dict[str, Mappable]

@staticmethod
def validate_xcom(value: Any) -> None:
if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
raise UnmappableXComTypePushed(value)

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving."""
return self.value
Expand Down Expand Up @@ -212,18 +206,6 @@ class ListOfDictsExpandInput(NamedTuple):

value: XComArg

@staticmethod
def validate_xcom(value: Any) -> None:
if not isinstance(value, collections.abc.Collection):
raise UnmappableXComTypePushed(value)
if isinstance(value, (str, bytes, collections.abc.Mapping)):
raise UnmappableXComTypePushed(value)
for item in value:
if not isinstance(item, collections.abc.Mapping):
raise UnmappableXComTypePushed(value, item)
if not all(isinstance(k, str) for k in item):
raise UnmappableXComValuePushed(value, reason="dict keys must be str")

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving.
Expand Down
15 changes: 0 additions & 15 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Dict,
Expand Down Expand Up @@ -613,20 +612,6 @@ def _get_specified_expand_input(self) -> ExpandInput:
"""Input received from the expand call on the operator."""
return getattr(self, self._expand_input_attr)

@property
def validate_upstream_return_value(self) -> Callable[[Any], None]:
"""Validate an upstream's return value satisfies this task's needs.
This is implemented as a property (instead of a function calling
``validate_xcom``) so the call site in TaskInstance can de-duplicate
validation functions. If this is an instance method, each
``validate_upstream_return_value`` would be a different object (due to
how Python handles bounded functions), and de-duplication won't work.
:meta private:
"""
return self._get_specified_expand_input().validate_xcom

def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]:
"""Create the mapped task instances for mapped task.
Expand Down
13 changes: 8 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
XComForMappingNotPushed,
)
from airflow.models.base import Base, StringID
Expand Down Expand Up @@ -2330,8 +2331,7 @@ def set_duration(self) -> None:
self.log.debug("Task Duration set to %s", self.duration)

def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None:
validators = {m.validate_upstream_return_value for m in task.iter_mapped_dependants()}
if not validators: # No mapped dependants, no need to validate.
if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
Expand All @@ -2341,9 +2341,12 @@ def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, sess
return
if value is None:
raise XComForMappingNotPushed()
for validator in validators:
validator(value)
assert isinstance(value, collections.abc.Collection) # The validators type-guard this.
if not isinstance(value, (collections.abc.Sequence, dict)):
raise UnmappableXComTypePushed(value)
if isinstance(value, (bytes, str)):
raise UnmappableXComTypePushed(value)
if TYPE_CHECKING: # The isinstance() checks above guard this.
assert isinstance(value, collections.abc.Collection)
task_map = TaskMap.from_task_instance_xcom(self, value)
max_map_length = conf.getint("core", "max_map_length", fallback=1024)
if task_map.length > max_map_length:
Expand Down
Loading

0 comments on commit 877dc89

Please sign in to comment.