Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TI.xcom_pull() with explicit task_ids and map_indexes #27699

Merged
merged 5 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.sql.expression import ColumnOperators, case

from airflow import settings
from airflow.compat.functools import cache
Expand Down Expand Up @@ -2350,35 +2350,34 @@ def xcom_pull(
return default
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)

return LazyXComAccess.build_from_single_xcom(first, query)
query = query.order_by(None).order_by(XCom.map_index.asc())
return LazyXComAccess.build_from_xcom_query(query)

# At this point either task_ids or map_indexes is explicitly multi-value.

results = (
(r.task_id, r.map_index, XCom.deserialize_value(r))
for r in query.with_entities(XCom.task_id, XCom.map_index, XCom.value)
)

if task_ids is None:
task_id_pos: dict[str, int] = defaultdict(int)
elif isinstance(task_ids, str):
task_id_pos = {task_ids: 0}
# Order return values to match task_ids and map_indexes ordering.
query = query.order_by(None)
if task_ids is None or isinstance(task_ids, str):
query = query.order_by(XCom.task_id)
else:
task_id_pos = {task_id: i for i, task_id in enumerate(task_ids)}
if map_indexes is None:
map_index_pos: dict[int, int] = defaultdict(int)
elif isinstance(map_indexes, int):
map_index_pos = {map_indexes: 0}
task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
if task_id_whens:
query = query.order_by(case(task_id_whens, value=XCom.task_id))
else:
query = query.order_by(XCom.task_id)
if map_indexes is None or isinstance(map_indexes, int):
query = query.order_by(XCom.map_index)
elif isinstance(map_indexes, range):
order = XCom.map_index
if map_indexes.step < 0:
order = order.desc()
query = query.order_by(order)
else:
map_index_pos = {map_index: i for i, map_index in enumerate(map_indexes)}

def _arg_pos(item: tuple[str, int, Any]) -> tuple[int, int]:
task_id, map_index, _ = item
return task_id_pos[task_id], map_index_pos[map_index]

results_sorted_by_arg_pos = sorted(results, key=_arg_pos)
return [value for _, _, value in results_sorted_by_arg_pos]
map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
if map_index_whens:
query = query.order_by(case(map_index_whens, value=XCom.map_index))
else:
query = query.order_by(XCom.map_index)
return LazyXComAccess.build_from_xcom_query(query)

@provide_session
def get_num_running_task_instances(self, session: Session) -> int:
Expand Down
43 changes: 22 additions & 21 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
import datetime
import inspect
import itertools
import json
import logging
import pickle
Expand Down Expand Up @@ -496,7 +497,9 @@ def get_many(
elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)

if is_container(map_indexes):
if isinstance(map_indexes, range) and abs(map_indexes.step) == 1:
query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
elif is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)
Expand Down Expand Up @@ -697,30 +700,28 @@ class LazyXComAccess(collections.abc.Sequence):
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.

:meta private:
"""

dag_id: str
run_id: str
task_id: str
_query: Query = attr.ib(repr=False)
_len: int | None = attr.ib(init=False, repr=False, default=None)
_query: Query
_len: int | None = attr.ib(init=False, default=None)

@classmethod
def build_from_single_xcom(cls, first: XCom, query: Query) -> LazyXComAccess:
return cls(
dag_id=first.dag_id,
run_id=first.run_id,
task_id=first.task_id,
query=query.with_entities(XCom.value)
.filter(
XCom.run_id == first.run_id,
XCom.task_id == first.task_id,
XCom.dag_id == first.dag_id,
XCom.map_index >= 0,
)
.order_by(None)
.order_by(XCom.map_index.asc()),
)
def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
return cls(query=query.with_entities(XCom.value))

def __repr__(self) -> str:
return f"LazyXComAccess([{len(self)} items])"

def __str__(self) -> str:
return str(list(self))

def __eq__(self, other: Any) -> bool:
if isinstance(other, (list, LazyXComAccess)):
z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
return all(x == y for x, y in z)
return NotImplemented

def __len__(self):
if self._len is None:
Expand Down
27 changes: 11 additions & 16 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,22 @@
UnmappableXComTypePushed,
XComForMappingNotPushed,
)
from airflow.models import (
DAG,
Connection,
DagBag,
DagRun,
Pool,
RenderedTaskInstanceFields,
TaskInstance as TI,
TaskReschedule,
Variable,
XCom,
)
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated
from airflow.models.param import process_params
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.variable import Variable
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
Expand Down Expand Up @@ -3522,10 +3518,9 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v

# Simply pulling the joined XCom value should not deserialize.
joined = ti_2.xcom_pull("task_1", session=session)
assert isinstance(joined, LazyXComAccess)
assert mock_deserialize_value.call_count == 0

assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', task_id='task_1')"

# Only when we go through the iterable does deserialization happen.
it = iter(joined)
assert next(it) == "a"
Expand Down