Skip to content

Commit

Permalink
AIP-61: Allow multiple executors to load their CLI commands (#39077)
Browse files Browse the repository at this point in the history
* Load executors inside of a Try/except block to prevent failures if an Executor fails to load
Add unit test to test failure to load Executor case
  • Loading branch information
syedahsn committed Apr 22, 2024
1 parent acc75c9 commit c8f34f5
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 15 deletions.
31 changes: 16 additions & 15 deletions airflow/cli/cli_parser.py
Expand Up @@ -56,19 +56,21 @@
airflow_commands = core_commands.copy() # make a copy to prevent bad interactions in tests

log = logging.getLogger(__name__)
try:
executor, _ = ExecutorLoader.import_default_executor_cls(validate=False)
airflow_commands.extend(executor.get_cli_commands())
except Exception:
executor_name = ExecutorLoader.get_default_executor_name()
log.exception("Failed to load CLI commands from executor: %s", executor_name)
log.error(
"Ensure all dependencies are met and try again. If using a Celery based executor install "
"a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a "
"7.4.0+ version of the CNCF provider"
)
# Do not re-raise the exception since we want the CLI to still function for
# other commands.


for executor_name in ExecutorLoader.get_executor_names():
try:
executor, _ = ExecutorLoader.import_executor_cls(executor_name)
airflow_commands.extend(executor.get_cli_commands())
except Exception:
log.exception("Failed to load CLI commands from executor: %s", executor_name)
log.error(
"Ensure all dependencies are met and try again. If using a Celery based executor install "
"a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a "
"7.4.0+ version of the CNCF provider"
)
# Do not re-raise the exception since we want the CLI to still function for
# other commands.

try:
auth_mgr = get_auth_manager_cls()
Expand All @@ -90,8 +92,7 @@
dup = {k for k, v in Counter([c.name for c in airflow_commands]).items() if v > 1}
raise CliConflictError(
f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n"
f"This can be due to the executor '{ExecutorLoader.get_default_executor_name()}' "
f"redefining core airflow CLI commands."
f"This can be due to an Executor or Auth Manager redefining core airflow CLI commands."
)


Expand Down
8 changes: 8 additions & 0 deletions airflow/executors/executor_loader.py
Expand Up @@ -153,6 +153,14 @@ def _get_executor_names(cls) -> list[ExecutorName]:

return executor_names

@classmethod
def get_executor_names(cls) -> list[ExecutorName]:
"""Return the executor names from Airflow configuration.
:return: List of executor names from Airflow configuration
"""
return cls._get_executor_names()

@classmethod
def get_default_executor_name(cls) -> ExecutorName:
"""Return the default executor name from Airflow configuration.
Expand Down
106 changes: 106 additions & 0 deletions tests/cli/test_cli_parser.py
Expand Up @@ -29,6 +29,7 @@
from importlib import reload
from io import StringIO
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -38,7 +39,10 @@
from airflow.cli.utils import CliConflictError
from airflow.configuration import AIRFLOW_HOME
from airflow.executors import executor_loader
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.local_executor import LocalExecutor
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from tests.test_utils.config import conf_vars

# Can not be `--snake_case` or contain uppercase letter
Expand Down Expand Up @@ -160,6 +164,108 @@ def test_dynamic_conflict_detection(self, cli_commands_mock: MagicMock):
# force re-evaluation of cli commands (done in top level code)
reload(cli_parser)

@patch.object(CeleryExecutor, "get_cli_commands")
@patch.object(AwsEcsExecutor, "get_cli_commands")
def test_hybrid_executor_get_cli_commands(
self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock
):
"""Test that if multiple executors are configured, then every executor loads its commands."""
ecs_executor_command = ActionCommand(
name="ecs_command",
help="test command for ecs executor",
func=lambda: None,
args=[],
)
ecs_executor_cli_commands_mock.return_value = [ecs_executor_command]

celery_executor_command = ActionCommand(
name="celery_command",
help="test command for celery executor",
func=lambda: None,
args=[],
)
celery_executor_cli_commands_mock.return_value = [celery_executor_command]
reload(executor_loader)
executor_loader.ExecutorLoader.get_executor_names = mock.Mock(
return_value=[
ExecutorName("airflow.providers.celery.executors.celery_executor.CeleryExecutor"),
ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"),
]
)

reload(cli_parser)
commands = [command.name for command in cli_parser.airflow_commands]
assert celery_executor_command.name in commands
assert ecs_executor_command.name in commands

@patch.object(CeleryExecutor, "get_cli_commands")
@patch.object(AwsEcsExecutor, "get_cli_commands")
def test_hybrid_executor_get_cli_commands_with_error(
self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock, caplog
):
"""Test that if multiple executors are configured, then every executor loads its commands.
If the executor fails to load its commands, the CLI should log the error, and continue loading"""
caplog.set_level("ERROR")
ecs_executor_command = ActionCommand(
name="ecs_command",
help="test command for ecs executor",
func=lambda: None,
args=[],
)
ecs_executor_cli_commands_mock.side_effect = Exception()

celery_executor_command = ActionCommand(
name="celery_command",
help="test command for celery executor",
func=lambda: None,
args=[],
)
celery_executor_cli_commands_mock.return_value = [celery_executor_command]
reload(executor_loader)
executor_loader.ExecutorLoader.get_executor_names = mock.Mock(
return_value=[
ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"),
ExecutorName("airflow.providers.celery.executors.celery_executor.CeleryExecutor"),
]
)

reload(cli_parser)
commands = [command.name for command in cli_parser.airflow_commands]
assert celery_executor_command.name in commands
assert ecs_executor_command.name not in commands
assert (
"Failed to load CLI commands from executor: airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
in caplog.messages[0]
)

@patch.object(AwsEcsExecutor, "get_cli_commands")
def test_cli_parser_fail_to_load_executor(self, ecs_executor_cli_commands_mock, caplog):
caplog.set_level("ERROR")

ecs_executor_command = ActionCommand(
name="ecs_command",
help="test command for ecs executor",
func=lambda: None,
args=[],
)
ecs_executor_cli_commands_mock.return_value = [ecs_executor_command]

reload(executor_loader)
executor_loader.ExecutorLoader.get_executor_names = mock.Mock(
return_value=[
ExecutorName("airflow.providers.incorrect.executor.Executor"),
ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"),
]
)

reload(cli_parser)
commands = [command.name for command in cli_parser.airflow_commands]
assert ecs_executor_command.name in commands
assert (
"Failed to load CLI commands from executor: airflow.providers.incorrect.executor.Executor"
in caplog.messages[0]
)

def test_falsy_default_value(self):
arg = cli_config.Arg(("--test",), default=0, type=int)
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit c8f34f5

Please sign in to comment.