From 4cc8dd078a83411e80223e224f3f57d3c121ddfb Mon Sep 17 00:00:00 2001 From: alvinttang Date: Sat, 25 Apr 2026 22:36:33 +0800 Subject: [PATCH] fix(scheduler): invalidate DBDagBag cache when SerializedDagModel updates in-place SerializedDagModel.write_dag updates the serialized DAG in-place under the same dag_version_id when the version has no associated task instances (added in #45524). Long-lived DBDagBag instances such as the scheduler's self.scheduler_dag_bag cache deserialized SerializedDAG objects keyed only by dag_version_id, with no staleness check. Once an in-place update happens, the scheduler keeps returning the stale cached DAG until the process is restarted - newly added tasks are marked "removed" on every scheduling tick, and removed tasks keep getting scheduled. Cache the dag_hash alongside the deserialized DAG and re-check it against the DB on every cache hit via a single-column lookup. On hash mismatch, drop the cache entry and reload the full row. The extra query is a tiny indexed lookup on the unique dag_version_id, far cheaper than the previously skipped JSON deserialization on a true cache hit. Closes: #65696 --- airflow-core/src/airflow/models/dagbag.py | 50 +++++++++++--- airflow-core/tests/unit/models/test_dagbag.py | 67 +++++++++++++++++-- 2 files changed, 101 insertions(+), 16 deletions(-) diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index 63059884a9502..bc1af5586b669 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -67,7 +67,11 @@ def __init__( :param cache_ttl: Time-to-live for cache entries in seconds. If None or 0, no TTL (LRU only). """ self.load_op_links = load_op_links - self._dags: MutableMapping[UUID | str, SerializedDAG] = {} + # Cache value is (deserialized_dag, dag_hash). The hash is checked against the DB + # on every lookup to detect in-place updates of the SerializedDagModel under the + # same dag_version_id (see ``SerializedDagModel.write_dag`` fast-path for versions + # with no associated task instances). See issue #65696. + self._dags: MutableMapping[UUID | str, tuple[SerializedDAG, str]] = {} self._use_cache = False # Initialize bounded cache if cache_size is provided and > 0 @@ -90,21 +94,44 @@ def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None: if not dag: return None with self._lock: - self._dags[serdag.dag_version_id] = dag + self._dags[serdag.dag_version_id] = (dag, serdag.dag_hash) cache_size = len(self._dags) if self._use_cache: Stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1) return dag + def _current_db_hash(self, version_id: UUID | str, session: Session) -> str | None: + """ + Fetch the current ``dag_hash`` for a ``dag_version_id`` from the DB. + + Cheap single-column lookup used to detect stale cache entries when + ``SerializedDagModel`` rows are updated in-place under the same + ``dag_version_id`` (issue #65696). + """ + from airflow.models.serialized_dag import SerializedDagModel + + return session.scalar( + select(SerializedDagModel.dag_hash).where(SerializedDagModel.dag_version_id == version_id) + ) + def _get_dag(self, version_id: UUID | str, session: Session) -> SerializedDAG | None: # Check cache first with self._lock: - dag = self._dags.get(version_id) - - if dag: - if self._use_cache: - Stats.incr("api_server.dag_bag.cache_hit") - return dag + cached = self._dags.get(version_id) + + if cached is not None: + cached_dag, cached_hash = cached + # Verify the cache entry is still fresh. ``SerializedDagModel.write_dag`` + # may update the row in-place under the same dag_version_id, leaving our + # cached deserialized DAG stale (issue #65696). + current_hash = self._current_db_hash(version_id, session) + if current_hash is not None and current_hash == cached_hash: + if self._use_cache: + Stats.incr("api_server.dag_bag.cache_hit") + return cached_dag + # Stale or removed: drop the entry and fall through to a fresh load. + with self._lock: + self._dags.pop(version_id, None) dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)]) if not dag_version: @@ -117,9 +144,10 @@ def _get_dag(self, version_id: UUID | str, session: Session) -> SerializedDAG | # counting a single lookup as both a miss and a hit. if self._use_cache: with self._lock: - if dag := self._dags.get(version_id): - Stats.incr("api_server.dag_bag.cache_hit") - return dag + cached = self._dags.get(version_id) + if cached is not None and cached[1] == serdag.dag_hash: + Stats.incr("api_server.dag_bag.cache_hit") + return cached[0] Stats.incr("api_server.dag_bag.cache_miss") return self._read_dag(serdag) diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index 397560d327cea..32b495e59f2ab 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -44,16 +44,17 @@ def setup_method(self): self.session = MagicMock() def test__read_dag_stores_and_returns_dag(self): - """It should store the SerializedDAG in _dags and return it.""" + """It should store the (SerializedDAG, dag_hash) tuple in _dags and return the dag.""" mock_dag = MagicMock(spec=SerializedDAG) mock_serdag = MagicMock(spec=SerializedDagModel) mock_serdag.dag = mock_dag mock_serdag.dag_version_id = "v1" + mock_serdag.dag_hash = "h1" result = self.db_dag_bag._read_dag(mock_serdag) assert result == mock_dag - assert self.db_dag_bag._dags["v1"] == mock_dag + assert self.db_dag_bag._dags["v1"] == (mock_dag, "h1") assert mock_serdag.load_op_links is True def test__read_dag_returns_none_when_no_dag(self): @@ -83,9 +84,14 @@ def test_get_dag_fetches_from_db_on_miss(self): assert result == mock_dag def test_get_dag_returns_cached_on_hit(self): - """It should return cached DAG without querying DB.""" + """It should return cached DAG without re-loading the full row from DB. + + A cheap ``dag_hash`` lookup is still issued to detect in-place updates + (issue #65696), but the heavy ``session.get(DagVersion, ...)`` is skipped. + """ mock_dag = MagicMock(spec=SerializedDAG) - self.db_dag_bag._dags["v1"] = mock_dag + self.db_dag_bag._dags["v1"] = (mock_dag, "h1") + self.session.scalar.return_value = "h1" result = self.db_dag_bag.get_dag("v1", session=self.session) @@ -100,6 +106,54 @@ def test_get_dag_returns_none_when_not_found(self): assert result is None + def test_get_dag_invalidates_cache_when_dag_hash_changes_in_place(self): + """Regression test for issue #65696. + + ``SerializedDagModel.write_dag`` updates the serialized DAG in-place under the + same ``dag_version_id`` when the version has no associated task instances. Long-lived + ``DBDagBag`` instances (e.g. the scheduler's ``self.scheduler_dag_bag``) cache by + ``dag_version_id`` and previously had no staleness check, so the scheduler kept + returning the stale cached ``SerializedDAG`` until restart. + + The fix: cache the ``dag_hash`` alongside the deserialized DAG and compare against + the current DB ``dag_hash`` on each cache lookup. A mismatch triggers re-deserialization. + """ + from airflow.models.serialized_dag import SerializedDagModel + + # First read: cache the original DAG with its hash + original_dag = MagicMock(spec=SerializedDAG) + original_serdag = MagicMock(spec=SerializedDagModel) + original_serdag.dag = original_dag + original_serdag.dag_version_id = "v1" + original_serdag.dag_hash = "hash_original" + original_dag_version = MagicMock() + original_dag_version.serialized_dag = original_serdag + self.session.get.return_value = original_dag_version + + first = self.db_dag_bag.get_dag("v1", session=self.session) + assert first is original_dag + + # Simulate write_dag in-place update: same dag_version_id, new content + hash + updated_dag = MagicMock(spec=SerializedDAG) + updated_serdag = MagicMock(spec=SerializedDagModel) + updated_serdag.dag = updated_dag + updated_serdag.dag_version_id = "v1" + updated_serdag.dag_hash = "hash_updated" + updated_dag_version = MagicMock() + updated_dag_version.serialized_dag = updated_serdag + + # Configure session: scalar() returns the new hash; get() returns the new full row. + self.session.scalar.return_value = "hash_updated" + self.session.get.return_value = updated_dag_version + + second = self.db_dag_bag.get_dag("v1", session=self.session) + + # Should NOT be the stale cached original + assert second is updated_dag, ( + "DBDagBag returned stale cached SerializedDAG after in-place update of dag_version " + "(see issue #65696)" + ) + class TestDBDagBagCache: """Tests for DBDagBag optional caching behavior.""" @@ -257,7 +311,10 @@ def test_cache_hit_metric_emitted(self, mock_stats): """Test that cache hit metric is emitted when caching is enabled.""" dag_bag = DBDagBag(cache_size=10, cache_ttl=60) mock_session = MagicMock() - dag_bag._dags["test_version"] = MagicMock() + mock_dag = MagicMock() + dag_bag._dags["test_version"] = (mock_dag, "h") + # Hash matches so the cache entry is considered fresh. + mock_session.scalar.return_value = "h" dag_bag._get_dag("test_version", mock_session)