Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions airflow-core/src/airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.trigger import Trigger
from airflow.serialization.definitions.assets import (
SerializedAsset,
Expand All @@ -75,6 +76,7 @@
from sqlalchemy.sql import Select

from airflow.models.dagwarning import DagWarning
from airflow.models.serialized_dag import DagWriteMetadata
from airflow.typing_compat import Self, Unpack

AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias)
Expand Down Expand Up @@ -256,15 +258,18 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se


def _serialize_dag_capturing_errors(
dag: LazyDeserializedDAG, bundle_name, session: Session, bundle_version: str | None
dag: LazyDeserializedDAG,
bundle_name,
session: Session,
bundle_version: str | None,
_prefetched: DagWriteMetadata | None = None,
):
"""
Try to serialize the dag to the DB, but make a note of any errors.

We can't place them directly in import_errors, as this may be retried, and work the next time
"""
from airflow.models.dagcode import DagCode
from airflow.models.serialized_dag import SerializedDagModel

# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate.
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
Expand All @@ -279,10 +284,11 @@ def _serialize_dag_capturing_errors(
bundle_version=bundle_version,
min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
session=session,
_prefetched=_prefetched,
)
if not dag_was_updated:
# Check and update DagCode
DagCode.update_source_code(dag.dag_id, dag.fileloc)
DagCode.update_source_code(dag.dag_id, dag.fileloc, session=session)
if "FabAuthManager" in conf.get("core", "auth_manager"):
_sync_dag_perms(dag, session=session)

Expand Down Expand Up @@ -473,6 +479,13 @@ def update_dag_parsing_results_in_db(
SerializedDAG.bulk_write_to_db(
bundle_name, bundle_version, dags, parse_duration, session=session
)
# Bulk prefetch metadata for all DAGs to avoid the standard per-DAG
# metadata lookups in write_dag. This replaces the update-interval,
# hash, and version queries with 2 bulk queries total; DAGs with
# deadlines may still do an additional lookup for deadline UUID reuse.
prefetched_metadata = SerializedDagModel._prefetch_dag_write_metadata(
[dag.dag_id for dag in dags], session=session
)
# Write Serialized DAGs to DB, capturing errors
for dag in dags:
serialize_errors.extend(
Expand All @@ -481,6 +494,7 @@ def update_dag_parsing_results_in_db(
bundle_name=bundle_name,
bundle_version=bundle_version,
session=session,
_prefetched=prefetched_metadata.get(dag.dag_id),
)
)
except OperationalError:
Expand Down Expand Up @@ -526,6 +540,7 @@ def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
.options(joinedload(DagModel.schedule_asset_references))
.options(joinedload(DagModel.schedule_asset_alias_references))
.options(joinedload(DagModel.task_outlet_asset_references))
.options(joinedload(DagModel.dag_owner_links))
),
of=DagModel,
session=session,
Expand Down
102 changes: 85 additions & 17 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import zlib
from collections.abc import Callable, Iterable, Iterator, Sequence
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, NamedTuple
from uuid import UUID

import uuid6
Expand Down Expand Up @@ -70,6 +70,14 @@
_COMPRESS_SERIALIZED_DAGS = conf.getboolean("core", "compress_serialized_dags", fallback=False)


class DagWriteMetadata(NamedTuple):
"""Pre-fetched metadata for write_dag to avoid per-DAG queries."""

last_updated: datetime | None
dag_hash: str | None
dag_version: DagVersion | None


class _DagDependenciesResolver:
"""Resolver that resolves dag dependencies to include asset id and assets link to asset aliases."""

Expand Down Expand Up @@ -508,6 +516,70 @@ def _create_deadline_alert_records(
)
serialized_dag.deadline_alerts.append(alert)

@classmethod
def _prefetch_dag_write_metadata(
cls, dag_ids: Iterable[str], *, session: Session
) -> dict[str, DagWriteMetadata]:
"""
Bulk-fetch metadata needed by write_dag for multiple DAGs in two queries.

Instead of running 3 SELECTs per DAG in write_dag (update interval check,
hash comparison, version fetch), this fetches all needed data upfront.

:param dag_ids: DAG IDs to prefetch metadata for
:param session: ORM Session
:returns: dict mapping dag_id to DagWriteMetadata
"""
dag_id_list = list(set(dag_ids))
if not dag_id_list:
return {}

# Fetch latest serialized_dag (last_updated, dag_hash) per dag_id
# using a window function to pick the most recent row.
sd_subq = (
select(
cls.dag_id.label("dag_id"),
cls.last_updated.label("last_updated"),
cls.dag_hash.label("dag_hash"),
func.row_number().over(partition_by=cls.dag_id, order_by=cls.created_at.desc()).label("rn"),
)
.where(cls.dag_id.in_(dag_id_list))
.subquery()
)
sd_rows = session.execute(
select(sd_subq.c.dag_id, sd_subq.c.last_updated, sd_subq.c.dag_hash).where(sd_subq.c.rn == 1)
).all()
sd_by_dag_id: dict[str, tuple[datetime, str]] = {
row.dag_id: (row.last_updated, row.dag_hash) for row in sd_rows
}

