Skip to content

Commit

Permalink
Add a mechanism to warn if executors override existing CLI commands (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz committed Aug 19, 2023
1 parent d9814eb commit 1945c1a
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 65 deletions.
16 changes: 14 additions & 2 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from __future__ import annotations

import argparse
import collections
import logging
from argparse import Action
from functools import lru_cache
Expand All @@ -41,11 +42,12 @@
GroupCommand,
core_commands,
)
from airflow.cli.utils import CliConflictError
from airflow.exceptions import AirflowException
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.helpers import partition

airflow_commands = core_commands
airflow_commands = core_commands.copy() # make a copy to prevent bad interactions in tests

log = logging.getLogger(__name__)
try:
Expand All @@ -59,13 +61,23 @@
"a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a "
"7.4.0+ version of the CNCF provider"
)
# Do no re-raise the exception since we want the CLI to still function for
# Do not re-raise the exception since we want the CLI to still function for
# other commands.


ALL_COMMANDS_DICT: dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands}


# Check if sub-commands are defined twice, which could be an issue.
if len(ALL_COMMANDS_DICT) < len(airflow_commands):
dup = {k for k, v in collections.Counter([c.name for c in airflow_commands]).items() if v > 1}
raise CliConflictError(
f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n"
f"This can be due to the executor '{ExecutorLoader.get_default_executor_name()}' "
f"redefining core airflow CLI commands."
)


