Skip to content

Commit

Permalink
add async wait method to the "with logging" aws utils (#32055)
Browse files Browse the repository at this point in the history
Also changed the status formatting in the logs so that it'd not be done if log level is not including INFO
  • Loading branch information
vandonr-amz committed Jun 22, 2023
1 parent c48f744 commit 4797192
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 16 deletions.
90 changes: 76 additions & 14 deletions airflow/providers/amazon/aws/utils/waiter_with_logging.py
Expand Up @@ -17,8 +17,10 @@

from __future__ import annotations

import asyncio
import logging
import time
from typing import Any

import jmespath
from botocore.exceptions import WaiterError
Expand All @@ -31,10 +33,10 @@ def wait(
waiter: Waiter,
waiter_delay: int,
max_attempts: int,
args: dict,
args: dict[str, Any],
failure_message: str,
status_message: str,
status_args: list,
status_args: list[str],
) -> None:
"""
Use a boto waiter to poll an AWS service for the specified state. Although this function
Expand All @@ -47,7 +49,7 @@ def wait(
:param args: The arguments to pass to the waiter.
:param failure_message: The message to log if a failure state is reached.
:param status_message: The message logged when printing the status of the service.
:param status_args: A list containing the arguments to retrieve status information from
:param status_args: A list containing the JMESPath queries to retrieve status information from
the waiter response.
e.g.
response = {"Cluster": {"state": "CREATING"}}
Expand All @@ -68,23 +70,83 @@ def wait(
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"{failure_message}: {error}")
status_string = _format_status_string(status_args, error.last_response)
log.info("%s: %s", status_message, status_string)

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response))
if attempt >= max_attempts:
raise AirflowException("Waiter error: max attempts reached")

time.sleep(waiter_delay)


async def async_wait(
waiter: Waiter,
waiter_delay: int,
max_attempts: int,
args: dict[str, Any],
failure_message: str,
status_message: str,
status_args: list[str],
):
"""
Use an async boto waiter to poll an AWS service for the specified state. Although this function
uses boto waiters to poll the state of the service, it logs the response of the service
after every attempt, which is not currently supported by boto waiters.
:param waiter: The boto waiter to use.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param args: The arguments to pass to the waiter.
:param failure_message: The message to log if a failure state is reached.
:param status_message: The message logged when printing the status of the service.
:param status_args: A list containing the JMESPath queries to retrieve status information from
the waiter response.
e.g.
response = {"Cluster": {"state": "CREATING"}}
status_args = ["Cluster.state"]
response = {
"Clusters": [{"state": "CREATING", "details": "User initiated."},]
}
status_args = ["Clusters[0].state", "Clusters[0].details"]
"""
log = logging.getLogger(__name__)
attempt = 0
while True:
attempt += 1
try:
await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
break
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"{failure_message}: {error}")

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response))
if attempt >= max_attempts:
raise AirflowException("Waiter error: max attempts reached")

await asyncio.sleep(waiter_delay)


def _format_status_string(args, response):
class _LazyStatusFormatter:
"""
Loops through the supplied args list and generates a string
which contains values from the waiter response.
a wrapper containing the info necessary to extract the status from a response,
that'll only compute the value when necessary.
Used to avoid computations if the logs are disabled at the given level.
"""
values = []
for arg in args:
value = jmespath.search(arg, response)
if value is not None and value != "":
values.append(str(value))

return " - ".join(values)
def __init__(self, jmespath_queries: list[str], response: dict[str, Any]):
self.jmespath_queries = jmespath_queries
self.response = response

def __str__(self):
"""
Loops through the supplied args list and generates a string
which contains values from the waiter response.
"""
values = []
for query in self.jmespath_queries:
value = jmespath.search(query, self.response)
if value is not None and value != "":
values.append(str(value))

return " - ".join(values)
59 changes: 57 additions & 2 deletions tests/providers/amazon/aws/utils/test_waiter_with_logging.py
Expand Up @@ -20,12 +20,13 @@
import logging
from typing import Any
from unittest import mock
from unittest.mock import AsyncMock

import pytest
from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.providers.amazon.aws.utils.waiter_with_logging import _LazyStatusFormatter, async_wait, wait


def generate_response(state: str) -> dict[str, Any]:
Expand Down Expand Up @@ -63,7 +64,7 @@ def test_wait(self, mock_sleep, caplog):
"MaxAttempts": 1,
},
)
mock_waiter.wait.call_count == 3
assert mock_waiter.wait.call_count == 3
mock_sleep.assert_called_with(123)
assert (
caplog.record_tuples
Expand All @@ -77,6 +78,36 @@ def test_wait(self, mock_sleep, caplog):
* 2
)

@pytest.mark.asyncio
async def test_async_wait(self, caplog):
mock_waiter = mock.MagicMock()
error = WaiterError(
name="test_waiter",
reason="test_reason",
last_response=generate_response("Pending"),
)
mock_waiter.wait = AsyncMock()
mock_waiter.wait.side_effect = [error, error, True]

await async_wait(
waiter=mock_waiter,
waiter_delay=0,
max_attempts=456,
args={"test_arg": "test_value"},
failure_message="test failure message",
status_message="test status message",
status_args=["Status.State"],
)

mock_waiter.wait.assert_called_with(
**{"test_arg": "test_value"},
WaiterConfig={
"MaxAttempts": 1,
},
)
assert mock_waiter.wait.call_count == 3
assert caplog.messages == ["test status message: Pending", "test status message: Pending"]

@mock.patch("time.sleep")
def test_wait_max_attempts_exceeded(self, mock_sleep, caplog):
mock_sleep.return_value = True
Expand Down Expand Up @@ -302,3 +333,27 @@ def test_wait_with_multiple_args(self, mock_sleep, caplog):
]
* 2
)

@mock.patch.object(_LazyStatusFormatter, "__str__")
def test_status_formatting_not_done_if_higher_log_level(self, status_format_mock: mock.MagicMock, caplog):
mock_waiter = mock.MagicMock()
error = WaiterError(
name="test_waiter",
reason="test_reason",
last_response=generate_response("Pending"),
)
mock_waiter.wait.side_effect = [error, error, True]

with caplog.at_level(level=logging.WARNING):
wait(
waiter=mock_waiter,
waiter_delay=0,
max_attempts=456,
args={"test_arg": "test_value"},
failure_message="test failure message",
status_message="test status message",
status_args=["Status.State"],
)

assert len(caplog.messages) == 0
status_format_mock.assert_not_called()

0 comments on commit 4797192

Please sign in to comment.