Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions airflow-core/src/airflow/models/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections.abc import Iterable
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import sqlalchemy as sa
import structlog
Expand Down Expand Up @@ -246,6 +246,40 @@ def _get_latest_dag_run_row_query(*, dag_id: str, info: DagRunInfo):
return stmt.limit(1)


def _fetch_latest_dag_runs_for_infos(
*, session: Session, dag_id: str, infos: list[DagRunInfo]
) -> dict[tuple[Any, Any], Any]:
"""
Pre-fetch the latest existing DagRun for each (logical_date, partition_key) in ``infos``.

Replaces the per-info SELECT issued by ``_get_latest_dag_run_row_query`` when iterating
a backfill range. Returns ``{(logical_date, partition_key): DagRun}``; missing keys mean
no existing run for that slot.
"""
from airflow.models import DagRun

if not infos:
return {}

logical_dates = {i.logical_date for i in infos if i.logical_date is not None}
partition_keys = {i.partition_key for i in infos if i.partition_key is not None}

stmt = select(DagRun).where(DagRun.dag_id == dag_id)
if logical_dates:
stmt = stmt.where(DagRun.logical_date.in_(logical_dates))
if partition_keys:
stmt = stmt.where(DagRun.partition_key.in_(partition_keys))
# Match the per-info ORDER BY so the first row encountered for each key is the latest.
stmt = stmt.order_by(DagRun.start_date.is_(None), DagRun.start_date.desc())

out: dict[tuple[Any, Any], Any] = {}
for dr in session.scalars(stmt):
key = (dr.logical_date, dr.partition_key)
if key not in out:
out[key] = dr
return out


def _get_dag_run_no_create_reason(dr, reprocess_behavior: ReprocessBehavior) -> str | None:
non_create_reason = None
if dr.state not in (DagRunState.SUCCESS, DagRunState.FAILED):
Expand Down Expand Up @@ -318,12 +352,11 @@ def _do_dry_run(
to_date=to_date,
reverse=reverse,
)
existing_runs = _fetch_latest_dag_runs_for_infos(session=session, dag_id=dag_id, infos=dagrun_info_list)
for info in dagrun_info_list:
if TYPE_CHECKING:
assert info.logical_date
dr = session.scalar(
statement=_get_latest_dag_run_row_query(dag_id=dag_id, info=info),
)
dr = existing_runs.get((info.logical_date, info.partition_key))
if dr:
non_create_reason = _get_dag_run_no_create_reason(dr, reprocess_behavior)
if not non_create_reason:
Expand All @@ -342,12 +375,13 @@ def _create_backfill_dag_run_non_partitioned(
backfill_sort_ordinal: int,
triggering_user_name: str | None,
run_on_latest_version: bool,
existing_runs: dict[tuple[Any, Any], DagRun],
session: Session,
) -> None:
from airflow.models.dagrun import DagRun

with session.begin_nested() as nested:
dr = session.scalar(_get_latest_dag_run_row_query(dag_id=dag.dag_id, info=info))
dr = existing_runs.get((info.logical_date, info.partition_key))
if dr:
non_create_reason = _get_dag_run_no_create_reason(dr, reprocess_behavior)
if non_create_reason:
Expand Down Expand Up @@ -453,14 +487,14 @@ def _create_backfill_dag_run_partitioned(
dag_run_conf: dict | None,
backfill_sort_ordinal: int,
triggering_user_name: str | None,
existing_runs: dict[tuple[Any, Any], DagRun],
session: Session,
) -> None:
# Partitioned backfills don't currently reprocess existing runs — if a run exists
# for this partition, it's recorded as skipped via exception_reason rather than
# cleared and re-queued. As a result, this function never calls ``_handle_clear_run``
# and therefore doesn't need to forward ``dag_run_conf`` for the reprocess path.
stmt = _get_latest_dag_run_row_query(dag_id=dag.dag_id, info=info)
dr = session.scalar(stmt)
dr = existing_runs.get((info.logical_date, info.partition_key))
if dr:
non_create_reason = _get_dag_run_no_create_reason(dr, reprocess_behavior)
if non_create_reason:
Expand Down Expand Up @@ -679,6 +713,9 @@ def _create_runs_partitioned(
for info in dagrun_info_list:
if not info.partition_key:
raise RuntimeError("Expected all Dag run infos to have partition key and no logical date.")
existing_runs = _fetch_latest_dag_runs_for_infos(
session=session, dag_id=dag.dag_id, infos=dagrun_info_list
)
for backfill_sort_ordinal, info in enumerate(dagrun_info_list, start=1):
_create_backfill_dag_run_partitioned(
dag=dag,
Expand All @@ -688,6 +725,7 @@ def _create_runs_partitioned(
reprocess_behavior=ReprocessBehavior(br.reprocess_behavior),
backfill_sort_ordinal=backfill_sort_ordinal,
triggering_user_name=br.triggering_user_name,
existing_runs=existing_runs,
session=session,
)
log.info(
Expand All @@ -710,6 +748,9 @@ def _create_runs_non_partitioned(
if info.partition_key or not info.logical_date:
raise RuntimeError("Expected all Dag run infos to have logical date and no partition key.")

existing_runs = _fetch_latest_dag_runs_for_infos(
session=session, dag_id=dag.dag_id, infos=dagrun_info_list
)
for backfill_sort_ordinal, info in enumerate(dagrun_info_list, start=1):
_create_backfill_dag_run_non_partitioned(
dag=dag,
Expand All @@ -720,6 +761,7 @@ def _create_runs_non_partitioned(
backfill_sort_ordinal=backfill_sort_ordinal,
triggering_user_name=br.triggering_user_name,
run_on_latest_version=run_on_latest_version,
existing_runs=existing_runs,
session=session,
)
log.info(
Expand Down
26 changes: 26 additions & 0 deletions airflow-core/tests/unit/models/test_backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from airflow.utils.strings import get_random_string
from airflow.utils.types import DagRunTriggeredByType, DagRunType

from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import (
clear_db_backfills,
clear_db_dags,
Expand Down Expand Up @@ -161,6 +162,31 @@ def test_create_backfill_simple(reverse, existing, dag_maker, session):
assert all(x.conf == expected_run_conf for x in dag_runs)


def test_create_backfill_does_not_re_query_each_info(dag_maker, session):
# Regression guard for the N+1 fix in _create_runs_non_partitioned: existence
# checks for each DagRunInfo were previously a per-info SELECT, but are now
# batched into one SELECT in _fetch_latest_dag_runs_for_infos. Backfilling N
# daily runs should not issue a SELECT per info on top of the unavoidable
# per-info INSERTs.
with dag_maker(schedule="@daily") as dag:
PythonOperator(task_id="hi", python_callable=print)
session.commit()

# Backfilling 20 daily runs issues 195 queries with the batched fetch.
# Reverting to a per-info SELECT in the loop pushes the count to 215
# (one extra SELECT per info), which would fail this assertion.
with assert_queries_count(195):
_create_backfill(
dag_id=dag.dag_id,
from_date=pendulum.parse("2021-01-01"),
to_date=pendulum.parse("2021-01-21"), # 20 daily runs
max_active_runs=2,
reverse=False,
triggering_user_name="pytest",
dag_run_conf={},
)


@pytest.mark.parametrize("run_on_latest_version", [True, False])
def test_create_backfill_clear_existing_bundle_version(dag_maker, session, run_on_latest_version):
"""
Expand Down
Loading