Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/64503.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix scheduler crashing with ``StaleDataError`` when a task instance is completed or removed by another session between ``verify_integrity`` loading task instances and ``session.flush()`` persisting them. Now caught and rolled back like the existing ``IntegrityError`` path.
8 changes: 6 additions & 2 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import Mapped, declared_attr, joinedload, mapped_column, relationship, synonym, validates
from sqlalchemy.orm.exc import StaleDataError
from sqlalchemy.sql.expression import false, select
from sqlalchemy.sql.functions import coalesce

Expand Down Expand Up @@ -1873,14 +1874,17 @@ def _create_task_instances(
extra_tags={"task_type": task_type},
)
session.flush()
except IntegrityError:
except (IntegrityError, StaleDataError) as exc:
self.log.info(
"Hit IntegrityError while creating the TIs for %s- %s",
"Hit %s while creating the TIs for %s- %s",
type(exc).__name__,
dag_id,
run_id,
exc_info=True,
)
self.log.info("Doing session rollback.")
# Catching StaleDataError and rolling back is sufficient here because
# the next scheduler loop will re-read the latest state from the DB.
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

Expand Down
5 changes: 3 additions & 2 deletions airflow-core/src/airflow/utils/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING, TypeVar, overload

from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm.exc import StaleDataError

from airflow.configuration import conf

Expand All @@ -40,7 +41,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: Logger | None

# Default kwargs
retry_kwargs = dict(
retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError, StaleDataError)),
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(max_retries),
reraise=True,
Expand Down Expand Up @@ -104,7 +105,7 @@ def wrapped_function(*args, **kwargs):
)
try:
return func(*args, **kwargs)
except DBAPIError:
except (DBAPIError, StaleDataError):
session.rollback()
raise

Expand Down
24 changes: 24 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
update,
)
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import StaleDataError

from airflow import settings
from airflow._shared.observability.metrics.stats import Stats
Expand Down Expand Up @@ -1443,6 +1444,29 @@ def mynameis(arg):
assert indices == [0, 1, 2, 3]


def test_verify_integrity_handles_stale_data_error(dag_maker, session):
"""Test that StaleDataError during _create_task_instances is caught and session is rolled back."""
with dag_maker("test_stale_data_error_dag", session=session) as dag:
task = EmptyOperator(task_id="task1")

dr = dag_maker.create_dagrun()
dag_version_id = DagVersion.get_latest_version(dag.dag_id, session=session).id

with mock.patch.object(session, "flush", side_effect=StaleDataError()):
with mock.patch.object(session, "rollback") as mock_rollback:
# Should not raise — StaleDataError must be caught gracefully.
# Call _create_task_instances directly with a non-empty task list so the
# test exercises the session.flush() → StaleDataError → session.rollback() path.
dr._create_task_instances(
dag_id=dag.dag_id,
tasks=[TI(task=task, run_id=dr.run_id, dag_version_id=dag_version_id)],
created_counts={"EmptyOperator": 1},
hook_is_noop=False,
session=session,
)
mock_rollback.assert_called_once()


def test_mapped_literal_verify_integrity(dag_maker, session):
"""Test that when the length of a mapped literal changes we remove extra TIs"""

Expand Down
29 changes: 20 additions & 9 deletions airflow-core/tests/unit/utils/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from unittest import mock

import pytest
from sqlalchemy.exc import InternalError, OperationalError
from sqlalchemy.orm.exc import StaleDataError

from airflow.utils.retries import retry_db_transaction

if TYPE_CHECKING:
from sqlalchemy.exc import DBAPIError


class TestRetries:
def test_retry_db_transaction_with_passing_retries(self):
Expand All @@ -48,15 +45,29 @@ def test_function(session):

assert mock_obj.call_count == 2

@pytest.mark.db_test
@pytest.mark.parametrize("excection_type", [OperationalError, InternalError])
def test_retry_db_transaction_with_default_retries(self, caplog, excection_type: type[DBAPIError]):
@pytest.mark.parametrize(
("exception_type", "exception_kwargs"),
[
pytest.param(
InternalError,
{"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY},
id="dbapi-internal",
),
pytest.param(
OperationalError,
{"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY},
id="dbapi-operational",
),
pytest.param(StaleDataError, {}, id="stale-data"),
],
)
def test_retry_db_transaction_with_default_retries(self, caplog, exception_type, exception_kwargs):
"""Test that by default 3 retries will be carried out"""
mock_obj = mock.MagicMock()
mock_session = mock.MagicMock()
mock_rollback = mock.MagicMock()
mock_session.rollback = mock_rollback
db_error = excection_type(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
db_error = exception_type(**exception_kwargs)

@retry_db_transaction
def test_function(session):
Expand All @@ -66,7 +77,7 @@ def test_function(session):

caplog.set_level(logging.DEBUG)
caplog.clear()
with pytest.raises(excection_type):
with pytest.raises(exception_type):
test_function(session=mock_session)

for try_no in range(1, 4):
Expand Down
Loading