Skip to content

Commit

Permalink
ECS Executor - Add backoff on failed task retry (#37109)
Browse files Browse the repository at this point in the history
* ECS Executor - Add backoff on failed task retry
  • Loading branch information
ferruzzi committed Feb 5, 2024
1 parent d1aea8e commit 41ebf28
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 12 deletions.
23 changes: 20 additions & 3 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Expand Up @@ -42,7 +42,10 @@
EcsQueuedTask,
EcsTaskCollection,
)
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
calculate_next_attempt_delay,
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -300,7 +303,14 @@ def __handle_failed_task(self, task_arn: str, reason: str):
)
self.active_workers.increment_failure_count(task_key)
self.pending_tasks.appendleft(
EcsQueuedTask(task_key, task_cmd, queue, exec_info, failure_count + 1)
EcsQueuedTask(
task_key,
task_cmd,
queue,
exec_info,
failure_count + 1,
timezone.utcnow() + calculate_next_attempt_delay(failure_count),
)
)
else:
self.log.error(
Expand Down Expand Up @@ -331,6 +341,8 @@ def attempt_task_runs(self):
exec_config = ecs_task.executor_config
attempt_number = ecs_task.attempt_number
_failure_reasons = []
if timezone.utcnow() < ecs_task.next_attempt_time:
continue
try:
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
except NoCredentialsError:
Expand Down Expand Up @@ -361,6 +373,9 @@ def attempt_task_runs(self):
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
ecs_task.attempt_number += 1
ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
attempt_number
)
self.pending_tasks.appendleft(ecs_task)
else:
self.log.error(
Expand Down Expand Up @@ -422,7 +437,9 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None,
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
if executor_config and ("name" in executor_config or "command" in executor_config):
raise ValueError('Executor Config should never override "name" or "command"')
self.pending_tasks.append(EcsQueuedTask(key, command, queue, executor_config or {}, 1))
self.pending_tasks.append(
EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
)

def end(self, heartbeat_interval=10):
"""Waits for all currently running tasks to end, and doesn't launch any tasks."""
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Expand Up @@ -23,6 +23,7 @@

from __future__ import annotations

import datetime
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List
Expand Down Expand Up @@ -58,6 +59,7 @@ class EcsQueuedTask:
queue: str
executor_config: ExecutorConfigType
attempt_number: int
next_attempt_time: datetime.datetime


@dataclass
Expand Down
Expand Up @@ -25,6 +25,21 @@
log = logging.getLogger(__name__)


def calculate_next_attempt_delay(
attempt_number: int,
max_delay: int = 60 * 2,
exponent_base: int = 4,
) -> timedelta:
"""
Calculate the exponential backoff (in seconds) until the next attempt.
:param attempt_number: Number of attempts since last success.
:param max_delay: Maximum delay in seconds between retries. Default 120.
:param exponent_base: Exponent base to calculate delay. Default 4.
"""
return timedelta(seconds=min((exponent_base**attempt_number), max_delay))


def exponential_backoff_retry(
last_attempt_time: datetime,
attempts_since_last_successful: int,
Expand All @@ -34,7 +49,7 @@ def exponential_backoff_retry(
exponent_base: int = 4,
) -> None:
"""
Retries a callable function with exponential backoff between attempts if it raises an exception.
Retry a callable function with exponential backoff between attempts if it raises an exception.
:param last_attempt_time: Timestamp of last attempt call.
:param attempts_since_last_successful: Number of attempts since last success.
Expand All @@ -47,14 +62,18 @@ def exponential_backoff_retry(
log.error("Max attempts reached. Exiting.")
return

delay = min((exponent_base**attempts_since_last_successful), max_delay)
next_retry_time = last_attempt_time + timedelta(seconds=delay)
next_retry_time = last_attempt_time + calculate_next_attempt_delay(
attempt_number=attempts_since_last_successful, max_delay=max_delay, exponent_base=exponent_base
)

current_time = timezone.utcnow()

if current_time >= next_retry_time:
try:
callable_function()
except Exception:
log.exception("Error calling %r", callable_function.__name__)
next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay)
next_delay = calculate_next_attempt_delay(
attempts_since_last_successful + 1, max_delay, exponent_base
)
log.info("Waiting for %s seconds before retrying.", next_delay)
17 changes: 13 additions & 4 deletions tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
Expand Up @@ -25,6 +25,7 @@
from functools import partial
from typing import Callable
from unittest import mock
from unittest.mock import MagicMock

import pytest
import yaml
Expand All @@ -33,7 +34,7 @@

from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config
from airflow.providers.amazon.aws.executors.ecs import ecs_executor, ecs_executor_config
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import (
CONFIG_GROUP_NAME,
Expand All @@ -50,6 +51,7 @@
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import utcnow

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -365,7 +367,8 @@ def test_execute(self, mock_airflow_key, mock_executor):
assert 1 == len(mock_executor.active_workers)
assert ARN1 in mock_executor.active_workers.task_by_key(airflow_key).task_arn

def test_success_execute_api_exception(self, mock_executor):
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_success_execute_api_exception(self, mock_backoff, mock_executor):
"""Test what happens when ECS throws an exception, but ultimately runs the task."""
run_task_exception = Exception("Test exception")
run_task_success = {
Expand All @@ -381,9 +384,10 @@ def test_success_execute_api_exception(self, mock_executor):
}
mock_executor.ecs.run_task.side_effect = [run_task_exception, run_task_exception, run_task_success]
mock_executor.execute_async(mock_airflow_key, mock_cmd)
expected_retry_count = 2

# Fail 2 times
for _ in range(2):
for _ in range(expected_retry_count):
mock_executor.attempt_task_runs()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
Expand All @@ -392,6 +396,9 @@ def test_success_execute_api_exception(self, mock_executor):
mock_executor.attempt_task_runs()
assert len(mock_executor.pending_tasks) == 0
assert ARN1 in mock_executor.active_workers.get_all_arns()
assert mock_backoff.call_count == expected_retry_count
for attempt_number in range(1, expected_retry_count):
mock_backoff.assert_has_calls([mock.call(attempt_number)])

def test_failed_execute_api_exception(self, mock_executor):
"""Test what happens when ECS refuses to execute a task and throws an exception"""
Expand Down Expand Up @@ -479,7 +486,8 @@ def test_removed_sync(self, fail_mock, success_mock, mock_executor):

@mock.patch.object(BaseExecutor, "fail")
@mock.patch.object(BaseExecutor, "success")
def test_failed_sync_cumulative_fail(self, success_mock, fail_mock, mock_airflow_key, mock_executor):
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_failed_sync_cumulative_fail(self, _, success_mock, fail_mock, mock_airflow_key, mock_executor):
"""Test that failure_count/attempt_number is cumulative for pending tasks and active workers."""
AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "5"
mock_executor.ecs.run_task.return_value = {
Expand All @@ -488,6 +496,7 @@ def test_failed_sync_cumulative_fail(self, success_mock, fail_mock, mock_airflow
{"arn": ARN1, "reason": "Sample Failure", "detail": "UnitTest Failure - Please ignore"}
],
}
mock_executor._calculate_next_attempt_time = MagicMock(return_value=utcnow())
task_key = mock_airflow_key()
mock_executor.execute_async(task_key, mock_cmd)
for _ in range(2):
Expand Down
Expand Up @@ -21,7 +21,10 @@

import pytest

from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
calculate_next_attempt_delay,
exponential_backoff_retry,
)


class TestExponentialBackoffRetry:
Expand Down Expand Up @@ -279,3 +282,18 @@ def test_exponential_backoff_retry_exponent_base_parameterized(
exponent_base=3,
)
assert mock_callable_function.call_count == expected_calls

def test_calculate_next_attempt_delay(self):
exponent_base: int = 4
num_loops: int = 3
# Setting max_delay this way means there will be three loops will run to test:
# one will return a value under max_delay, one equal to max_delay, and one over.
max_delay: int = exponent_base**num_loops - 1

for attempt_number in range(1, num_loops):
returned_delay = calculate_next_attempt_delay(attempt_number, max_delay, exponent_base).seconds

if (expected_delay := exponent_base**attempt_number) <= max_delay:
assert returned_delay == expected_delay
else:
assert returned_delay == max_delay

0 comments on commit 41ebf28

Please sign in to comment.