Skip to content

Commit

Permalink
Generalize input validation and format output
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Mar 19, 2024
1 parent 60e59a1 commit afa2bb6
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 32 deletions.
18 changes: 14 additions & 4 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,15 +1098,25 @@ class GroupCommand(NamedTuple):
),
ActionCommand(
name="pause",
help="Pause a DAG",
help="Pause DAG(s)",
description=(
"Pause one or more DAGs. This command allows to halt the execution of specified DAGs, "
"disabling further task scheduling. Use `--treat-dag-as-regex` to target multiple DAGs by "
"treating the `--dag-id` as a regex pattern."
),
func=lazy_load_command("airflow.cli.commands.dag_command.dag_pause"),
args=(ARG_DAG_ID, ARG_SUBDIR, ARG_TREAT_DAG_AS_REGEX, ARG_YES, ARG_VERBOSE),
args=(ARG_DAG_ID, ARG_SUBDIR, ARG_TREAT_DAG_AS_REGEX, ARG_YES, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
name="unpause",
help="Resume a paused DAG",
help="Resume paused DAG(s)",
description=(
"Resume one or more DAGs. This command allows to restore the execution of specified "
"DAGs, enabling further task scheduling. Use `--treat-dag-as-regex` to target multiple DAGs "
"treating the `--dag-id` as a regex pattern."
),
func=lazy_load_command("airflow.cli.commands.dag_command.dag_unpause"),
args=(ARG_DAG_ID, ARG_SUBDIR, ARG_TREAT_DAG_AS_REGEX, ARG_YES, ARG_VERBOSE),
args=(ARG_DAG_ID, ARG_SUBDIR, ARG_TREAT_DAG_AS_REGEX, ARG_YES, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
name="trigger",
Expand Down
45 changes: 25 additions & 20 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,35 +216,40 @@ def dag_unpause(args) -> None:
def set_is_paused(is_paused: bool, args) -> None:
"""Set is_paused for DAG by a given dag_id."""
should_apply = True
dangerous_inputs = [
".",
".?",
".*",
".*?",
"^.",
"^.*",
"^.*?",
"^.*$",
r"[^\n]*",
"(?s:.*)",
r"[^\n]?",
"(?s:.*?)",
r"[\s\S]*",
r"[\w\W]*",
dags = [
dag
for dag in get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
if is_paused != dag.get_is_paused()
]
if not args.yes and args.treat_dag_as_regex and args.dag_id in dangerous_inputs:
question = f"You are about to {'un' if not is_paused else ''}pause all DAGs.\n\nAre you sure? [y/n]"

if not dags:
raise AirflowException(f"No {'un' if is_paused else ''}paused DAGs were found")

if not args.yes and args.treat_dag_as_regex:
dags_ids = [dag.dag_id for dag in dags]
question = (
f"You are about to {'un' if not is_paused else ''}pause {len(dags_ids)} DAGs:\n"
f"{','.join(dags_ids)}"
f"\n\nAre you sure? [y/n]"
)
should_apply = ask_yesno(question)

if should_apply:
dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
dags_models = [DagModel.get_dagmodel(dag.dag_id) for dag in dags]
for dag_model in dags_models:
if dag_model is not None:
dag_model.set_is_paused(is_paused=is_paused)
print(f"Dag: {dag_model.dag_id}, paused: {is_paused}")

AirflowConsole().print_as(
data=[
{"dag_id": dag.dag_id, "is_paused": dag.get_is_paused()}
for dag in dags_models
if dag is not None
],
output=args.output,
)
else:
print("Operation cancelled")
print("Operation cancelled by user")


@providers_configuration_loaded
Expand Down
24 changes: 16 additions & 8 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,41 +651,49 @@ def test_pause(self):
args = self.parser.parse_args(["dags", "pause", "example_bash_operator"])
dag_command.dag_pause(args)
assert self.dagbag.dags["example_bash_operator"].get_is_paused()

args = self.parser.parse_args(["dags", "unpause", "example_bash_operator"])
dag_command.dag_unpause(args)
assert not self.dagbag.dags["example_bash_operator"].get_is_paused()

def test_pause_regex(self):
@mock.patch("airflow.cli.commands.dag_command.ask_yesno")
def test_pause_regex(self, mock_yesno):
args = self.parser.parse_args(["dags", "pause", "^example_.*$", "--treat-dag-as-regex"])
dag_command.dag_pause(args)
mock_yesno.assert_called_once()
assert self.dagbag.dags["example_bash_decorator"].get_is_paused()
assert self.dagbag.dags["example_kubernetes_executor"].get_is_paused()
assert self.dagbag.dags["example_xcom_args"].get_is_paused()

args = self.parser.parse_args(["dags", "pause", "^example_.*$", "--treat-dag-as-regex"])
args = self.parser.parse_args(["dags", "unpause", "^example_.*$", "--treat-dag-as-regex"])
dag_command.dag_unpause(args)
assert not self.dagbag.dags["example_bash_decorator"].get_is_paused()
assert not self.dagbag.dags["example_kubernetes_executor"].get_is_paused()
assert not self.dagbag.dags["example_xcom_args"].get_is_paused()

@mock.patch("airflow.cli.commands.dag_command.ask_yesno")
def test_pause_regex_all_dags_confirmation(self, mock_yesno):
args = self.parser.parse_args(["dags", "pause", ".*", "--treat-dag-as-regex"])
def test_pause_regex_operation_cancelled(self, ask_yesno, capsys):
args = self.parser.parse_args(["dags", "pause", "example_bash_operator", "--treat-dag-as-regex"])
ask_yesno.return_value = False
dag_command.dag_pause(args)
mock_yesno.assert_called_once()
stdout = capsys.readouterr().out
assert "Operation cancelled by user" in stdout

@mock.patch("airflow.cli.commands.dag_command.ask_yesno")
def test_pause_regex_all_dags_yes(self, mock_yesno):
def test_pause_regex_yes(self, mock_yesno):
args = self.parser.parse_args(["dags", "pause", ".*", "--treat-dag-as-regex", "--yes"])
dag_command.dag_pause(args)
mock_yesno.assert_not_called()
dag_command.dag_unpause(args)

def test_pause_non_existing_dag_error(self):
args = self.parser.parse_args(["dags", "pause", "non_existing_dag"])
with pytest.raises(AirflowException):
dag_command.dag_pause(args)

def test_unpause_already_unpaused_dag_error(self):
args = self.parser.parse_args(["dags", "unpause", "example_bash_operator", "--yes"])
with pytest.raises(AirflowException, match="No paused DAGs were found"):
dag_command.dag_unpause(args)

def test_trigger_dag(self):
dag_command.dag_trigger(
self.parser.parse_args(
Expand Down

0 comments on commit afa2bb6

Please sign in to comment.