Skip to content

Commit

Permalink
ECS Executor Health Check (#35412)
Browse files Browse the repository at this point in the history
During startup the Scheduler calls start() on the configured Executor.
Attempt an API call to ECS via the Boto client in this method to test the health of the ECS Executor.
This will test most of the machinery of the executor (credentials, permissions, configuration, etc).
If the check fails and the executor is unhealthy don't allow the scheduler to continue to start up, 
fail hard and message clearly to the user what is the issue.


---------

Co-authored-by: ferruzzi <ferruzzi@amazon.com>
Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 3, 2023
1 parent 92d1e8c commit ae9a7b8
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 5 deletions.
57 changes: 57 additions & 0 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Expand Up @@ -28,7 +28,10 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from botocore.exceptions import ClientError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema
from airflow.providers.amazon.aws.executors.ecs.utils import (
Expand Down Expand Up @@ -99,6 +102,60 @@ def __init__(self, *args, **kwargs):
self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.run_task_kwargs = self._load_run_kwargs()

def start(self):
"""
Make a test API call to check the health of the ECS Executor.
Deliberately use an invalid task ID, some potential outcomes in order:
1. "AccessDeniedException" is raised if there are insufficient permissions.
2. "ClusterNotFoundException" is raised if permissions exist but the cluster does not.
3. The API responds with a failure message if the cluster is found and there
are permissions, but the cluster itself has issues.
4. "InvalidParameterException" is raised if the permissions and cluster exist but the task does not.
The last one is considered a success state for the purposes of this check.
"""
check_health = conf.getboolean(
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
)

if not check_health:
return

self.log.info("Starting ECS Executor and determining health...")

success_status = "succeeded."
status = success_status

try:
invalid_task_id = "a" * 32
self.ecs.stop_task(cluster=self.cluster, task=invalid_task_id)

# If it got this far, something is wrong. stop_task() called with an
# invalid taskID should have thrown a ClientError. All known reasons are
# covered in the ``except`` block below, and this should never be reached.
status = "failed for an unknown reason. "
except ClientError as ex:
error_code = ex.response["Error"]["Code"]
error_message = ex.response["Error"]["Message"]

if ("InvalidParameterException" in error_code) and ("task was not found" in error_message):
# This failure is expected, and means we're healthy
pass
else:
# Catch all for unexpected failures
status = f"failed because: {error_message}. "
finally:
msg_prefix = "ECS Executor health check has %s"
if status == success_status:
self.log.info(msg_prefix, status)
else:
msg_error_suffix = (
"The ECS executor will not be able to run Airflow tasks until the issue is addressed. "
"Stopping the Airflow Scheduler from starting until the issue is resolved."
)
raise AirflowException(msg_prefix % status + msg_error_suffix)

def sync(self):
try:
self.sync_running_tasks()
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Expand Up @@ -46,6 +46,7 @@
"assign_public_ip": "False",
"launch_type": "FARGATE",
"platform_version": "LATEST",
"check_health_on_startup": "True",
}


Expand Down Expand Up @@ -96,6 +97,7 @@ class AllEcsConfigKeys(RunTaskKwargsConfigKeys):
AWS_CONN_ID = "conn_id"
RUN_TASK_KWARGS = "run_task_kwargs"
REGION_NAME = "region_name"
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"


class EcsExecutorException(Exception):
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -887,3 +887,10 @@ config:
type: string
example: '{"tags": {"key": "schema", "value": "1.0"}}'
default: ~
check_health_on_startup:
description: |
Whether or not to check the ECS Executor health on startup.
version_added: "8.11"
type: boolean
example: "True"
default: "True"
Expand Up @@ -106,6 +106,8 @@ Optional config options:
- MAX_RUN_TASK_ATTEMPTS - The maximum number of times the Ecs Executor
should attempt to run a task. This refers to instances where the task
fails to start (i.e. ECS API failures, container failures etc.)
- CHECK_HEALTH_ON_STARTUP - Whether or not to check the ECS Executor
health on startup

For a more detailed description of available options, including type
hints and examples, see the ``config_templates`` folder in the Amazon
Expand Down
90 changes: 85 additions & 5 deletions tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
Expand Up @@ -18,6 +18,7 @@

import datetime as dt
import json
import logging
import os
import re
import time
Expand All @@ -27,8 +28,10 @@

import pytest
import yaml
from botocore.exceptions import ClientError
from inflection import camelize

from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema
Expand Down Expand Up @@ -98,8 +101,7 @@ def mock_config():


@pytest.fixture
def mock_executor() -> AwsEcsExecutor:
"""Mock ECS to a repeatable starting state.."""
def set_env_vars():
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}".upper()] = "us-west-1"
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}".upper()] = "some-cluster"
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()] = "container-name"
Expand All @@ -110,6 +112,11 @@ def mock_executor() -> AwsEcsExecutor:
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}".upper()] = "sg1,sg2"
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}".upper()] = "sub1,sub2"
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS}".upper()] = "3"


