Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow MapXComArg to resolve after serialization #26591

Merged
merged 2 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import contextlib
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, overload
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload

from sqlalchemy import func
from sqlalchemy.orm import Session
Expand All @@ -35,6 +37,11 @@
from airflow.models.dag import DAG
from airflow.models.operator import Operator

# Callable objects contained by MapXComArg. We only accept callables from
# the user, but deserialize them into strings in a serialized XComArg for
# safety (those callables are arbitrary user code).
MapCallables = Sequence[Union[Callable[[Any], Any], str]]


class XComArg(DependencyMixin):
"""Reference to an XCom value pushed from another operator.
Expand Down Expand Up @@ -322,15 +329,39 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
raise XComNotFound(context["ti"].dag_id, task_id, self.key)


def _get_callable_name(f: Callable | str) -> str:
"""Try to "describe" a callable by getting its name."""
if callable(f):
return f.__name__
# Parse the source to find whatever is behind "def". For safety, we don't
# want to evaluate the code in any meaningful way!
with contextlib.suppress(Exception):
kw, name, _ = f.lstrip().split(None, 2)
if kw == "def":
return name
return "<function>"


class _MapResult(Sequence):
def __init__(self, value: Sequence | dict, callables: Sequence[Callable[[Any], Any]]) -> None:
def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
self.value = value
self.callables = callables

def __getitem__(self, index: Any) -> Any:
value = self.value[index]
for f in self.callables:
value = f(value)

# In the worker, we can access all actual callables. Call them.
callables = [f for f in self.callables if callable(f)]
if len(callables) == len(self.callables):
for f in callables:
value = f(value)
return value

# In the scheduler, we don't have access to the actual callables, nor do
# we want to run it since it's arbitrary code. This builds a string to
# represent the call chain in the UI or logs instead.
for v in self.callables:
value = f"{_get_callable_name(v)}({value})"
return value

def __len__(self) -> int:
Expand All @@ -342,22 +373,25 @@ class MapXComArg(XComArg):

This is based on an XComArg, but also applies a series of "transforms" that
convert the pulled XCom value.

:meta private:
"""

def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
def __init__(self, arg: XComArg, callables: MapCallables) -> None:
for c in callables:
if getattr(c, "_airflow_is_task_decorator", False):
raise ValueError("map() argument must be a plain function, not a @task operator")
self.arg = arg
self.callables = callables

def __repr__(self) -> str:
return f"{self.arg!r}.map([{len(self.callables)} functions])"
map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
return f"{self.arg!r}{map_calls}"

def _serialize(self) -> dict[str, Any]:
return {
"arg": serialize_xcom_arg(self.arg),
"callables": [inspect.getsource(c) for c in self.callables],
"callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
}

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ def pull(value):

# Run "push_letters" and "push_numbers".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["push_letters", "push_numbers"]
for ti in decision.schedulable_tis:
ti.run(session=session)
session.commit()

# Run "pull".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["pull"] * len(expected_results)
for ti in decision.schedulable_tis:
ti.run(session=session)

Expand Down