Skip to content

Commit

Permalink
Optimize TI.xcom_pull() with explicit task_ids and map_indexes (#27699)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Nov 21, 2022
1 parent 8a8ad47 commit 72b1c2f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 63 deletions.
51 changes: 25 additions & 26 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.sql.expression import ColumnOperators, case

from airflow import settings
from airflow.compat.functools import cache
Expand Down Expand Up @@ -2350,35 +2350,34 @@ def xcom_pull(
return default
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)

return LazyXComAccess.build_from_single_xcom(first, query)
query = query.order_by(None).order_by(XCom.map_index.asc())
return LazyXComAccess.build_from_xcom_query(query)

# At this point either task_ids or map_indexes is explicitly multi-value.

results = (
(r.task_id, r.map_index, XCom.deserialize_value(r))
for r in query.with_entities(XCom.task_id, XCom.map_index, XCom.value)
)

if task_ids is None:
task_id_pos: dict[str, int] = defaultdict(int)
elif isinstance(task_ids, str):
task_id_pos = {task_ids: 0}
# Order return values to match task_ids and map_indexes ordering.
query = query.order_by(None)
if task_ids is None or isinstance(task_ids, str):
query = query.order_by(XCom.task_id)
else:
task_id_pos = {task_id: i for i, task_id in enumerate(task_ids)}
if map_indexes is None:
map_index_pos: dict[int, int] = defaultdict(int)
elif isinstance(map_indexes, int):
map_index_pos = {map_indexes: 0}
task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
if task_id_whens:
query = query.order_by(case(task_id_whens, value=XCom.task_id))
else:
query = query.order_by(XCom.task_id)
if map_indexes is None or isinstance(map_indexes, int):
query = query.order_by(XCom.map_index)
elif isinstance(map_indexes, range):
order = XCom.map_index
if map_indexes.step < 0:
order = order.desc()
query = query.order_by(order)
else:
map_index_pos = {map_index: i for i, map_index in enumerate(map_indexes)}

def _arg_pos(item: tuple[str, int, Any]) -> tuple[int, int]:
task_id, map_index, _ = item
return task_id_pos[task_id], map_index_pos[map_index]

results_sorted_by_arg_pos = sorted(results, key=_arg_pos)
return [value for _, _, value in results_sorted_by_arg_pos]
map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
if map_index_whens:
query = query.order_by(case(map_index_whens, value=XCom.map_index))
else:
query = query.order_by(XCom.map_index)
return LazyXComAccess.build_from_xcom_query(query)

@provide_session
def get_num_running_task_instances(self, session: Session) -> int:
Expand Down
43 changes: 22 additions & 21 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
import datetime
import inspect
import itertools
import json
import logging
import pickle
Expand Down Expand Up @@ -496,7 +497,9 @@ def get_many(
elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)

if is_container(map_indexes):
if isinstance(map_indexes, range) and abs(map_indexes.step) == 1:
query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
elif is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)
Expand Down Expand Up @@ -697,30 +700,28 @@ class LazyXComAccess(collections.abc.Sequence):
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
:meta private:
"""

dag_id: str
run_id: str
task_id: str
_query: Query = attr.ib(repr=False)
_len: int | None = attr.ib(init=False, repr=False, default=None)
_query: Query
_len: int | None = attr.ib(init=False, default=None)

@classmethod
def build_from_single_xcom(cls, first: XCom, query: Query) -> LazyXComAccess:
return cls(
dag_id=first.dag_id,
run_id=first.run_id,
task_id=first.task_id,
query=query.with_entities(XCom.value)
.filter(
XCom.run_id == first.run_id,
XCom.task_id == first.task_id,
XCom.dag_id == first.dag_id,
XCom.map_index >= 0,
)
.order_by(None)
.order_by(XCom.map_index.asc()),
)
def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
return cls(query=query.with_entities(XCom.value))

def __repr__(self) -> str:
return f"LazyXComAccess([{len(self)} items])"

def __str__(self) -> str:
return str(list(self))

def __eq__(self, other: Any) -> bool:
if isinstance(other, (list, LazyXComAccess)):
z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
return all(x == y for x, y in z)
return NotImplemented

def __len__(self):
if self._len is None:
Expand Down
27 changes: 11 additions & 16 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,22 @@
UnmappableXComTypePushed,
XComForMappingNotPushed,
)
from airflow.models import (
DAG,
Connection,
DagBag,
DagRun,
Pool,
RenderedTaskInstanceFields,
TaskInstance as TI,
TaskReschedule,
Variable,
XCom,
)
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated
from airflow.models.param import process_params
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.variable import Variable
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
Expand Down Expand Up @@ -3522,10 +3518,9 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v

# Simply pulling the joined XCom value should not deserialize.
joined = ti_2.xcom_pull("task_1", session=session)
assert isinstance(joined, LazyXComAccess)
assert mock_deserialize_value.call_count == 0

assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')"

# Only when we go through the iterable does deserialization happen.
it = iter(joined)
assert next(it) == "a"
Expand Down

0 comments on commit 72b1c2f

Please sign in to comment.