diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index a295bd21ba20c..7e12383cf57c0 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -56,19 +56,21 @@ airflow_commands = core_commands.copy() # make a copy to prevent bad interactions in tests log = logging.getLogger(__name__) -try: - executor, _ = ExecutorLoader.import_default_executor_cls(validate=False) - airflow_commands.extend(executor.get_cli_commands()) -except Exception: - executor_name = ExecutorLoader.get_default_executor_name() - log.exception("Failed to load CLI commands from executor: %s", executor_name) - log.error( - "Ensure all dependencies are met and try again. If using a Celery based executor install " - "a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a " - "7.4.0+ version of the CNCF provider" - ) - # Do not re-raise the exception since we want the CLI to still function for - # other commands. + + +for executor_name in ExecutorLoader.get_executor_names(): + try: + executor, _ = ExecutorLoader.import_executor_cls(executor_name) + airflow_commands.extend(executor.get_cli_commands()) + except Exception: + log.exception("Failed to load CLI commands from executor: %s", executor_name) + log.error( + "Ensure all dependencies are met and try again. If using a Celery based executor install " + "a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a " + "7.4.0+ version of the CNCF provider" + ) + # Do not re-raise the exception since we want the CLI to still function for + # other commands. try: auth_mgr = get_auth_manager_cls() @@ -90,8 +92,7 @@ dup = {k for k, v in Counter([c.name for c in airflow_commands]).items() if v > 1} raise CliConflictError( f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n" - f"This can be due to the executor '{ExecutorLoader.get_default_executor_name()}' " - f"redefining core airflow CLI commands." + f"This can be due to an Executor or Auth Manager redefining core airflow CLI commands." ) diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index 83aeb77e50a61..f35398041c9ff 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -153,6 +153,14 @@ def _get_executor_names(cls) -> list[ExecutorName]: return executor_names + @classmethod + def get_executor_names(cls) -> list[ExecutorName]: + """Return the executor names from Airflow configuration. + + :return: List of executor names from Airflow configuration + """ + return cls._get_executor_names() + @classmethod def get_default_executor_name(cls) -> ExecutorName: """Return the default executor name from Airflow configuration. diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py index d105b0161459a..9fd5bf48b0ef7 100644 --- a/tests/cli/test_cli_parser.py +++ b/tests/cli/test_cli_parser.py @@ -29,6 +29,7 @@ from importlib import reload from io import StringIO from pathlib import Path +from unittest import mock from unittest.mock import MagicMock, patch import pytest @@ -38,7 +39,10 @@ from airflow.cli.utils import CliConflictError from airflow.configuration import AIRFLOW_HOME from airflow.executors import executor_loader +from airflow.executors.executor_utils import ExecutorName from airflow.executors.local_executor import LocalExecutor +from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor +from airflow.providers.celery.executors.celery_executor import CeleryExecutor from tests.test_utils.config import conf_vars # Can not be `--snake_case` or contain uppercase letter @@ -160,6 +164,108 @@ def test_dynamic_conflict_detection(self, cli_commands_mock: MagicMock): # force re-evaluation of cli commands (done in top level code) reload(cli_parser) + @patch.object(CeleryExecutor, "get_cli_commands") + @patch.object(AwsEcsExecutor, "get_cli_commands") + def test_hybrid_executor_get_cli_commands( + self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock + ): + """Test that if multiple executors are configured, then every executor loads its commands.""" + ecs_executor_command = ActionCommand( + name="ecs_command", + help="test command for ecs executor", + func=lambda: None, + args=[], + ) + ecs_executor_cli_commands_mock.return_value = [ecs_executor_command] + + celery_executor_command = ActionCommand( + name="celery_command", + help="test command for celery executor", + func=lambda: None, + args=[], + ) + celery_executor_cli_commands_mock.return_value = [celery_executor_command] + reload(executor_loader) + executor_loader.ExecutorLoader.get_executor_names = mock.Mock( + return_value=[ + ExecutorName("airflow.providers.celery.executors.celery_executor.CeleryExecutor"), + ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"), + ] + ) + + reload(cli_parser) + commands = [command.name for command in cli_parser.airflow_commands] + assert celery_executor_command.name in commands + assert ecs_executor_command.name in commands + + @patch.object(CeleryExecutor, "get_cli_commands") + @patch.object(AwsEcsExecutor, "get_cli_commands") + def test_hybrid_executor_get_cli_commands_with_error( + self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock, caplog + ): + """Test that if multiple executors are configured, then every executor loads its commands. + If the executor fails to load its commands, the CLI should log the error, and continue loading""" + caplog.set_level("ERROR") + ecs_executor_command = ActionCommand( + name="ecs_command", + help="test command for ecs executor", + func=lambda: None, + args=[], + ) + ecs_executor_cli_commands_mock.side_effect = Exception() + + celery_executor_command = ActionCommand( + name="celery_command", + help="test command for celery executor", + func=lambda: None, + args=[], + ) + celery_executor_cli_commands_mock.return_value = [celery_executor_command] + reload(executor_loader) + executor_loader.ExecutorLoader.get_executor_names = mock.Mock( + return_value=[ + ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"), + ExecutorName("airflow.providers.celery.executors.celery_executor.CeleryExecutor"), + ] + ) + + reload(cli_parser) + commands = [command.name for command in cli_parser.airflow_commands] + assert celery_executor_command.name in commands + assert ecs_executor_command.name not in commands + assert ( + "Failed to load CLI commands from executor: airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor" + in caplog.messages[0] + ) + + @patch.object(AwsEcsExecutor, "get_cli_commands") + def test_cli_parser_fail_to_load_executor(self, ecs_executor_cli_commands_mock, caplog): + caplog.set_level("ERROR") + + ecs_executor_command = ActionCommand( + name="ecs_command", + help="test command for ecs executor", + func=lambda: None, + args=[], + ) + ecs_executor_cli_commands_mock.return_value = [ecs_executor_command] + + reload(executor_loader) + executor_loader.ExecutorLoader.get_executor_names = mock.Mock( + return_value=[ + ExecutorName("airflow.providers.incorrect.executor.Executor"), + ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"), + ] + ) + + reload(cli_parser) + commands = [command.name for command in cli_parser.airflow_commands] + assert ecs_executor_command.name in commands + assert ( + "Failed to load CLI commands from executor: airflow.providers.incorrect.executor.Executor" + in caplog.messages[0] + ) + def test_falsy_default_value(self): arg = cli_config.Arg(("--test",), default=0, type=int) parser = argparse.ArgumentParser()