From ebe931b92df12063b2cd8fdb733ad1b3bee8b039 Mon Sep 17 00:00:00 2001 From: Subham Sangwan Date: Thu, 21 May 2026 08:59:11 +0530 Subject: [PATCH 1/2] Fix EMR Serverless task failure on transient AWS throttling errors --- .../amazon/aws/utils/waiter_with_logging.py | 36 +++++++++++-- .../aws/utils/test_waiter_with_logging.py | 53 +++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py index 615523a704752..dcfa7d8107c06 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -30,6 +30,28 @@ if TYPE_CHECKING: from botocore.waiter import Waiter +# Standard throttling and transient error codes to retry on +# https://docs.aws.amazon.com/general/latest/gr/api-retries.html +# and https://github.com/boto/botocore/blob/develop/botocore/retryhandler.py +RETRIABLE_ERROR_CODES = { + "ThrottlingException", + "Throttling", + "RequestLimitExceeded", + "ProvisionedThroughputExceededException", + "LimitExceededException", + "RequestThrottled", + "RequestThrottledException", + "TooManyRequestsException", + "ServerException", + "InternalServerError", + "InternalFailure", + "ServiceUnavailable", + "BadGateway", + "GatewayTimeout", + "RequestTimeout", + "RequestTimeoutException", +} + def wait( waiter: Waiter, @@ -87,7 +109,11 @@ def wait( and isinstance(last_response.get("Error"), dict) and "Code" in last_response.get("Error") ): - raise AirflowException(f"{failure_message}: {error}") + error_code = last_response["Error"]["Code"] + if error_code not in RETRIABLE_ERROR_CODES: + raise AirflowException(f"{failure_message}: {error}") + + log.info("Waiter encountered retriable error: %s. Retrying...", error_code) log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response)) else: @@ -104,7 +130,7 @@ async def async_wait( failure_message: str, status_message: str, status_args: list[str], -): +) -> None: """ Use an async boto waiter to poll an AWS service for the specified state. @@ -153,7 +179,11 @@ async def async_wait( and isinstance(last_response.get("Error"), dict) and "Code" in last_response.get("Error") ): - raise AirflowException(f"{failure_message}\n{last_response}\n{error}") + error_code = last_response["Error"]["Code"] + if error_code not in RETRIABLE_ERROR_CODES: + raise AirflowException(f"{failure_message}\n{last_response}\n{error}") + + log.info("Waiter encountered retriable error: %s. Retrying...", error_code) log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response)) else: diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py b/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py index 716af7fbdd1af..3414403578e6d 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py @@ -367,3 +367,56 @@ def test_status_formatting_not_done_if_higher_log_level(self, status_format_mock finally: logger.setLevel(level) status_format_mock.assert_not_called() + + @mock.patch("time.sleep") + def test_wait_with_retriable_throttling_error(self, mock_sleep): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + throttling_error = WaiterError( + name="test_waiter", + reason="An error occurred (ThrottlingException) when calling the GetJobRun operation: Rate exceeded", + last_response={ + "Error": { + "Message": "Rate exceeded", + "Code": "ThrottlingException", + } + }, + ) + mock_waiter.wait.side_effect = [throttling_error, throttling_error, True] + wait( + waiter=mock_waiter, + waiter_delay=123, + waiter_max_attempts=10, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + assert mock_waiter.wait.call_count == 3 + mock_sleep.assert_called_with(123) + + @pytest.mark.asyncio + async def test_async_wait_with_retriable_throttling_error(self): + mock_waiter = mock.MagicMock() + throttling_error = WaiterError( + name="test_waiter", + reason="An error occurred (ThrottlingException) when calling the GetJobRun operation: Rate exceeded", + last_response={ + "Error": { + "Message": "Rate exceeded", + "Code": "ThrottlingException", + } + }, + ) + mock_waiter.wait = AsyncMock() + mock_waiter.wait.side_effect = [throttling_error, throttling_error, True] + await async_wait( + waiter=mock_waiter, + waiter_delay=0, + waiter_max_attempts=10, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + assert mock_waiter.wait.call_count == 3 From 5047a65acdeb60eb2ec65bd1fcf174fb21786543 Mon Sep 17 00:00:00 2001 From: Subham Sangwan Date: Thu, 21 May 2026 14:29:16 +0530 Subject: [PATCH 2/2] fix: Don't count throttling retries against waiter_max_attempts --- .../amazon/aws/utils/waiter_with_logging.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py index dcfa7d8107c06..60e20bbf2cddb 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -87,9 +87,12 @@ def wait( status_args = ["Clusters[0].state", "Clusters[0].details"] """ log = logging.getLogger(__name__) - for attempt in range(waiter_max_attempts): - if attempt: + first_attempt = True + attempt = 0 + while attempt < waiter_max_attempts: + if not first_attempt: time.sleep(waiter_delay) + first_attempt = False try: waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) @@ -113,11 +116,19 @@ def wait( if error_code not in RETRIABLE_ERROR_CODES: raise AirflowException(f"{failure_message}: {error}") - log.info("Waiter encountered retriable error: %s. Retrying...", error_code) + log.info( + "Waiter encountered retriable error: %s. Retrying (attempt %d/%d)...", + error_code, + attempt + 1, + waiter_max_attempts, + ) + # Don't increment attempt counter for retriable errors; continue looping + continue log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response)) else: break + attempt += 1 else: raise AirflowException("Waiter error: max attempts reached") @@ -156,9 +167,12 @@ async def async_wait( status_args = ["Clusters[0].state", "Clusters[0].details"] """ log = logging.getLogger(__name__) - for attempt in range(waiter_max_attempts): - if attempt: + first_attempt = True + attempt = 0 + while attempt < waiter_max_attempts: + if not first_attempt: await asyncio.sleep(waiter_delay) + first_attempt = False try: await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) @@ -183,11 +197,19 @@ async def async_wait( if error_code not in RETRIABLE_ERROR_CODES: raise AirflowException(f"{failure_message}\n{last_response}\n{error}") - log.info("Waiter encountered retriable error: %s. Retrying...", error_code) + log.info( + "Waiter encountered retriable error: %s. Retrying (attempt %d/%d)...", + error_code, + attempt + 1, + waiter_max_attempts, + ) + # Don't increment attempt counter for retriable errors; continue looping + continue log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response)) else: break + attempt += 1 else: raise AirflowException("Waiter error: max attempts reached")