@pytest.fixture
def mock_executor(set_env_vars) -> AwsEcsExecutor:
"""Mock ECS to a repeatable starting state.."""
executor = AwsEcsExecutor()

# Replace boto3 ECS client with mock.
Expand Down Expand Up @@ -788,9 +795,13 @@ def test_config_defaults_are_applied(self, assign_subnets):
found_keys = {convert_camel_to_snake(key): key for key in task_kwargs.keys()}

for expected_key, expected_value in CONFIG_DEFAULTS.items():
# "conn_id" and max_run_task_attempts are used by the executor, but are not expected to appear
# in the task_kwargs.
if expected_key in [AllEcsConfigKeys.AWS_CONN_ID, AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS]:
# conn_id, max_run_task_attempts, and check_health_on_startup are used by the executor,
# but are not expected to appear in the task_kwargs.
if expected_key in [
AllEcsConfigKeys.AWS_CONN_ID,
AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP,
]:
assert expected_key not in found_keys.keys()
else:
assert expected_key in found_keys.keys()
Expand Down Expand Up @@ -919,3 +930,72 @@ def test_that_provided_kwargs_are_moved_to_correct_nesting(self, assign_subnets)
assert run_task_kwargs_network_config[camelized_key] == "ENABLED"
else:
assert run_task_kwargs_network_config[camelized_key] == value.split(",")

def test_start_failure_with_invalid_permissions(self, set_env_vars):
executor = AwsEcsExecutor()

# Replace boto3 ECS client with mock.
ecs_mock = mock.Mock(spec=executor.ecs)
mock_resp = {
"Error": {
"Code": "AccessDeniedException",
"Message": "no identity-based policy allows the ecs:StopTask action",
}
}
ecs_mock.stop_task.side_effect = ClientError(mock_resp, "StopTask")

executor.ecs = ecs_mock

with pytest.raises(AirflowException, match=mock_resp["Error"]["Message"]):
executor.start()

def test_start_failure_with_invalid_cluster_name(self, set_env_vars):
executor = AwsEcsExecutor()

# Replace boto3 ECS client with mock.
ecs_mock = mock.Mock(spec=executor.ecs)
mock_resp = {"Error": {"Code": "ClusterNotFoundException", "Message": "Cluster not found."}}
ecs_mock.stop_task.side_effect = ClientError(mock_resp, "StopTask")

executor.ecs = ecs_mock

with pytest.raises(AirflowException, match=mock_resp["Error"]["Message"]):
executor.start()

def test_start_success(self, set_env_vars, caplog):
executor = AwsEcsExecutor()

# Replace boto3 ECS client with mock.
ecs_mock = mock.Mock(spec=executor.ecs)
mock_resp = {
"Error": {"Code": "InvalidParameterException", "Message": "The referenced task was not found."}
}
ecs_mock.stop_task.side_effect = ClientError(mock_resp, "StopTask")

executor.ecs = ecs_mock

caplog.set_level(logging.DEBUG)

executor.start()

assert "succeeded" in caplog.text

def test_start_health_check_config(self, set_env_vars):
executor = AwsEcsExecutor()

# Replace boto3 ECS client with mock.
ecs_mock = mock.Mock(spec=executor.ecs)
mock_resp = {
"Error": {"Code": "InvalidParameterException", "Message": "The referenced task was not found."}
}
ecs_mock.stop_task.side_effect = ClientError(mock_resp, "StopTask")

executor.ecs = ecs_mock

os.environ[
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP}".upper()
] = "False"

executor.start()

ecs_mock.stop_task.assert_not_called()

0 comments on commit ae9a7b8

Please sign in to comment.