Skip to content

Commit

Permalink
Coerce LazyXComAccess to list when pushed to XCom
Browse files Browse the repository at this point in the history
The class is intended to work as a "lazy list" to avoid pulling a ton
of XComs unnecessarily, but if it's pushed into XCom, the user should
be aware of the performance implications, and this avoids leaking the
implementation detail.
  • Loading branch information
uranusjr committed Oct 25, 2022
1 parent 2dc78b7 commit e3ab730
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 106 deletions.
104 changes: 3 additions & 101 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,9 @@
from datetime import datetime, timedelta
from functools import partial
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
ContextManager,
Generator,
Iterable,
NamedTuple,
Tuple,
)
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple
from urllib.parse import quote

import attr
import dill
import jinja2
import lazy_object_proxy
Expand All @@ -69,8 +58,6 @@
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators
Expand Down Expand Up @@ -101,7 +88,7 @@
from airflow.models.taskfail import TaskFail
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.stats import Stats
Expand Down Expand Up @@ -291,91 +278,6 @@ def clear_task_instances(
session.flush()


class _LazyXComAccessIterator(collections.abc.Iterator):
__slots__ = ['_cm', '_it']

def __init__(self, cm: ContextManager[Query]):
self._cm = cm
self._it = None

def __del__(self):
if self._it:
self._cm.__exit__(None, None, None)

def __iter__(self):
return self

def __next__(self):
if not self._it:
self._it = iter(self._cm.__enter__())
return XCom.deserialize_value(next(self._it))


@attr.define
class _LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
"""

dag_id: str
run_id: str
task_id: str
_query: Query = attr.ib(repr=False)
_len: int | None = attr.ib(init=False, repr=False, default=None)

@classmethod
def build_from_single_xcom(cls, first: XCom, query: Query) -> _LazyXComAccess:
return cls(
dag_id=first.dag_id,
run_id=first.run_id,
task_id=first.task_id,
query=query.with_entities(XCom.value)
.filter(
XCom.run_id == first.run_id,
XCom.task_id == first.task_id,
XCom.dag_id == first.dag_id,
XCom.map_index >= 0,
)
.order_by(None)
.order_by(XCom.map_index.asc()),
)

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._query.session and self._query.session.is_active:
yield self._query
return

session = settings.Session()
try:
yield self._query.with_session(session)
finally:
session.close()


class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""

Expand Down Expand Up @@ -2441,7 +2343,7 @@ def xcom_pull(
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)

return _LazyXComAccess.build_from_single_xcom(first, query)
return LazyXComAccess.build_from_single_xcom(first, query)

# At this point either task_ids or map_indexes is explicitly multi-value.

Expand Down
104 changes: 101 additions & 3 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
# under the License.
from __future__ import annotations

import collections.abc
import contextlib
import datetime
import inspect
import json
import logging
import pickle
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, cast, overload
from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload

import attr
import pendulum
from sqlalchemy import (
Column,
Expand All @@ -41,6 +44,8 @@
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound

from airflow import settings
from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
Expand Down Expand Up @@ -589,8 +594,15 @@ def serialize_value(
dag_id: str | None = None,
run_id: str | None = None,
map_index: int | None = None,
):
"""Serialize XCom value to str or pickled object"""
) -> Any:
"""Serialize XCom value to str or pickled object."""
# Seamlessly resolve LazyXComAccess to a list. This is intended to work
# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
# it's pushed into XCom, the user should be aware of the performance
# implications, and this avoids leaking the implementation detail.
if isinstance(value, LazyXComAccess):
value = list(value)

if conf.getboolean('core', 'enable_xcom_pickling'):
return pickle.dumps(value)
try:
Expand Down Expand Up @@ -632,6 +644,92 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom.deserialize_value(self)


class _LazyXComAccessIterator(collections.abc.Iterator):
def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
self._cm = cm
self._entered = False

def __del__(self) -> None:
if self._entered:
self._cm.__exit__(None, None, None)

def __iter__(self) -> collections.abc.Iterator:
return self

def __next__(self) -> Any:
return XCom.deserialize_value(next(self._it))

@cached_property
def _it(self) -> collections.abc.Iterator:
self._entered = True
return iter(self._cm.__enter__())


@attr.define(slots=True)
class LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
"""

dag_id: str
run_id: str
task_id: str
_query: Query = attr.ib(repr=False)
_len: int | None = attr.ib(init=False, repr=False, default=None)

@classmethod
def build_from_single_xcom(cls, first: XCom, query: Query) -> LazyXComAccess:
return cls(
dag_id=first.dag_id,
run_id=first.run_id,
task_id=first.task_id,
query=query.with_entities(XCom.value)
.filter(
XCom.run_id == first.run_id,
XCom.task_id == first.task_id,
XCom.dag_id == first.dag_id,
XCom.map_index >= 0,
)
.order_by(None)
.order_by(XCom.map_index.asc()),
)

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._query.session and self._query.session.is_active:
yield self._query
return

session = settings.Session()
try:
yield self._query.with_session(session)
finally:
session.close()


def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ The grid view also provides visibility into your mapped tasks in the details pan

In the above example, ``values`` received by ``sum_it`` is an aggregation of all values returned by each mapped instance of ``add_one``. However, since it is impossible to know how many instances of ``add_one`` we will have in advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves each individual value only when asked. Therefore, if you run ``print(values)`` directly, you would get something like this::

_LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one')
LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one')

You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but please be aware of the potential performance implications if the list is large.

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,7 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v
joined = ti_2.xcom_pull("task_1", session=session)
assert mock_deserialize_value.call_count == 0

assert repr(joined) == "_LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')"
assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')"

# Only when we go through the iterable does deserialization happen.
it = iter(joined)
Expand Down

0 comments on commit e3ab730

Please sign in to comment.