Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@
if TYPE_CHECKING:
from botocore.waiter import Waiter

# Standard throttling and transient error codes to retry on
# https://docs.aws.amazon.com/general/latest/gr/api-retries.html
# and https://github.com/boto/botocore/blob/develop/botocore/retryhandler.py
RETRIABLE_ERROR_CODES = {
Comment thread
Subham-KRLX marked this conversation as resolved.
"ThrottlingException",
"Throttling",
"RequestLimitExceeded",
"ProvisionedThroughputExceededException",
"LimitExceededException",
"RequestThrottled",
"RequestThrottledException",
"TooManyRequestsException",
"ServerException",
"InternalServerError",
"InternalFailure",
"ServiceUnavailable",
"BadGateway",
"GatewayTimeout",
"RequestTimeout",
"RequestTimeoutException",
}


def wait(
waiter: Waiter,
Expand Down Expand Up @@ -65,9 +87,12 @@ def wait(
status_args = ["Clusters[0].state", "Clusters[0].details"]
"""
log = logging.getLogger(__name__)
for attempt in range(waiter_max_attempts):
if attempt:
first_attempt = True
attempt = 0
while attempt < waiter_max_attempts:
if not first_attempt:
time.sleep(waiter_delay)
first_attempt = False
try:
waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})

Expand All @@ -87,11 +112,23 @@ def wait(
and isinstance(last_response.get("Error"), dict)
and "Code" in last_response.get("Error")
):
raise AirflowException(f"{failure_message}: {error}")
error_code = last_response["Error"]["Code"]
if error_code not in RETRIABLE_ERROR_CODES:
raise AirflowException(f"{failure_message}: {error}")

log.info(
"Waiter encountered retriable error: %s. Retrying (attempt %d/%d)...",
error_code,
attempt + 1,
waiter_max_attempts,
)
# Don't increment attempt counter for retriable errors; continue looping
continue

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
else:
break
attempt += 1
else:
raise AirflowException("Waiter error: max attempts reached")

Expand All @@ -104,7 +141,7 @@ async def async_wait(
failure_message: str,
status_message: str,
status_args: list[str],
):
) -> None:
"""
Use an async boto waiter to poll an AWS service for the specified state.

Expand All @@ -130,9 +167,12 @@ async def async_wait(
status_args = ["Clusters[0].state", "Clusters[0].details"]
"""
log = logging.getLogger(__name__)
for attempt in range(waiter_max_attempts):
if attempt:
first_attempt = True
attempt = 0
while attempt < waiter_max_attempts:
if not first_attempt:
await asyncio.sleep(waiter_delay)
first_attempt = False
try:
await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})

Expand All @@ -153,11 +193,23 @@ async def async_wait(
and isinstance(last_response.get("Error"), dict)
and "Code" in last_response.get("Error")
):
raise AirflowException(f"{failure_message}\n{last_response}\n{error}")
error_code = last_response["Error"]["Code"]
if error_code not in RETRIABLE_ERROR_CODES:
raise AirflowException(f"{failure_message}\n{last_response}\n{error}")

log.info(
"Waiter encountered retriable error: %s. Retrying (attempt %d/%d)...",
error_code,
attempt + 1,
waiter_max_attempts,
)
# Don't increment attempt counter for retriable errors; continue looping
continue

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
else:
break
attempt += 1
else:
raise AirflowException("Waiter error: max attempts reached")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,56 @@ def test_status_formatting_not_done_if_higher_log_level(self, status_format_mock
finally:
logger.setLevel(level)
status_format_mock.assert_not_called()

@mock.patch("time.sleep")
def test_wait_with_retriable_throttling_error(self, mock_sleep):
mock_sleep.return_value = True
mock_waiter = mock.MagicMock()
throttling_error = WaiterError(
name="test_waiter",
reason="An error occurred (ThrottlingException) when calling the GetJobRun operation: Rate exceeded",
last_response={
"Error": {
"Message": "Rate exceeded",
"Code": "ThrottlingException",
}
},
)
mock_waiter.wait.side_effect = [throttling_error, throttling_error, True]
wait(
waiter=mock_waiter,
waiter_delay=123,
waiter_max_attempts=10,
args={"test_arg": "test_value"},
failure_message="test failure message",
status_message="test status message",
status_args=["Status.State"],
)
assert mock_waiter.wait.call_count == 3
mock_sleep.assert_called_with(123)

@pytest.mark.asyncio
async def test_async_wait_with_retriable_throttling_error(self):
mock_waiter = mock.MagicMock()
throttling_error = WaiterError(
name="test_waiter",
reason="An error occurred (ThrottlingException) when calling the GetJobRun operation: Rate exceeded",
last_response={
"Error": {
"Message": "Rate exceeded",
"Code": "ThrottlingException",
}
},
)
mock_waiter.wait = AsyncMock()
mock_waiter.wait.side_effect = [throttling_error, throttling_error, True]
await async_wait(
waiter=mock_waiter,
waiter_delay=0,
waiter_max_attempts=10,
args={"test_arg": "test_value"},
failure_message="test failure message",
status_message="test status message",
status_args=["Status.State"],
)
assert mock_waiter.wait.call_count == 3
Loading