Skip to content

Commit 230da3e

Browse files
authored
EcsRunTaskOperator fails when no containers are provided in the response (#51692)
* Added logic to safely access the container name only if its required and poll multiple times in case the task is not yet active * Adujusted the test case to better reflect actual ecs response
1 parent a1c9180 commit 230da3e

File tree

2 files changed

+84
-3
lines changed
  • providers/amazon
    • src/airflow/providers/amazon/aws/operators
    • tests/unit/amazon/aws/operators

2 files changed

+84
-3
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections.abc import Sequence
2222
from datetime import timedelta
2323
from functools import cached_property
24+
from time import sleep
2425
from typing import TYPE_CHECKING, Any
2526

2627
from airflow.configuration import conf
@@ -629,10 +630,22 @@ def _start_task(self):
629630
self.log.info("ECS Task started: %s", response)
630631

631632
self.arn = response["tasks"][0]["taskArn"]
632-
if not self.container_name:
633-
self.container_name = response["tasks"][0]["containers"][0]["name"]
634633
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
635634

635+
if not self.container_name and (self.awslogs_group and self.awslogs_stream_prefix):
636+
backoff_schedule = [10, 30]
637+
for delay in backoff_schedule:
638+
sleep(delay)
639+
response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn])
640+
containers = response["tasks"][0].get("containers", [])
641+
if containers:
642+
self.container_name = containers[0]["name"]
643+
if self.container_name:
644+
break
645+
646+
if not self.container_name:
647+
self.log.info("Could not find container name, required for the log stream after 2 tries")
648+
636649
def _try_reattach_task(self, started_by: str):
637650
if not started_by:
638651
raise AirflowException("`started_by` should not be empty or None")
@@ -666,7 +679,13 @@ def _aws_logs_enabled(self):
666679
return self.awslogs_group and self.awslogs_stream_prefix
667680

668681
def _get_logs_stream_name(self) -> str:
669-
if (
682+
if not self.container_name and self.awslogs_stream_prefix and "/" not in self.awslogs_stream_prefix:
683+
self.log.warning(
684+
"Container name could not be inferred and awslogs_stream_prefix '%s' does not contain '/'. "
685+
"This may cause issues when extracting logs from Cloudwatch.",
686+
self.awslogs_stream_prefix,
687+
)
688+
elif (
670689
self.awslogs_stream_prefix
671690
and self.container_name
672691
and not self.awslogs_stream_prefix.endswith(f"/{self.container_name}")

providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,20 @@
8181
}
8282
],
8383
}
84+
RESPONSE_WITHOUT_NAME = {
85+
"failures": [],
86+
"tasks": [
87+
{
88+
"containers": [],
89+
"desiredStatus": "RUNNING",
90+
"lastStatus": "PENDING",
91+
"taskArn": f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
92+
"taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11",
93+
}
94+
],
95+
}
96+
97+
8498
WAITERS_TEST_CASES = [
8599
pytest.param(None, None, id="default-values"),
86100
pytest.param(3.14, None, id="set-delay-only"),
@@ -788,6 +802,54 @@ def test_container_name_in_log_stream(self, client_mock, log_fetcher_mock):
788802

789803
assert self.ecs._get_logs_stream_name().startswith(f"{prefix}/{container_name}/")
790804

805+
@mock.patch.object(EcsBaseOperator, "client")
806+
@mock.patch("airflow.providers.amazon.aws.operators.ecs.sleep", return_value=None)
807+
def test_container_name_not_set(self, sleep_mock, client_mock):
808+
self.set_up_operator(
809+
awslogs_group="awslogs-group",
810+
awslogs_stream_prefix="prefix",
811+
container_name=None,
812+
)
813+
client_mock.run_task.return_value = RESPONSE_WITHOUT_NAME
814+
client_mock.describe_tasks.side_effect = [
815+
{"tasks": [{"containers": []}]},
816+
{"tasks": [{"containers": [{"name": "resolved-container"}]}]},
817+
]
818+
self.ecs._start_task()
819+
assert client_mock.describe_tasks.call_count == 2
820+
assert self.ecs.container_name == "resolved-container"
821+
822+
@mock.patch.object(EcsBaseOperator, "client")
823+
@mock.patch.object(EcsBaseOperator, "log")
824+
@mock.patch("airflow.providers.amazon.aws.operators.ecs.sleep", return_value=None)
825+
def test_container_name_resolution_fails_logs_message(self, sleep_mock, log_mock, client_mock):
826+
self.set_up_operator(
827+
awslogs_group="test-group",
828+
awslogs_stream_prefix="prefix",
829+
container_name=None,
830+
)
831+
client_mock.run_task.return_value = RESPONSE_WITHOUT_NAME
832+
client_mock.describe_tasks.return_value = {"tasks": [{"containers": [{"name": None}]}]}
833+
834+
self.ecs._start_task()
835+
836+
assert client_mock.describe_tasks.call_count == 2
837+
assert self.ecs.container_name is None
838+
log_mock.info.assert_called_with(
839+
"Could not find container name, required for the log stream after 2 tries"
840+
)
841+
842+
@mock.patch.object(EcsBaseOperator, "client")
843+
def test_container_name_not_polled(self, client_mock):
844+
self.set_up_operator(
845+
awslogs_group=None,
846+
awslogs_stream_prefix=None,
847+
container_name=None,
848+
)
849+
client_mock.run_task.return_value = RESPONSE_WITHOUT_NAME
850+
self.ecs._start_task()
851+
assert client_mock.describe_tasks.call_count == 0
852+
791853

792854
class TestEcsCreateClusterOperator(EcsBaseTestCase):
793855
@pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES)

0 commit comments

Comments
 (0)