Skip to content

Commit

Permalink
Do not return success from AWS ECS trigger after max_attempts (#32589)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz committed Jul 18, 2023
1 parent 62044f1 commit 7ed791d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/triggers/ecs.py
Expand Up @@ -22,6 +22,7 @@

from botocore.exceptions import ClientError, WaiterError

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
Expand Down Expand Up @@ -170,7 +171,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
await waiter.wait(
cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}
)
break # we reach this point only if the waiter met a success criteria
# we reach this point only if the waiter met a success criteria
yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
except WaiterError as error:
if "terminal failure" in str(error):
raise
Expand All @@ -179,8 +181,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
finally:
if self.log_group and self.log_stream:
logs_token = await self._forward_logs(logs_client, logs_token)

yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
raise AirflowException("Waiter error: max attempts reached")

async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None:
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/providers/amazon/aws/triggers/test_ecs.py
Expand Up @@ -22,6 +22,7 @@
import pytest
from botocore.exceptions import WaiterError

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
Expand Down Expand Up @@ -56,6 +57,26 @@ async def test_run_until_error(self, _, client_mock):

assert wait_mock.call_count == 3

@pytest.mark.asyncio
@mock.patch.object(EcsHook, "async_conn")
# this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step
@mock.patch.object(AwsLogsHook, "async_conn")
async def test_run_until_timeout(self, _, client_mock):
a_mock = mock.MagicMock()
client_mock.__aenter__.return_value = a_mock
wait_mock = AsyncMock()
wait_mock.side_effect = WaiterError("name", "reason", {"tasks": [{"lastStatus": "my_status"}]})
a_mock.get_waiter().wait = wait_mock

trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None)

with pytest.raises(AirflowException) as err:
generator = trigger.run()
await generator.asend(None)

assert wait_mock.call_count == 10
assert "max attempts" in str(err.value)

@pytest.mark.asyncio
@mock.patch.object(EcsHook, "async_conn")
# this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step
Expand Down

0 comments on commit 7ed791d

Please sign in to comment.