From fc8f2a2c2a352b16d6768d2de74ef873f936db4e Mon Sep 17 00:00:00 2001 From: Pradeep Kalluri <128097794+kalluripradeep@users.noreply.github.com> Date: Mon, 11 May 2026 19:48:23 +0100 Subject: [PATCH] [v3-2-test] fix(scheduler): catch StaleDataError in verify_integrity to prevent scheduler crash (#64503) Closes #63926 StaleDataError raised by SQLAlchemy's optimistic locking when a concurrent session modifies the same row can cause the scheduler to crash during verify_integrity. Fix by catching StaleDataError alongside IntegrityError in dagrun.verify_integrity() and adding it to the retry exceptions in run_with_db_retries()/retry_db_transaction() so the operation is retried automatically. (cherry picked from commit dcfa2715632de7f665c3eba1b42d2e3084f08361) Co-authored-by: Pradeep Kalluri <128097794+kalluripradeep@users.noreply.github.com> --- airflow-core/newsfragments/64503.bugfix.rst | 1 + airflow-core/src/airflow/models/dagrun.py | 8 +++-- airflow-core/src/airflow/utils/retries.py | 5 ++-- airflow-core/tests/unit/models/test_dagrun.py | 24 +++++++++++++++ airflow-core/tests/unit/utils/test_retries.py | 29 +++++++++++++------ 5 files changed, 54 insertions(+), 13 deletions(-) create mode 100644 airflow-core/newsfragments/64503.bugfix.rst diff --git a/airflow-core/newsfragments/64503.bugfix.rst b/airflow-core/newsfragments/64503.bugfix.rst new file mode 100644 index 0000000000000..0358708ea1f64 --- /dev/null +++ b/airflow-core/newsfragments/64503.bugfix.rst @@ -0,0 +1 @@ +Fix scheduler crashing with ``StaleDataError`` when a task instance is completed or removed by another session between ``verify_integrity`` loading task instances and ``session.flush()`` persisting them. Now caught and rolled back like the existing ``IntegrityError`` path. diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 7eabadd73cfb6..afe73a43b96ef 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -57,6 +57,7 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, declared_attr, joinedload, mapped_column, relationship, synonym, validates +from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql.expression import false, select from sqlalchemy.sql.functions import coalesce @@ -1873,14 +1874,17 @@ def _create_task_instances( extra_tags={"task_type": task_type}, ) session.flush() - except IntegrityError: + except (IntegrityError, StaleDataError) as exc: self.log.info( - "Hit IntegrityError while creating the TIs for %s- %s", + "Hit %s while creating the TIs for %s- %s", + type(exc).__name__, dag_id, run_id, exc_info=True, ) self.log.info("Doing session rollback.") + # Catching StaleDataError and rolling back is sufficient here because + # the next scheduler loop will re-read the latest state from the DB. # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() diff --git a/airflow-core/src/airflow/utils/retries.py b/airflow-core/src/airflow/utils/retries.py index a30d676685321..69b71046acb44 100644 --- a/airflow-core/src/airflow/utils/retries.py +++ b/airflow-core/src/airflow/utils/retries.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, TypeVar, overload from sqlalchemy.exc import DBAPIError +from sqlalchemy.orm.exc import StaleDataError from airflow.configuration import conf @@ -40,7 +41,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: Logger | None # Default kwargs retry_kwargs = dict( - retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)), + retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError, StaleDataError)), wait=tenacity.wait_random_exponential(multiplier=0.5, max=5), stop=tenacity.stop_after_attempt(max_retries), reraise=True, @@ -104,7 +105,7 @@ def wrapped_function(*args, **kwargs): ) try: return func(*args, **kwargs) - except DBAPIError: + except (DBAPIError, StaleDataError): session.rollback() raise diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 93bf2dcbdf428..b259b62552e22 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -39,6 +39,7 @@ update, ) from sqlalchemy.orm import joinedload +from sqlalchemy.orm.exc import StaleDataError from airflow import settings from airflow._shared.observability.metrics.stats import Stats @@ -1443,6 +1444,29 @@ def mynameis(arg): assert indices == [0, 1, 2, 3] +def test_verify_integrity_handles_stale_data_error(dag_maker, session): + """Test that StaleDataError during _create_task_instances is caught and session is rolled back.""" + with dag_maker("test_stale_data_error_dag", session=session) as dag: + task = EmptyOperator(task_id="task1") + + dr = dag_maker.create_dagrun() + dag_version_id = DagVersion.get_latest_version(dag.dag_id, session=session).id + + with mock.patch.object(session, "flush", side_effect=StaleDataError()): + with mock.patch.object(session, "rollback") as mock_rollback: + # Should not raise — StaleDataError must be caught gracefully. + # Call _create_task_instances directly with a non-empty task list so the + # test exercises the session.flush() → StaleDataError → session.rollback() path. + dr._create_task_instances( + dag_id=dag.dag_id, + tasks=[TI(task=task, run_id=dr.run_id, dag_version_id=dag_version_id)], + created_counts={"EmptyOperator": 1}, + hook_is_noop=False, + session=session, + ) + mock_rollback.assert_called_once() + + def test_mapped_literal_verify_integrity(dag_maker, session): """Test that when the length of a mapped literal changes we remove extra TIs""" diff --git a/airflow-core/tests/unit/utils/test_retries.py b/airflow-core/tests/unit/utils/test_retries.py index 1f44ee9ebf8be..f0976d0e3589f 100644 --- a/airflow-core/tests/unit/utils/test_retries.py +++ b/airflow-core/tests/unit/utils/test_retries.py @@ -18,17 +18,14 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING from unittest import mock import pytest from sqlalchemy.exc import InternalError, OperationalError +from sqlalchemy.orm.exc import StaleDataError from airflow.utils.retries import retry_db_transaction -if TYPE_CHECKING: - from sqlalchemy.exc import DBAPIError - class TestRetries: def test_retry_db_transaction_with_passing_retries(self): @@ -48,15 +45,29 @@ def test_function(session): assert mock_obj.call_count == 2 - @pytest.mark.db_test - @pytest.mark.parametrize("excection_type", [OperationalError, InternalError]) - def test_retry_db_transaction_with_default_retries(self, caplog, excection_type: type[DBAPIError]): + @pytest.mark.parametrize( + ("exception_type", "exception_kwargs"), + [ + pytest.param( + InternalError, + {"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY}, + id="dbapi-internal", + ), + pytest.param( + OperationalError, + {"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY}, + id="dbapi-operational", + ), + pytest.param(StaleDataError, {}, id="stale-data"), + ], + ) + def test_retry_db_transaction_with_default_retries(self, caplog, exception_type, exception_kwargs): """Test that by default 3 retries will be carried out""" mock_obj = mock.MagicMock() mock_session = mock.MagicMock() mock_rollback = mock.MagicMock() mock_session.rollback = mock_rollback - db_error = excection_type(statement=mock.ANY, params=mock.ANY, orig=mock.ANY) + db_error = exception_type(**exception_kwargs) @retry_db_transaction def test_function(session): @@ -66,7 +77,7 @@ def test_function(session): caplog.set_level(logging.DEBUG) caplog.clear() - with pytest.raises(excection_type): + with pytest.raises(exception_type): test_function(session=mock_session) for try_no in range(1, 4):