# Fetch latest DagVersion per dag_id using a window function,
# matching the original write_dag ordering (ORDER BY created_at DESC).
dv_subq = (
select(
DagVersion.id.label("id"),
DagVersion.dag_id.label("dag_id"),
func.row_number()
.over(partition_by=DagVersion.dag_id, order_by=DagVersion.created_at.desc())
.label("rn"),
)
.where(DagVersion.dag_id.in_(dag_id_list))
.subquery()
)
dag_versions = session.scalars(
select(DagVersion).join(dv_subq, DagVersion.id == dv_subq.c.id).where(dv_subq.c.rn == 1)
).all()
dv_by_dag_id: dict[str, DagVersion] = {dv.dag_id: dv for dv in dag_versions}

return {
dag_id: DagWriteMetadata(
last_updated=sd_by_dag_id[dag_id][0] if dag_id in sd_by_dag_id else None,
dag_hash=sd_by_dag_id[dag_id][1] if dag_id in sd_by_dag_id else None,
dag_version=dv_by_dag_id.get(dag_id),
)
for dag_id in dag_id_list
}

@classmethod
@provide_session
def write_dag(
Expand All @@ -517,6 +589,7 @@ def write_dag(
bundle_version: str | None = None,
min_update_interval: int | None = None,
session: Session = NEW_SESSION,
_prefetched: DagWriteMetadata | None = None,
) -> bool:
"""
Serialize a DAG and writes it into database.
Expand All @@ -529,33 +602,28 @@ def write_dag(
:param bundle_version: bundle version of the DAG
:param min_update_interval: minimal interval in seconds to update serialized DAG
:param session: ORM Session
:param _prefetched: Pre-fetched metadata to skip per-DAG queries; used by bulk callers

:returns: Boolean indicating if the DAG was written to the DB
"""
if _prefetched is None:
_prefetched = cls._prefetch_dag_write_metadata([dag.dag_id], session=session).get(
dag.dag_id, DagWriteMetadata(last_updated=None, dag_hash=None, dag_version=None)
)

# Checks if (Current Time - Time when the DAG was written to DB) < min_update_interval
# If Yes, does nothing
# If No or the DAG does not exists, updates / writes Serialized DAG to DB
if min_update_interval is not None:
if session.scalar(
select(literal(True))
.where(
cls.dag_id == dag.dag_id,
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
)
.select_from(cls)
if (
_prefetched.last_updated is not None
and (timezone.utcnow() - timedelta(seconds=min_update_interval)) < _prefetched.last_updated
):
return False

log.debug("Checking if DAG (%s) changed", dag.dag_id)
serialized_dag_hash = session.scalars(
select(cls.dag_hash).where(cls.dag_id == dag.dag_id).order_by(cls.created_at.desc())
).first()
dag_version = session.scalar(
select(DagVersion)
.where(DagVersion.dag_id == dag.dag_id)
.order_by(DagVersion.created_at.desc())
.limit(1)
)
serialized_dag_hash = _prefetched.dag_hash
dag_version = _prefetched.dag_version

if dag.data.get("dag", {}).get("deadline"):
# Try to reuse existing deadline UUIDs if the deadline definitions haven't changed.
Expand Down
1 change: 1 addition & 0 deletions airflow-core/tests/unit/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def test_sync_to_db_is_retried(
bundle_version=None,
min_update_interval=mock.ANY,
session=mock_session,
_prefetched=mock.ANY,
),
]
)
Expand Down
42 changes: 42 additions & 0 deletions airflow-core/tests/unit/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,48 @@ def test_min_update_interval_is_respected(self, provide_interval, new_task, shou
)
assert did_write is should_write

def test_prefetch_dag_write_metadata_multiple_dags(self, dag_maker, session):
"""Test that _prefetch_dag_write_metadata returns correct metadata for multiple DAGs."""
with dag_maker("prefetch_multi_dag1"):
EmptyOperator(task_id="task1")
with dag_maker("prefetch_multi_dag2"):
EmptyOperator(task_id="task1")

result = SDM._prefetch_dag_write_metadata(
["prefetch_multi_dag1", "prefetch_multi_dag2"], session=session
)

assert len(result) == 2
for dag_id in ("prefetch_multi_dag1", "prefetch_multi_dag2"):
metadata = result[dag_id]
assert metadata.last_updated is not None
assert metadata.dag_hash is not None
assert metadata.dag_version is not None
assert metadata.dag_version.dag_id == dag_id

def test_prefetch_dag_write_metadata_returns_latest_version(self, dag_maker, session):
"""Test that _prefetch_dag_write_metadata returns the latest DagVersion."""
with dag_maker("prefetch_version_dag") as dag:
PythonOperator(task_id="task1", python_callable=lambda: None)
# Create a dagrun so that writing a changed DAG creates a new version
dag_maker.create_dagrun(run_id="run1", logical_date=pendulum.datetime(2025, 1, 1))

# Modify the DAG (add a task) and write again to create version 2
PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker")

assert (
session.scalar(
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id)
)
== 2
)

result = SDM._prefetch_dag_write_metadata([dag.dag_id], session=session)
metadata = result[dag.dag_id]
assert metadata.dag_version is not None
assert metadata.dag_version.version_number == 2

def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(self, dag_maker, session):
"""Test that new dag_version is created if bundle_name changes but DAG is unchanged."""
# Create and write initial DAG
Expand Down
Loading