diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4a378ddb18f8d..c35f254492815 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -30,20 +30,9 @@ from datetime import datetime, timedelta from functools import partial from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, - ContextManager, - Generator, - Iterable, - NamedTuple, - Tuple, -) +from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple from urllib.parse import quote -import attr import dill import jinja2 import lazy_object_proxy @@ -69,8 +58,6 @@ from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value -from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.expression import ColumnOperators @@ -101,7 +88,7 @@ from airflow.models.taskfail import TaskFail from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule -from airflow.models.xcom import XCOM_RETURN_KEY, XCom +from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom from airflow.plugins_manager import integrate_macros_plugins from airflow.sentry import Sentry from airflow.stats import Stats @@ -291,91 +278,6 @@ def clear_task_instances( session.flush() -class _LazyXComAccessIterator(collections.abc.Iterator): - __slots__ = ['_cm', '_it'] - - def __init__(self, cm: ContextManager[Query]): - self._cm = cm - self._it = None - - def __del__(self): - if self._it: - self._cm.__exit__(None, None, None) - - def __iter__(self): - return self - - def __next__(self): - if not self._it: - self._it = iter(self._cm.__enter__()) - return XCom.deserialize_value(next(self._it)) - - -@attr.define -class _LazyXComAccess(collections.abc.Sequence): - """Wrapper to lazily pull XCom with a sequence-like interface. - - Note that since the session bound to the parent query may have died when we - actually access the sequence's content, we must create a new session - for every function call with ``with_session()``. - """ - - dag_id: str - run_id: str - task_id: str - _query: Query = attr.ib(repr=False) - _len: int | None = attr.ib(init=False, repr=False, default=None) - - @classmethod - def build_from_single_xcom(cls, first: XCom, query: Query) -> _LazyXComAccess: - return cls( - dag_id=first.dag_id, - run_id=first.run_id, - task_id=first.task_id, - query=query.with_entities(XCom.value) - .filter( - XCom.run_id == first.run_id, - XCom.task_id == first.task_id, - XCom.dag_id == first.dag_id, - XCom.map_index >= 0, - ) - .order_by(None) - .order_by(XCom.map_index.asc()), - ) - - def __len__(self): - if self._len is None: - with self._get_bound_query() as query: - self._len = query.count() - return self._len - - def __iter__(self): - return _LazyXComAccessIterator(self._get_bound_query()) - - def __getitem__(self, key): - if not isinstance(key, int): - raise ValueError("only support index access for now") - try: - with self._get_bound_query() as query: - r = query.offset(key).limit(1).one() - except NoResultFound: - raise IndexError(key) from None - return XCom.deserialize_value(r) - - @contextlib.contextmanager - def _get_bound_query(self) -> Generator[Query, None, None]: - # Do we have a valid session already? - if self._query.session and self._query.session.is_active: - yield self._query - return - - session = settings.Session() - try: - yield self._query.with_session(session) - finally: - session.close() - - class TaskInstanceKey(NamedTuple): """Key used to identify task instance.""" @@ -2441,7 +2343,7 @@ def xcom_pull( if map_indexes is not None or first.map_index < 0: return XCom.deserialize_value(first) - return _LazyXComAccess.build_from_single_xcom(first, query) + return LazyXComAccess.build_from_single_xcom(first, query) # At this point either task_ids or map_indexes is explicitly multi-value. diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 5108d91a8833c..8b4d721b2049e 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import collections.abc +import contextlib import datetime import inspect import json @@ -24,8 +26,9 @@ import pickle import warnings from functools import wraps -from typing import TYPE_CHECKING, Any, Iterable, cast, overload +from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload +import attr import pendulum from sqlalchemy import ( Column, @@ -41,6 +44,8 @@ from sqlalchemy.orm import Query, Session, reconstructor, relationship from sqlalchemy.orm.exc import NoResultFound +from airflow import settings +from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.base import COLLATION_ARGS, ID_LEN, Base @@ -589,8 +594,15 @@ def serialize_value( dag_id: str | None = None, run_id: str | None = None, map_index: int | None = None, - ): - """Serialize XCom value to str or pickled object""" + ) -> Any: + """Serialize XCom value to str or pickled object.""" + # Seamlessly resolve LazyXComAccess to a list. This is intended to work + # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if + # it's pushed into XCom, the user should be aware of the performance + # implications, and this avoids leaking the implementation detail. + if isinstance(value, LazyXComAccess): + value = list(value) + if conf.getboolean('core', 'enable_xcom_pickling'): return pickle.dumps(value) try: @@ -632,6 +644,92 @@ def orm_deserialize_value(self) -> Any: return BaseXCom.deserialize_value(self) +class _LazyXComAccessIterator(collections.abc.Iterator): + def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None: + self._cm = cm + self._entered = False + + def __del__(self) -> None: + if self._entered: + self._cm.__exit__(None, None, None) + + def __iter__(self) -> collections.abc.Iterator: + return self + + def __next__(self) -> Any: + return XCom.deserialize_value(next(self._it)) + + @cached_property + def _it(self) -> collections.abc.Iterator: + self._entered = True + return iter(self._cm.__enter__()) + + +@attr.define(slots=True) +class LazyXComAccess(collections.abc.Sequence): + """Wrapper to lazily pull XCom with a sequence-like interface. + + Note that since the session bound to the parent query may have died when we + actually access the sequence's content, we must create a new session + for every function call with ``with_session()``. + """ + + dag_id: str + run_id: str + task_id: str + _query: Query = attr.ib(repr=False) + _len: int | None = attr.ib(init=False, repr=False, default=None) + + @classmethod + def build_from_single_xcom(cls, first: XCom, query: Query) -> LazyXComAccess: + return cls( + dag_id=first.dag_id, + run_id=first.run_id, + task_id=first.task_id, + query=query.with_entities(XCom.value) + .filter( + XCom.run_id == first.run_id, + XCom.task_id == first.task_id, + XCom.dag_id == first.dag_id, + XCom.map_index >= 0, + ) + .order_by(None) + .order_by(XCom.map_index.asc()), + ) + + def __len__(self): + if self._len is None: + with self._get_bound_query() as query: + self._len = query.count() + return self._len + + def __iter__(self): + return _LazyXComAccessIterator(self._get_bound_query()) + + def __getitem__(self, key): + if not isinstance(key, int): + raise ValueError("only support index access for now") + try: + with self._get_bound_query() as query: + r = query.offset(key).limit(1).one() + except NoResultFound: + raise IndexError(key) from None + return XCom.deserialize_value(r) + + @contextlib.contextmanager + def _get_bound_query(self) -> Generator[Query, None, None]: + # Do we have a valid session already? + if self._query.session and self._query.session.is_active: + yield self._query + return + + session = settings.Session() + try: + yield self._query.with_session(session) + finally: + session.close() + + def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: """Patch a custom ``serialize_value`` to accept the modern signature. diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst b/docs/apache-airflow/concepts/dynamic-task-mapping.rst index 9f02b92215564..2c181d617e8fa 100644 --- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst +++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst @@ -68,7 +68,7 @@ The grid view also provides visibility into your mapped tasks in the details pan In the above example, ``values`` received by ``sum_it`` is an aggregation of all values returned by each mapped instance of ``add_one``. However, since it is impossible to know how many instances of ``add_one`` we will have in advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves each individual value only when asked. Therefore, if you run ``print(values)`` directly, you would get something like this:: - _LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one') + LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one') You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but please be aware of the potential performance implications if the list is large. diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index d21203b2d752c..c167e4d5b5c7d 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3485,7 +3485,7 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v joined = ti_2.xcom_pull("task_1", session=session) assert mock_deserialize_value.call_count == 0 - assert repr(joined) == "_LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')" + assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')" # Only when we go through the iterable does deserialization happen. it = iter(joined)