Skip to content

Commit

Permalink
Coerce LazyXComAccess to list when pushed to XCom (#27251)
Browse files Browse the repository at this point in the history
The class is intended to work as a "lazy list" to avoid pulling a ton
of XComs unnecessarily, but if it's pushed into XCom, the user should
be aware of the performance implications, and this avoids leaking the
implementation detail.
  • Loading branch information
uranusjr committed Nov 7, 2022
1 parent 96a5a63 commit 62b7bd6
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 111 deletions.
104 changes: 3 additions & 101 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,9 @@
from datetime import datetime, timedelta
from functools import partial
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
ContextManager,
Generator,
Iterable,
NamedTuple,
Tuple,
)
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple
from urllib.parse import quote

import attr
import dill
import jinja2
import lazy_object_proxy
Expand All @@ -69,8 +58,6 @@
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators
Expand Down Expand Up @@ -101,7 +88,7 @@
from airflow.models.taskfail import TaskFail
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.stats import Stats
Expand Down Expand Up @@ -291,91 +278,6 @@ def clear_task_instances(
session.flush()


class _LazyXComAccessIterator(collections.abc.Iterator):
__slots__ = ['_cm', '_it']

def __init__(self, cm: ContextManager[Query]):
self._cm = cm
self._it = None

def __del__(self):
if self._it:
self._cm.__exit__(None, None, None)

def __iter__(self):
return self

def __next__(self):
if not self._it:
self._it = iter(self._cm.__enter__())
return XCom.deserialize_value(next(self._it))


@attr.define
class _LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
"""

dag_id: str
run_id: str
task_id: str
_query: Query = attr.ib(repr=False)
_len: int | None = attr.ib(init=False, repr=False, default=None)

@classmethod
def build_from_single_xcom(cls, first: XCom, query: Query) -> _LazyXComAccess:
return cls(
dag_id=first.dag_id,
run_id=first.run_id,
task_id=first.task_id,
query=query.with_entities(XCom.value)
.filter(
XCom.run_id == first.run_id,
XCom.task_id == first.task_id,
XCom.dag_id == first.dag_id,
XCom.map_index >= 0,
)
.order_by(None)
.order_by(XCom.map_index.asc()),
)

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._query.session and self._query.session.is_active:
yield self._query
return

session = settings.Session()
try:
yield self._query.with_session(session)
finally:
session.close()


class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""

Expand Down Expand Up @@ -2439,7 +2341,7 @@ def xcom_pull(
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)

return _LazyXComAccess.build_from_single_xcom(first, query)
return LazyXComAccess.build_from_single_xcom(first, query)

# At this point either task_ids or map_indexes is explicitly multi-value.

Expand Down
118 changes: 115 additions & 3 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
# under the License.
from __future__ import annotations

import collections.abc
import contextlib
import datetime
import inspect
import json
import logging
import pickle
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, cast, overload
from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload

import attr
import pendulum
from sqlalchemy import (
Column,
Expand All @@ -41,6 +44,8 @@
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound

from airflow import settings
from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
Expand Down Expand Up @@ -203,6 +208,27 @@ def set(
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")

# Seamlessly resolve LazyXComAccess to a list. This is intended to work
# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
# it's pushed into XCom, the user should be aware of the performance
# implications, and this avoids leaking the implementation detail.
if isinstance(value, LazyXComAccess):
warning_message = (
"Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
"to list, which may degrade performance. Review resource "
"requirements for this operation, and call list() to suppress "
"this message. See Dynamic Task Mapping documentation for "
"more information about lazy proxy objects."
)
log.warning(
warning_message,
"return value" if key == XCOM_RETURN_KEY else f"value {key}",
task_id,
dag_id,
run_id or execution_date,
)
value = list(value)

value = cls.serialize_value(
value=value,
key=key,
Expand Down Expand Up @@ -589,8 +615,8 @@ def serialize_value(
dag_id: str | None = None,
run_id: str | None = None,
map_index: int | None = None,
):
"""Serialize XCom value to str or pickled object"""
) -> Any:
"""Serialize XCom value to str or pickled object."""
if conf.getboolean('core', 'enable_xcom_pickling'):
return pickle.dumps(value)
try:
Expand Down Expand Up @@ -632,6 +658,92 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom.deserialize_value(self)


class _LazyXComAccessIterator(collections.abc.Iterator):
def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
self._cm = cm
self._entered = False

def __del__(self) -> None:
if self._entered:
self._cm.__exit__(None, None, None)

def __iter__(self) -> collections.abc.Iterator:
return self

def __next__(self) -> Any:
return XCom.deserialize_value(next(self._it))

@cached_property
def _it(self) -> collections.abc.Iterator:
self._entered = True
return iter(self._cm.__enter__())


@attr.define(slots=True)
class LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
"""

dag_id: str
run_id: str
task_id: str
_query: Query = attr.ib(repr=False)
_len: int | None = attr.ib(init=False, repr=False, default=None)

@classmethod
def build_from_single_xcom(cls, first: XCom, query: Query) -> LazyXComAccess:
return cls(
dag_id=first.dag_id,
run_id=first.run_id,
task_id=first.task_id,
query=query.with_entities(XCom.value)
.filter(
XCom.run_id == first.run_id,
XCom.task_id == first.task_id,
XCom.dag_id == first.dag_id,
XCom.map_index >= 0,
)
.order_by(None)
.order_by(XCom.map_index.asc()),
)

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._query.session and self._query.session.is_active:
yield self._query
return

session = settings.Session()
try:
yield self._query.with_session(session)
finally:
session.close()


def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.
Expand Down
24 changes: 18 additions & 6 deletions docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,32 @@ The grid view also provides visibility into your mapped tasks in the details pan

In the above example, ``values`` received by ``sum_it`` is an aggregation of all values returned by each mapped instance of ``add_one``. However, since it is impossible to know how many instances of ``add_one`` we will have in advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves each individual value only when asked. Therefore, if you run ``print(values)`` directly, you would get something like this::

_LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one')
LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one')

You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but please be aware of the potential performance implications if the list is large.
You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but since this would eagerly load values from *all* of the referenced upstream mapped tasks, you must be aware of the potential performance implications if the mapped number is large.

Note that the same also applies to when you push this proxy object into XCom. This, for example, would not
work with the default XCom backend:
Note that the same also applies to when you push this proxy object into XCom. Airflow tries to be smart and coerce the value automatically, but will emit a warning for this so you are aware of this. For example:

.. code-block:: python
@task
def forward_values(values):
return values # This is a lazy proxy and can't be pushed!
return values # This is a lazy proxy!
You need to explicitly call ``list(values)`` instead, and accept the performance implications.
will emit a warning like this:

.. code-block:: text
Coercing mapped lazy proxy return value from task forward_values to list, which may degrade
performance. Review resource requirements for this operation, and call list() explicitly to suppress this message. See Dynamic Task Mapping documentation for more information about lazy proxy objects.
The message can be suppressed by modifying the task like this:

.. code-block:: python
@task
def forward_values(values):
return list(values)
.. note:: A reduce task is not required.

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,7 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v
joined = ti_2.xcom_pull("task_1", session=session)
assert mock_deserialize_value.call_count == 0

assert repr(joined) == "_LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')"
assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')"

# Only when we go through the iterable does deserialization happen.
it = iter(joined)
Expand Down

0 comments on commit 62b7bd6

Please sign in to comment.