diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 9ec2854be7ded..29d0ad4497405 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -61,7 +61,7 @@ from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList -from sqlalchemy.sql.expression import ColumnOperators +from sqlalchemy.sql.expression import ColumnOperators, case from airflow import settings from airflow.compat.functools import cache @@ -2350,35 +2350,34 @@ def xcom_pull( return default if map_indexes is not None or first.map_index < 0: return XCom.deserialize_value(first) - - return LazyXComAccess.build_from_single_xcom(first, query) + query = query.order_by(None).order_by(XCom.map_index.asc()) + return LazyXComAccess.build_from_xcom_query(query) # At this point either task_ids or map_indexes is explicitly multi-value. - - results = ( - (r.task_id, r.map_index, XCom.deserialize_value(r)) - for r in query.with_entities(XCom.task_id, XCom.map_index, XCom.value) - ) - - if task_ids is None: - task_id_pos: dict[str, int] = defaultdict(int) - elif isinstance(task_ids, str): - task_id_pos = {task_ids: 0} + # Order return values to match task_ids and map_indexes ordering. + query = query.order_by(None) + if task_ids is None or isinstance(task_ids, str): + query = query.order_by(XCom.task_id) else: - task_id_pos = {task_id: i for i, task_id in enumerate(task_ids)} - if map_indexes is None: - map_index_pos: dict[int, int] = defaultdict(int) - elif isinstance(map_indexes, int): - map_index_pos = {map_indexes: 0} + task_id_whens = {tid: i for i, tid in enumerate(task_ids)} + if task_id_whens: + query = query.order_by(case(task_id_whens, value=XCom.task_id)) + else: + query = query.order_by(XCom.task_id) + if map_indexes is None or isinstance(map_indexes, int): + query = query.order_by(XCom.map_index) + elif isinstance(map_indexes, range): + order = XCom.map_index + if map_indexes.step < 0: + order = order.desc() + query = query.order_by(order) else: - map_index_pos = {map_index: i for i, map_index in enumerate(map_indexes)} - - def _arg_pos(item: tuple[str, int, Any]) -> tuple[int, int]: - task_id, map_index, _ = item - return task_id_pos[task_id], map_index_pos[map_index] - - results_sorted_by_arg_pos = sorted(results, key=_arg_pos) - return [value for _, _, value in results_sorted_by_arg_pos] + map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)} + if map_index_whens: + query = query.order_by(case(map_index_whens, value=XCom.map_index)) + else: + query = query.order_by(XCom.map_index) + return LazyXComAccess.build_from_xcom_query(query) @provide_session def get_num_running_task_instances(self, session: Session) -> int: diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 9aea200a2cb50..39bb1221222ee 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -21,6 +21,7 @@ import contextlib import datetime import inspect +import itertools import json import logging import pickle @@ -496,7 +497,9 @@ def get_many( elif dag_ids is not None: query = query.filter(cls.dag_id == dag_ids) - if is_container(map_indexes): + if isinstance(map_indexes, range) and abs(map_indexes.step) == 1: + query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) + elif is_container(map_indexes): query = query.filter(cls.map_index.in_(map_indexes)) elif map_indexes is not None: query = query.filter(cls.map_index == map_indexes) @@ -697,30 +700,28 @@ class LazyXComAccess(collections.abc.Sequence): Note that since the session bound to the parent query may have died when we actually access the sequence's content, we must create a new session for every function call with ``with_session()``. + + :meta private: """ - dag_id: str - run_id: str - task_id: str - _query: Query = attr.ib(repr=False) - _len: int | None = attr.ib(init=False, repr=False, default=None) + _query: Query + _len: int | None = attr.ib(init=False, default=None) @classmethod - def build_from_single_xcom(cls, first: XCom, query: Query) -> LazyXComAccess: - return cls( - dag_id=first.dag_id, - run_id=first.run_id, - task_id=first.task_id, - query=query.with_entities(XCom.value) - .filter( - XCom.run_id == first.run_id, - XCom.task_id == first.task_id, - XCom.dag_id == first.dag_id, - XCom.map_index >= 0, - ) - .order_by(None) - .order_by(XCom.map_index.asc()), - ) + def build_from_xcom_query(cls, query: Query) -> LazyXComAccess: + return cls(query=query.with_entities(XCom.value)) + + def __repr__(self) -> str: + return f"LazyXComAccess([{len(self)} items])" + + def __str__(self) -> str: + return str(list(self)) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, (list, LazyXComAccess)): + z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) + return all(x == y for x, y in z) + return NotImplemented def __len__(self): if self._len is None: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 77b226441424e..8549230c08294 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -47,26 +47,22 @@ UnmappableXComTypePushed, XComForMappingNotPushed, ) -from airflow.models import ( - DAG, - Connection, - DagBag, - DagRun, - Pool, - RenderedTaskInstanceFields, - TaskInstance as TI, - TaskReschedule, - Variable, - XCom, -) +from airflow.models.connection import Connection +from airflow.models.dag import DAG +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated from airflow.models.param import process_params +from airflow.models.pool import Pool +from airflow.models.renderedtifields import RenderedTaskInstanceFields from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskfail import TaskFail -from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstance import TaskInstance, TaskInstance as TI from airflow.models.taskmap import TaskMap -from airflow.models.xcom import XCOM_RETURN_KEY +from airflow.models.taskreschedule import TaskReschedule +from airflow.models.variable import Variable +from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom from airflow.operators.bash import BashOperator from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator @@ -3522,10 +3518,9 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v # Simply pulling the joined XCom value should not deserialize. joined = ti_2.xcom_pull("task_1", session=session) + assert isinstance(joined, LazyXComAccess) assert mock_deserialize_value.call_count == 0 - assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')" - # Only when we go through the iterable does deserialization happen. it = iter(joined) assert next(it) == "a"