From 18b8eaefd0ca9c98f146a71e42f6af4c9bd4bbc5 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 13 Apr 2026 08:18:58 +0100 Subject: [PATCH] Reduce per-DAG queries during DAG serialization with bulk prefetch (#64929) * Reduce per-DAG queries during DAG serialization with bulk prefetch Replaces 3 SELECTs per DAG in write_dag (update interval check, hash comparison, version fetch) with 2 bulk queries via a new _prefetch_dag_write_metadata classmethod. Also fixes DagCode.update_source_code to reuse the caller's session and eagerly loads dag_owner_links to prevent N+1 queries. * fixup! Reduce per-DAG queries during DAG serialization with bulk prefetch * fixup! fixup! Reduce per-DAG queries during DAG serialization with bulk prefetch (cherry picked from commit ef0004035edb27507c6899b11bd24166ce3a08c0) --- .../src/airflow/dag_processing/collection.py | 21 +++- .../src/airflow/models/serialized_dag.py | 102 +++++++++++++++--- .../unit/dag_processing/test_collection.py | 1 + .../tests/unit/models/test_serialized_dag.py | 42 ++++++++ 4 files changed, 146 insertions(+), 20 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index 96f3c89f8623a..06e4900d816e5 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -53,6 +53,7 @@ from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarningType from airflow.models.errors import ParseImportError +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.trigger import Trigger from airflow.serialization.definitions.assets import ( SerializedAsset, @@ -75,6 +76,7 @@ from sqlalchemy.sql import Select from airflow.models.dagwarning import DagWarning + from airflow.models.serialized_dag import DagWriteMetadata from airflow.typing_compat import Self, Unpack AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias) @@ -256,7 +258,11 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se def _serialize_dag_capturing_errors( - dag: LazyDeserializedDAG, bundle_name, session: Session, bundle_version: str | None + dag: LazyDeserializedDAG, + bundle_name, + session: Session, + bundle_version: str | None, + _prefetched: DagWriteMetadata | None = None, ): """ Try to serialize the dag to the DB, but make a note of any errors. @@ -264,7 +270,6 @@ def _serialize_dag_capturing_errors( We can't place them directly in import_errors, as this may be retried, and work the next time """ from airflow.models.dagcode import DagCode - from airflow.models.serialized_dag import SerializedDagModel # Updating serialized DAG can not be faster than a minimum interval to reduce database write rate. MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint( @@ -279,10 +284,11 @@ def _serialize_dag_capturing_errors( bundle_version=bundle_version, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session, + _prefetched=_prefetched, ) if not dag_was_updated: # Check and update DagCode - DagCode.update_source_code(dag.dag_id, dag.fileloc) + DagCode.update_source_code(dag.dag_id, dag.fileloc, session=session) if "FabAuthManager" in conf.get("core", "auth_manager"): _sync_dag_perms(dag, session=session) @@ -473,6 +479,13 @@ def update_dag_parsing_results_in_db( SerializedDAG.bulk_write_to_db( bundle_name, bundle_version, dags, parse_duration, session=session ) + # Bulk prefetch metadata for all DAGs to avoid the standard per-DAG + # metadata lookups in write_dag. This replaces the update-interval, + # hash, and version queries with 2 bulk queries total; DAGs with + # deadlines may still do an additional lookup for deadline UUID reuse. + prefetched_metadata = SerializedDagModel._prefetch_dag_write_metadata( + [dag.dag_id for dag in dags], session=session + ) # Write Serialized DAGs to DB, capturing errors for dag in dags: serialize_errors.extend( @@ -481,6 +494,7 @@ def update_dag_parsing_results_in_db( bundle_name=bundle_name, bundle_version=bundle_version, session=session, + _prefetched=prefetched_metadata.get(dag.dag_id), ) ) except OperationalError: @@ -526,6 +540,7 @@ def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]: .options(joinedload(DagModel.schedule_asset_references)) .options(joinedload(DagModel.schedule_asset_alias_references)) .options(joinedload(DagModel.task_outlet_asset_references)) + .options(joinedload(DagModel.dag_owner_links)) ), of=DagModel, session=session, diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 23b93f2dae5aa..0aaf295a58ff4 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -23,7 +23,7 @@ import zlib from collections.abc import Callable, Iterable, Iterator, Sequence from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, NamedTuple from uuid import UUID import uuid6 @@ -70,6 +70,14 @@ _COMPRESS_SERIALIZED_DAGS = conf.getboolean("core", "compress_serialized_dags", fallback=False) +class DagWriteMetadata(NamedTuple): + """Pre-fetched metadata for write_dag to avoid per-DAG queries.""" + + last_updated: datetime | None + dag_hash: str | None + dag_version: DagVersion | None + + class _DagDependenciesResolver: """Resolver that resolves dag dependencies to include asset id and assets link to asset aliases.""" @@ -508,6 +516,70 @@ def _create_deadline_alert_records( ) serialized_dag.deadline_alerts.append(alert) + @classmethod + def _prefetch_dag_write_metadata( + cls, dag_ids: Iterable[str], *, session: Session + ) -> dict[str, DagWriteMetadata]: + """ + Bulk-fetch metadata needed by write_dag for multiple DAGs in two queries. + + Instead of running 3 SELECTs per DAG in write_dag (update interval check, + hash comparison, version fetch), this fetches all needed data upfront. + + :param dag_ids: DAG IDs to prefetch metadata for + :param session: ORM Session + :returns: dict mapping dag_id to DagWriteMetadata + """ + dag_id_list = list(set(dag_ids)) + if not dag_id_list: + return {} + + # Fetch latest serialized_dag (last_updated, dag_hash) per dag_id + # using a window function to pick the most recent row. + sd_subq = ( + select( + cls.dag_id.label("dag_id"), + cls.last_updated.label("last_updated"), + cls.dag_hash.label("dag_hash"), + func.row_number().over(partition_by=cls.dag_id, order_by=cls.created_at.desc()).label("rn"), + ) + .where(cls.dag_id.in_(dag_id_list)) + .subquery() + ) + sd_rows = session.execute( + select(sd_subq.c.dag_id, sd_subq.c.last_updated, sd_subq.c.dag_hash).where(sd_subq.c.rn == 1) + ).all() + sd_by_dag_id: dict[str, tuple[datetime, str]] = { + row.dag_id: (row.last_updated, row.dag_hash) for row in sd_rows + } + + # Fetch latest DagVersion per dag_id using a window function, + # matching the original write_dag ordering (ORDER BY created_at DESC). + dv_subq = ( + select( + DagVersion.id.label("id"), + DagVersion.dag_id.label("dag_id"), + func.row_number() + .over(partition_by=DagVersion.dag_id, order_by=DagVersion.created_at.desc()) + .label("rn"), + ) + .where(DagVersion.dag_id.in_(dag_id_list)) + .subquery() + ) + dag_versions = session.scalars( + select(DagVersion).join(dv_subq, DagVersion.id == dv_subq.c.id).where(dv_subq.c.rn == 1) + ).all() + dv_by_dag_id: dict[str, DagVersion] = {dv.dag_id: dv for dv in dag_versions} + + return { + dag_id: DagWriteMetadata( + last_updated=sd_by_dag_id[dag_id][0] if dag_id in sd_by_dag_id else None, + dag_hash=sd_by_dag_id[dag_id][1] if dag_id in sd_by_dag_id else None, + dag_version=dv_by_dag_id.get(dag_id), + ) + for dag_id in dag_id_list + } + @classmethod @provide_session def write_dag( @@ -517,6 +589,7 @@ def write_dag( bundle_version: str | None = None, min_update_interval: int | None = None, session: Session = NEW_SESSION, + _prefetched: DagWriteMetadata | None = None, ) -> bool: """ Serialize a DAG and writes it into database. @@ -529,33 +602,28 @@ def write_dag( :param bundle_version: bundle version of the DAG :param min_update_interval: minimal interval in seconds to update serialized DAG :param session: ORM Session + :param _prefetched: Pre-fetched metadata to skip per-DAG queries; used by bulk callers :returns: Boolean indicating if the DAG was written to the DB """ + if _prefetched is None: + _prefetched = cls._prefetch_dag_write_metadata([dag.dag_id], session=session).get( + dag.dag_id, DagWriteMetadata(last_updated=None, dag_hash=None, dag_version=None) + ) + # Checks if (Current Time - Time when the DAG was written to DB) < min_update_interval # If Yes, does nothing # If No or the DAG does not exists, updates / writes Serialized DAG to DB if min_update_interval is not None: - if session.scalar( - select(literal(True)) - .where( - cls.dag_id == dag.dag_id, - (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated, - ) - .select_from(cls) + if ( + _prefetched.last_updated is not None + and (timezone.utcnow() - timedelta(seconds=min_update_interval)) < _prefetched.last_updated ): return False log.debug("Checking if DAG (%s) changed", dag.dag_id) - serialized_dag_hash = session.scalars( - select(cls.dag_hash).where(cls.dag_id == dag.dag_id).order_by(cls.created_at.desc()) - ).first() - dag_version = session.scalar( - select(DagVersion) - .where(DagVersion.dag_id == dag.dag_id) - .order_by(DagVersion.created_at.desc()) - .limit(1) - ) + serialized_dag_hash = _prefetched.dag_hash + dag_version = _prefetched.dag_version if dag.data.get("dag", {}).get("deadline"): # Try to reuse existing deadline UUIDs if the deadline definitions haven't changed. diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index 6a0aef00eaa8e..1792c9b38c0ae 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -537,6 +537,7 @@ def test_sync_to_db_is_retried( bundle_version=None, min_update_interval=mock.ANY, session=mock_session, + _prefetched=mock.ANY, ), ] ) diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index de0464dea8c09..2185635590eff 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -524,6 +524,48 @@ def test_min_update_interval_is_respected(self, provide_interval, new_task, shou ) assert did_write is should_write + def test_prefetch_dag_write_metadata_multiple_dags(self, dag_maker, session): + """Test that _prefetch_dag_write_metadata returns correct metadata for multiple DAGs.""" + with dag_maker("prefetch_multi_dag1"): + EmptyOperator(task_id="task1") + with dag_maker("prefetch_multi_dag2"): + EmptyOperator(task_id="task1") + + result = SDM._prefetch_dag_write_metadata( + ["prefetch_multi_dag1", "prefetch_multi_dag2"], session=session + ) + + assert len(result) == 2 + for dag_id in ("prefetch_multi_dag1", "prefetch_multi_dag2"): + metadata = result[dag_id] + assert metadata.last_updated is not None + assert metadata.dag_hash is not None + assert metadata.dag_version is not None + assert metadata.dag_version.dag_id == dag_id + + def test_prefetch_dag_write_metadata_returns_latest_version(self, dag_maker, session): + """Test that _prefetch_dag_write_metadata returns the latest DagVersion.""" + with dag_maker("prefetch_version_dag") as dag: + PythonOperator(task_id="task1", python_callable=lambda: None) + # Create a dagrun so that writing a changed DAG creates a new version + dag_maker.create_dagrun(run_id="run1", logical_date=pendulum.datetime(2025, 1, 1)) + + # Modify the DAG (add a task) and write again to create version 2 + PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) + SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") + + assert ( + session.scalar( + select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id) + ) + == 2 + ) + + result = SDM._prefetch_dag_write_metadata([dag.dag_id], session=session) + metadata = result[dag.dag_id] + assert metadata.dag_version is not None + assert metadata.dag_version.version_number == 2 + def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(self, dag_maker, session): """Test that new dag_version is created if bundle_name changes but DAG is unchanged.""" # Create and write initial DAG