class AirflowHelpFormatter(RichHelpFormatter):
"""
Custom help formatter to display help message.
Expand Down
6 changes: 6 additions & 0 deletions airflow/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import sys


class CliConflictError(Exception):
"""Error for when CLI commands are defined twice by different sources."""

pass


def is_stdout(fileio: io.IOBase) -> bool:
"""Check whether a file IO is stdout.
Expand Down
1 change: 1 addition & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def get_cli_commands() -> list[GroupCommand]:
Override this method to expose commands via Airflow CLI to manage this executor. This can
be commands to setup/teardown the executor, inspect state, etc.
Make sure to choose unique names for those commands, to avoid collisions.
"""
return []

Expand Down
24 changes: 13 additions & 11 deletions tests/cli/commands/test_celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import importlib
from argparse import Namespace
from tempfile import NamedTemporaryFile
from unittest import mock
Expand Down Expand Up @@ -65,11 +66,12 @@ def test_validate_session_dbapi_exception(self, mock_session):
class TestCeleryStopCommand:
@classmethod
def setup_class(cls):
cls.parser = cli_parser.get_parser()
with conf_vars({("core", "executor"): "CeleryExecutor"}):
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@mock.patch("airflow.cli.commands.celery_command.psutil.Process")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_if_right_pid_is_read(self, mock_process, mock_setup_locations):
args = self.parser.parse_args(["celery", "stop"])
pid = "123"
Expand All @@ -90,7 +92,6 @@ def test_if_right_pid_is_read(self, mock_process, mock_setup_locations):
@mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
@mock.patch("airflow.providers.celery.executors.celery_executor.app")
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_same_pid_file_is_used_in_start_and_stop(
self, mock_setup_locations, mock_celery_app, mock_read_pid_from_pidfile
):
Expand All @@ -116,7 +117,6 @@ def test_same_pid_file_is_used_in_start_and_stop(
@mock.patch("airflow.providers.celery.executors.celery_executor.app")
@mock.patch("airflow.cli.commands.celery_command.psutil.Process")
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_custom_pid_file_is_used_in_start_and_stop(
self,
mock_setup_locations,
Expand Down Expand Up @@ -147,12 +147,13 @@ def test_custom_pid_file_is_used_in_start_and_stop(
class TestWorkerStart:
@classmethod
def setup_class(cls):
cls.parser = cli_parser.get_parser()
with conf_vars({("core", "executor"): "CeleryExecutor"}):
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@mock.patch("airflow.cli.commands.celery_command.Process")
@mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_worker_started_with_required_arguments(self, mock_celery_app, mock_popen, mock_locations):
pid_file = "pid_file"
mock_locations.return_value = (pid_file, None, None, None)
Expand Down Expand Up @@ -208,11 +209,12 @@ def test_worker_started_with_required_arguments(self, mock_celery_app, mock_pope
class TestWorkerFailure:
@classmethod
def setup_class(cls):
cls.parser = cli_parser.get_parser()
with conf_vars({("core", "executor"): "CeleryExecutor"}):
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

@mock.patch("airflow.cli.commands.celery_command.Process")
@mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen):
args = self.parser.parse_args(["celery", "worker"])
mock_celery_app.run.side_effect = Exception("Mock exception to trigger runtime error")
Expand All @@ -226,10 +228,11 @@ def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen):
class TestFlowerCommand:
@classmethod
def setup_class(cls):
cls.parser = cli_parser.get_parser()
with conf_vars({("core", "executor"): "CeleryExecutor"}):
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

@mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_run_command(self, mock_celery_app):
args = self.parser.parse_args(
[
Expand Down Expand Up @@ -268,7 +271,6 @@ def test_run_command(self, mock_celery_app):
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@mock.patch("airflow.cli.commands.celery_command.daemon")
@mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_run_command_daemon(self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file):
mock_setup_locations.return_value = (
mock.MagicMock(name="pidfile"),
Expand Down
10 changes: 8 additions & 2 deletions tests/cli/commands/test_kubernetes_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import importlib
import os
import tempfile
from unittest import mock
Expand All @@ -26,12 +27,15 @@

from airflow.cli import cli_parser
from airflow.cli.commands import kubernetes_command
from tests.test_utils.config import conf_vars


class TestGenerateDagYamlCommand:
@classmethod
def setup_class(cls):
cls.parser = cli_parser.get_parser()
with conf_vars({("core", "executor"): "KubernetesExecutor"}):
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

def test_generate_dag_yaml(self):
with tempfile.TemporaryDirectory("airflow_dry_run_test/") as directory:
Expand Down Expand Up @@ -61,7 +65,9 @@ class TestCleanUpPodsCommand:

@classmethod
def setup_class(cls):
cls.parser = cli_parser.get_parser()
with conf_vars({("core", "executor"): "KubernetesExecutor"}):
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

@mock.patch("kubernetes.client.CoreV1Api.delete_namespaced_pod")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.config.load_incluster_config")
Expand Down
17 changes: 8 additions & 9 deletions tests/cli/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

from airflow import models
from airflow.cli import cli_parser
from airflow.executors import local_executor
from airflow.providers.celery.executors import celery_executor, celery_kubernetes_executor
from airflow.providers.cncf.kubernetes.executors import kubernetes_executor, local_kubernetes_executor
from tests.test_utils.config import conf_vars

# Create custom executors here because conftest is imported first
Expand All @@ -34,17 +36,14 @@
custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore
"CustomCeleryKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {}
)
custom_executor_module.CustomCeleryExecutor = type( # type: ignore
"CustomLocalExecutor", (celery_executor.CeleryExecutor,), {}
)
custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore
"CustomLocalKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {}
custom_executor_module.CustomLocalExecutor = type( # type: ignore
"CustomLocalExecutor", (local_executor.LocalExecutor,), {}
)
custom_executor_module.CustomCeleryExecutor = type( # type: ignore
"CustomKubernetesExecutor", (celery_executor.CeleryExecutor,), {}
custom_executor_module.CustomLocalKubernetesExecutor = type( # type: ignore
"CustomLocalKubernetesExecutor", (local_kubernetes_executor.LocalKubernetesExecutor,), {}
)
custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore
"CustomCeleryKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {}
custom_executor_module.CustomKubernetesExecutor = type( # type: ignore
"CustomKubernetesExecutor", (kubernetes_executor.KubernetesExecutor,), {}
)
sys.modules["custom_executor"] = custom_executor_module

Expand Down
79 changes: 38 additions & 41 deletions tests/cli/test_cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from collections import Counter
from importlib import reload
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

from airflow.cli import cli_config, cli_parser
from airflow.cli.cli_config import ActionCommand, lazy_load_command
from airflow.cli.cli_config import ActionCommand, core_commands, lazy_load_command
from airflow.cli.utils import CliConflictError
from airflow.configuration import AIRFLOW_HOME
from airflow.executors.local_executor import LocalExecutor
from tests.test_utils.config import conf_vars

# Can not be `--snake_case` or contain uppercase letter
Expand Down Expand Up @@ -133,6 +135,28 @@ def test_subcommand_arg_flag_conflict(self):
f"short option flags {conflict_short_option}"
)

@patch.object(LocalExecutor, "get_cli_commands")
def test_dynamic_conflict_detection(self, cli_commands_mock: MagicMock):
core_commands.append(
ActionCommand(
name="test_command",
help="does nothing",
func=lambda: None,
args=[],
)
)
cli_commands_mock.return_value = [
ActionCommand(
name="test_command",
help="just a command that'll conflict with one defined in core",
func=lambda: None,
args=[],
)
]
with pytest.raises(CliConflictError, match="test_command"):
# force re-evaluation of cli commands (done in top level code)
reload(cli_parser)

def test_falsy_default_value(self):
arg = cli_parser.Arg(("--test",), default=0, type=int)
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -205,57 +229,38 @@ def test_positive_int(self):
cli_config.positive_int(allow_zero=True)("-1")

@pytest.mark.parametrize(
"executor",
"command",
[
"celery",
"kubernetes",
],
)
def test_dag_parser_require_celery_executor(self, executor):
def test_executor_specific_commands_not_accessible(self, command):
with conf_vars({("core", "executor"): "SequentialExecutor"}), contextlib.redirect_stderr(
io.StringIO()
) as stderr:
reload(cli_parser)
parser = cli_parser.get_parser()
with pytest.raises(SystemExit):
parser.parse_args([executor])
stderr = stderr.getvalue()
assert (f"airflow command error: argument GROUP_OR_COMMAND: invalid choice: '{executor}'") in stderr

@pytest.mark.parametrize(
"executor",
[
"CeleryExecutor",
"CeleryKubernetesExecutor",
"custom_executor.CustomCeleryExecutor",
"custom_executor.CustomCeleryKubernetesExecutor",
],
)
def test_dag_parser_celery_command_accept_celery_executor(self, executor):
with conf_vars({("core", "executor"): executor}), contextlib.redirect_stderr(io.StringIO()) as stderr:
reload(cli_parser)
parser = cli_parser.get_parser()
with pytest.raises(SystemExit):
parser.parse_args(["celery"])
parser.parse_args([command])
stderr = stderr.getvalue()
assert (
"airflow celery command error: the following arguments are required: COMMAND, see help above."
) in stderr
assert (f"airflow command error: argument GROUP_OR_COMMAND: invalid choice: '{command}'") in stderr

@pytest.mark.parametrize(
"executor,expected_args",
[
("CeleryExecutor", ["celery"]),
("CeleryKubernetesExecutor", ["celery", "kubernetes"]),
("custom_executor.CustomCeleryExecutor", ["celery"]),
("custom_executor.CustomCeleryKubernetesExecutor", ["celery", "kubernetes"]),
("KubernetesExecutor", ["kubernetes"]),
("custom_executor.KubernetesExecutor", ["kubernetes"]),
("LocalExecutor", []),
("LocalKubernetesExecutor", ["kubernetes"]),
("custom_executor.LocalExecutor", []),
("custom_executor.LocalKubernetesExecutor", ["kubernetes"]),
("SequentialExecutor", []),
# custom executors are mapped to the regular ones in `conftest.py`
("custom_executor.CustomLocalExecutor", []),
("custom_executor.CustomLocalKubernetesExecutor", ["kubernetes"]),
("custom_executor.CustomCeleryExecutor", ["celery"]),
("custom_executor.CustomCeleryKubernetesExecutor", ["celery", "kubernetes"]),
("custom_executor.CustomKubernetesExecutor", ["kubernetes"]),
],
)
def test_cli_parser_executors(self, executor, expected_args):
Expand All @@ -266,20 +271,12 @@ def test_cli_parser_executors(self, executor, expected_args):
) as stderr:
reload(cli_parser)
parser = cli_parser.get_parser()
with pytest.raises(SystemExit) as e:
with pytest.raises(SystemExit) as e: # running the help command exits, so we prevent that
parser.parse_args([expected_arg, "--help"])
assert e.value.code == 0
assert e.value.code == 0, stderr.getvalue() # return code 0 == no problem
stderr = stderr.getvalue()
assert "airflow command error" not in stderr

def test_dag_parser_config_command_dont_required_celery_executor(self):
with conf_vars({("core", "executor"): "CeleryExecutor"}), contextlib.redirect_stderr(
io.StringIO()
) as stdout:
parser = cli_parser.get_parser()
parser.parse_args(["config", "get-value", "celery", "broker-url"])
assert stdout is not None

def test_non_existing_directory_raises_when_metavar_is_dir_for_db_export_cleaned(self):
"""Test that the error message is correct when the directory does not exist."""
with contextlib.redirect_stderr(io.StringIO()) as stderr:
Expand Down

0 comments on commit 1945c1a

Please sign in to comment.