Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-61: Allow multiple executors to load their CLI commands #39077

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 17 additions & 15 deletions airflow/cli/cli_parser.py
Expand Up @@ -56,19 +56,22 @@
airflow_commands = core_commands.copy() # make a copy to prevent bad interactions in tests

log = logging.getLogger(__name__)
try:
executor, _ = ExecutorLoader.import_default_executor_cls(validate=False)
airflow_commands.extend(executor.get_cli_commands())
except Exception:
executor_name = ExecutorLoader.get_default_executor_name()
log.exception("Failed to load CLI commands from executor: %s", executor_name)
log.error(
"Ensure all dependencies are met and try again. If using a Celery based executor install "
"a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a "
"7.4.0+ version of the CNCF provider"
)
# Do not re-raise the exception since we want the CLI to still function for
# other commands.


executors = [executor for executor, _ in ExecutorLoader.import_all_executors()]

for executor in executors:
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
try:
airflow_commands.extend(executor.get_cli_commands())
except Exception:
log.exception("Failed to load CLI commands from executor: %s", executor.__name__)
log.error(
"Ensure all dependencies are met and try again. If using a Celery based executor install "
"a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a "
"7.4.0+ version of the CNCF provider"
)
# Do not re-raise the exception since we want the CLI to still function for
# other commands.

try:
auth_mgr = get_auth_manager_cls()
Expand All @@ -90,8 +93,7 @@
dup = {k for k, v in Counter([c.name for c in airflow_commands]).items() if v > 1}
raise CliConflictError(
f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n"
f"This can be due to the executor '{ExecutorLoader.get_default_executor_name()}' "
f"redefining core airflow CLI commands."
f"This can be due to an Executor or Auth Manager redefining core airflow CLI commands."
o-nikolas marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down
9 changes: 9 additions & 0 deletions airflow/executors/executor_loader.py
Expand Up @@ -278,6 +278,15 @@ def _import_and_validate(path: str) -> type[BaseExecutor]:
)
return _import_and_validate(executor_name.module_path), executor_name.connector_source

@classmethod
def import_all_executors(cls) -> list[tuple[type[BaseExecutor], ConnectorSource]]:
executor_names = cls._get_executor_names()
executor_classes = []
for executor_name in executor_names:
executor_class = cls.import_executor_cls(executor_name, validate=False)
executor_classes.append(executor_class)
return executor_classes

@classmethod
def import_default_executor_cls(cls, validate: bool = True) -> tuple[type[BaseExecutor], ConnectorSource]:
"""
Expand Down
68 changes: 68 additions & 0 deletions tests/cli/test_cli_parser.py
Expand Up @@ -29,6 +29,7 @@
from importlib import reload
from io import StringIO
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -39,6 +40,8 @@
from airflow.configuration import AIRFLOW_HOME
from airflow.executors import executor_loader
from airflow.executors.local_executor import LocalExecutor
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from tests.test_utils.config import conf_vars

# Can not be `--snake_case` or contain uppercase letter
Expand Down Expand Up @@ -160,6 +163,71 @@ def test_dynamic_conflict_detection(self, cli_commands_mock: MagicMock):
# force re-evaluation of cli commands (done in top level code)
reload(cli_parser)

@patch.object(CeleryExecutor, "get_cli_commands")
@patch.object(AwsEcsExecutor, "get_cli_commands")
def test_hybrid_executor_get_cli_commands(
self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock
):
"""Test that if multiple executors are configured, then every executor loads its commands."""
ecs_executor_command = ActionCommand(
name="ecs_command",
help="test command for ecs executor",
func=lambda: None,
args=[],
)
ecs_executor_cli_commands_mock.return_value = [ecs_executor_command]

celery_executor_command = ActionCommand(
name="celery_command",
help="test command for celery executor",
func=lambda: None,
args=[],
)
celery_executor_cli_commands_mock.return_value = [celery_executor_command]
reload(executor_loader)
executor_loader.ExecutorLoader.import_all_executors = mock.Mock(
return_value=[(AwsEcsExecutor, ""), (CeleryExecutor, "")]
)

reload(cli_parser)
commands = [command.name for command in cli_parser.airflow_commands]
assert celery_executor_command.name in commands
assert ecs_executor_command.name in commands

@patch.object(CeleryExecutor, "get_cli_commands")
@patch.object(AwsEcsExecutor, "get_cli_commands")
def test_hybrid_executor_get_cli_commands_with_error(
self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock, caplog
):
"""Test that if multiple executors are configured, then every executor loads its commands.
If the executor fails to load its commands, the CLI should log the error, and continue loading"""
caplog.set_level("ERROR")
ecs_executor_command = ActionCommand(
name="ecs_command",
help="test command for ecs executor",
func=lambda: None,
args=[],
)
ecs_executor_cli_commands_mock.side_effect = Exception()

celery_executor_command = ActionCommand(
name="celery_command",
help="test command for celery executor",
func=lambda: None,
args=[],
)
celery_executor_cli_commands_mock.return_value = [celery_executor_command]
reload(executor_loader)
executor_loader.ExecutorLoader.import_all_executors = mock.Mock(
return_value=[(AwsEcsExecutor, ""), (CeleryExecutor, "")]
)

reload(cli_parser)
commands = [command.name for command in cli_parser.airflow_commands]
assert celery_executor_command.name in commands
assert ecs_executor_command.name not in commands
assert "Failed to load CLI commands from executor: AwsEcsExecutor" in caplog.messages[0]

def test_falsy_default_value(self):
arg = cli_config.Arg(("--test",), default=0, type=int)
parser = argparse.ArgumentParser()
Expand Down
40 changes: 40 additions & 0 deletions tests/executors/test_executor_loader.py
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowConfigException
from airflow.executors import executor_loader
from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, ExecutorName
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from tests.test_utils.config import conf_vars

Expand Down Expand Up @@ -98,6 +99,45 @@ def test_should_support_custom_path(self):
assert executor.name == ExecutorName("tests.executors.test_executor_loader.FakeExecutor")
assert executor.name.connector_source == ConnectorSource.CUSTOM_PATH

@pytest.mark.parametrize(
("executor_config", "expected_executor_classes", "expected_connector_sources"),
[
# Just one executor
(
"CeleryExecutor",
[
CeleryExecutor,
],
[
ConnectorSource.CORE,
],
),
# Multiple Executors,
(
"CeleryExecutor, airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor",
[
CeleryExecutor,
AwsEcsExecutor,
],
[
ConnectorSource.CORE,
ConnectorSource.CUSTOM_PATH,
],
),
],
)
def test_import_all_executors(
self, executor_config, expected_executor_classes, expected_connector_sources
):
ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
with conf_vars({("core", "executor"): executor_config}):
executors = [executor for executor, _ in ExecutorLoader.import_all_executors()]
connector_sources = [
connector_source for _, connector_source in ExecutorLoader.import_all_executors()
]
assert executors == expected_executor_classes
assert connector_sources == expected_connector_sources

@pytest.mark.parametrize(
("executor_config", "expected_executors_list"),
[
Expand Down