Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions airflow-core/src/airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(
:param cache_ttl: Time-to-live for cache entries in seconds. If None or 0, no TTL (LRU only).
"""
self.load_op_links = load_op_links
self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
# Cache value is (deserialized_dag, dag_hash). The hash is checked against the DB
# on every lookup to detect in-place updates of the SerializedDagModel under the
# same dag_version_id (see ``SerializedDagModel.write_dag`` fast-path for versions
# with no associated task instances). See issue #65696.
self._dags: MutableMapping[UUID | str, tuple[SerializedDAG, str]] = {}
self._use_cache = False

# Initialize bounded cache if cache_size is provided and > 0
Expand All @@ -90,21 +94,44 @@ def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
if not dag:
return None
with self._lock:
self._dags[serdag.dag_version_id] = dag
self._dags[serdag.dag_version_id] = (dag, serdag.dag_hash)
cache_size = len(self._dags)
if self._use_cache:
Stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
return dag

def _current_db_hash(self, version_id: UUID | str, session: Session) -> str | None:
"""
Fetch the current ``dag_hash`` for a ``dag_version_id`` from the DB.

Cheap single-column lookup used to detect stale cache entries when
``SerializedDagModel`` rows are updated in-place under the same
``dag_version_id`` (issue #65696).
"""
from airflow.models.serialized_dag import SerializedDagModel

return session.scalar(
select(SerializedDagModel.dag_hash).where(SerializedDagModel.dag_version_id == version_id)
)

def _get_dag(self, version_id: UUID | str, session: Session) -> SerializedDAG | None:
# Check cache first
with self._lock:
dag = self._dags.get(version_id)

if dag:
if self._use_cache:
Stats.incr("api_server.dag_bag.cache_hit")
return dag
cached = self._dags.get(version_id)

if cached is not None:
cached_dag, cached_hash = cached
# Verify the cache entry is still fresh. ``SerializedDagModel.write_dag``
# may update the row in-place under the same dag_version_id, leaving our
# cached deserialized DAG stale (issue #65696).
current_hash = self._current_db_hash(version_id, session)
if current_hash is not None and current_hash == cached_hash:
if self._use_cache:
Stats.incr("api_server.dag_bag.cache_hit")
return cached_dag
# Stale or removed: drop the entry and fall through to a fresh load.
with self._lock:
self._dags.pop(version_id, None)

dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)])
if not dag_version:
Expand All @@ -117,9 +144,10 @@ def _get_dag(self, version_id: UUID | str, session: Session) -> SerializedDAG |
# counting a single lookup as both a miss and a hit.
if self._use_cache:
with self._lock:
if dag := self._dags.get(version_id):
Stats.incr("api_server.dag_bag.cache_hit")
return dag
cached = self._dags.get(version_id)
if cached is not None and cached[1] == serdag.dag_hash:
Stats.incr("api_server.dag_bag.cache_hit")
return cached[0]
Stats.incr("api_server.dag_bag.cache_miss")
return self._read_dag(serdag)

Expand Down
67 changes: 62 additions & 5 deletions airflow-core/tests/unit/models/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def setup_method(self):
self.session = MagicMock()

def test__read_dag_stores_and_returns_dag(self):
"""It should store the SerializedDAG in _dags and return it."""
"""It should store the (SerializedDAG, dag_hash) tuple in _dags and return the dag."""
mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_serdag.dag_hash = "h1"

result = self.db_dag_bag._read_dag(mock_serdag)

assert result == mock_dag
assert self.db_dag_bag._dags["v1"] == mock_dag
assert self.db_dag_bag._dags["v1"] == (mock_dag, "h1")
assert mock_serdag.load_op_links is True

def test__read_dag_returns_none_when_no_dag(self):
Expand Down Expand Up @@ -83,9 +84,14 @@ def test_get_dag_fetches_from_db_on_miss(self):
assert result == mock_dag

def test_get_dag_returns_cached_on_hit(self):
"""It should return cached DAG without querying DB."""
"""It should return cached DAG without re-loading the full row from DB.

A cheap ``dag_hash`` lookup is still issued to detect in-place updates
(issue #65696), but the heavy ``session.get(DagVersion, ...)`` is skipped.
"""
mock_dag = MagicMock(spec=SerializedDAG)
self.db_dag_bag._dags["v1"] = mock_dag
self.db_dag_bag._dags["v1"] = (mock_dag, "h1")
self.session.scalar.return_value = "h1"

result = self.db_dag_bag.get_dag("v1", session=self.session)

Expand All @@ -100,6 +106,54 @@ def test_get_dag_returns_none_when_not_found(self):

assert result is None

def test_get_dag_invalidates_cache_when_dag_hash_changes_in_place(self):
"""Regression test for issue #65696.

``SerializedDagModel.write_dag`` updates the serialized DAG in-place under the
same ``dag_version_id`` when the version has no associated task instances. Long-lived
``DBDagBag`` instances (e.g. the scheduler's ``self.scheduler_dag_bag``) cache by
``dag_version_id`` and previously had no staleness check, so the scheduler kept
returning the stale cached ``SerializedDAG`` until restart.

The fix: cache the ``dag_hash`` alongside the deserialized DAG and compare against
the current DB ``dag_hash`` on each cache lookup. A mismatch triggers re-deserialization.
"""
from airflow.models.serialized_dag import SerializedDagModel

# First read: cache the original DAG with its hash
original_dag = MagicMock(spec=SerializedDAG)
original_serdag = MagicMock(spec=SerializedDagModel)
original_serdag.dag = original_dag
original_serdag.dag_version_id = "v1"
original_serdag.dag_hash = "hash_original"
original_dag_version = MagicMock()
original_dag_version.serialized_dag = original_serdag
self.session.get.return_value = original_dag_version

first = self.db_dag_bag.get_dag("v1", session=self.session)
assert first is original_dag

# Simulate write_dag in-place update: same dag_version_id, new content + hash
updated_dag = MagicMock(spec=SerializedDAG)
updated_serdag = MagicMock(spec=SerializedDagModel)
updated_serdag.dag = updated_dag
updated_serdag.dag_version_id = "v1"
updated_serdag.dag_hash = "hash_updated"
updated_dag_version = MagicMock()
updated_dag_version.serialized_dag = updated_serdag

# Configure session: scalar() returns the new hash; get() returns the new full row.
self.session.scalar.return_value = "hash_updated"
self.session.get.return_value = updated_dag_version

second = self.db_dag_bag.get_dag("v1", session=self.session)

# Should NOT be the stale cached original
assert second is updated_dag, (
"DBDagBag returned stale cached SerializedDAG after in-place update of dag_version "
"(see issue #65696)"
)


class TestDBDagBagCache:
"""Tests for DBDagBag optional caching behavior."""
Expand Down Expand Up @@ -257,7 +311,10 @@ def test_cache_hit_metric_emitted(self, mock_stats):
"""Test that cache hit metric is emitted when caching is enabled."""
dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
mock_session = MagicMock()
dag_bag._dags["test_version"] = MagicMock()
mock_dag = MagicMock()
dag_bag._dags["test_version"] = (mock_dag, "h")
# Hash matches so the cache entry is considered fresh.
mock_session.scalar.return_value = "h"

dag_bag._get_dag("test_version", mock_session)

Expand Down
Loading