Skip to content

Commit

Permalink
Retry Dagbag.sync_to_db to avoid Deadlocks (#12046)
Browse files Browse the repository at this point in the history
Previously we added Retry in DagFileProcessor.process_file to
retry dagbag.sync_to_db. However, this meant that if anyone calls
dagbag.sync_to_db separately then also need to manage retrying it
by themselves. This caused failures in CI for MySQL.

resolves #11543
  • Loading branch information
kaxil committed Nov 2, 2020
1 parent a1a1fc9 commit 2192010
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 41 deletions.
19 changes: 1 addition & 18 deletions airflow/jobs/scheduler_job.py
Expand Up @@ -32,7 +32,6 @@
from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple

import tenacity
from setproctitle import setproctitle
from sqlalchemy import and_, func, not_, or_
from sqlalchemy.exc import OperationalError
Expand Down Expand Up @@ -678,23 +677,7 @@ def process_file(

# Save individual DAGs in the ORM
dagbag.read_dags_from_db = True

# Retry 'dagbag.sync_to_db()' in case of any Operational Errors
# In case of failures, provide_session handles rollback
for attempt in tenacity.Retrying(
retry=tenacity.retry_if_exception_type(exception_types=OperationalError),
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES),
before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG),
reraise=True
):
with attempt:
self.log.debug(
"Running dagbag.sync_to_db with retries. Try %d of %d",
attempt.retry_state.attempt_number,
settings.MAX_DB_RETRIES
)
dagbag.sync_to_db()
dagbag.sync_to_db()

if pickle_dags:
paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
Expand Down
35 changes: 30 additions & 5 deletions airflow/models/dagbag.py
Expand Up @@ -20,6 +20,7 @@
import importlib
import importlib.machinery
import importlib.util
import logging
import os
import sys
import textwrap
Expand All @@ -29,7 +30,9 @@
from datetime import datetime, timedelta
from typing import Dict, List, NamedTuple, Optional

import tenacity
from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from tabulate import tabulate

Expand Down Expand Up @@ -522,8 +525,30 @@ def sync_to_db(self, session: Optional[Session] = None):
# To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
self.log.debug("Calling the DAG.bulk_sync_to_db method")
DAG.bulk_write_to_db(self.dags.values(), session=session)
# Write Serialized DAGs to DB
self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method")
SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session)

# Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case
# of any Operational Errors
# In case of failures, provide_session handles rollback
for attempt in tenacity.Retrying(
retry=tenacity.retry_if_exception_type(exception_types=OperationalError),
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES),
before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG),
reraise=True
):
with attempt:
self.log.debug(
"Running dagbag.sync_to_db with retries. Try %d of %d",
attempt.retry_state.attempt_number,
settings.MAX_DB_RETRIES
)
self.log.debug("Calling the DAG.bulk_sync_to_db method")
try:
DAG.bulk_write_to_db(self.dags.values(), session=session)

# Write Serialized DAGs to DB
self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method")
SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session)
except OperationalError:
session.rollback()
raise
18 changes: 0 additions & 18 deletions tests/jobs/test_scheduler_job.py
Expand Up @@ -31,7 +31,6 @@
import pytest
from parameterized import parameterized
from sqlalchemy import func
from sqlalchemy.exc import OperationalError

import airflow.example_dags
import airflow.smart_sensor_dags
Expand Down Expand Up @@ -709,23 +708,6 @@ def test_process_file_should_failure_callback(self):
self.assertEqual("Callback fired", content)
os.remove(callback_file.name)

@mock.patch("airflow.jobs.scheduler_job.DagBag")
def test_process_file_should_retry_sync_to_db(self, mock_dagbag):
"""Test that dagbag.sync_to_db is retried on OperationalError"""
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())

mock_dagbag.return_value.dags = {'example_dag': mock.ANY}
op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)

# Mock error for the first 2 tries and a successful third try
side_effect = [op_error, op_error, mock.ANY]

mock_sync_to_db = mock.Mock(side_effect=side_effect)
mock_dagbag.return_value.sync_to_db = mock_sync_to_db

dag_file_processor.process_file("/dev/null", callback_requests=mock.MagicMock())
mock_sync_to_db.assert_has_calls([mock.call(), mock.call(), mock.call()])

def test_should_mark_dummy_task_as_success(self):
dag_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py'
Expand Down
33 changes: 33 additions & 0 deletions tests/models/test_dagbag.py
Expand Up @@ -21,10 +21,12 @@
import unittest
from datetime import datetime, timezone
from tempfile import NamedTemporaryFile, mkdtemp
from unittest import mock
from unittest.mock import patch

from freezegun import freeze_time
from sqlalchemy import func
from sqlalchemy.exc import OperationalError

import airflow.example_dags
from airflow import models
Expand Down Expand Up @@ -661,6 +663,37 @@ def test_serialized_dags_are_written_to_db_on_sync(self):
new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar()
self.assertEqual(new_serialized_dags_count, 1)

@patch("airflow.models.dagbag.DagBag.collect_dags")
@patch("airflow.models.serialized_dag.SerializedDagModel.bulk_sync_to_db")
@patch("airflow.models.dag.DAG.bulk_write_to_db")
def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_sdag_sync_to_db, mock_collect_dags):
"""Test that dagbag.sync_to_db is retried on OperationalError"""

dagbag = DagBag("/dev/null")

op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)

# Mock error for the first 2 tries and a successful third try
side_effect = [op_error, op_error, mock.ANY]

mock_bulk_write_to_db.side_effect = side_effect

mock_session = mock.MagicMock()
dagbag.sync_to_db(session=mock_session)

# Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully
mock_bulk_write_to_db.assert_has_calls([
mock.call(mock.ANY, session=mock.ANY),
mock.call(mock.ANY, session=mock.ANY),
mock.call(mock.ANY, session=mock.ANY),
])
# Assert that rollback is called twice (i.e. whenever OperationalError occurs)
mock_session.rollback.assert_has_calls([mock.call(), mock.call()])
# Check that 'SerializedDagModel.bulk_sync_to_db' is also called
# Only called once since the other two times the 'DAG.bulk_write_to_db' error'd
# and the session was roll-backed before even reaching 'SerializedDagModel.bulk_sync_to_db'
mock_sdag_sync_to_db.assert_has_calls([mock.call(mock.ANY, session=mock.ANY)])

@patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5)
@patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5)
def test_get_dag_with_dag_serialization(self):
Expand Down

0 comments on commit 2192010

Please sign in to comment.