diff --git a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py index de07d0d3751a9..7bb609ca77c05 100644 --- a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py +++ b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py @@ -54,75 +54,75 @@ def date_param(): "auth list-envs", # Assets commands "assets list", - "assets get --asset-id=1", - "assets create-event --asset-id=1", + "assets create-event 1", + "assets get 1", # Backfill commands - "backfill list", + "backfill list example_bash_operator", # Config commands - "config get --section core --option executor", + "config get core executor", "config list", "config lint", # Connections commands - "connections create --connection-id=test_con --conn-type=mysql --password=TEST_PASS -o json", + "connections create test_con mysql --password=TEST_PASS -o json", "connections list", "connections list -o yaml", "connections list -o table", - "connections get --conn-id=test_con", - "connections get --conn-id=test_con -o json", - "connections update --connection-id=test_con --conn-type=postgres", + "connections get test_con", + "connections get test_con -o json", + "connections update test_con postgres", "connections import tests/airflowctl_tests/fixtures/test_connections.json", - "connections delete --conn-id=test_con", - "connections delete --conn-id=test_import_conn", + "connections delete test_con", + "connections delete test_import_conn", # DAGs commands "dags list", - "dags get --dag-id=example_bash_operator", - "dags get-details --dag-id=example_bash_operator", - "dags get-stats --dag-ids=example_bash_operator", - "dags get-version --dag-id=example_bash_operator --version-number=1", + "dags get example_bash_operator", + "dags get-details example_bash_operator", + "dags get-stats example_bash_operator", + "dags get-version example_bash_operator 1", "dags list-import-errors", - "dags list-version --dag-id=example_bash_operator", + "dags list-version example_bash_operator", "dags list-warning", # Order of trigger and pause/unpause is important for test stability because state checked - "dags trigger --dag-id=example_bash_operator --logical-date={date_param} --run-after={date_param}", + "dags trigger example_bash_operator --logical-date={date_param} --run-after={date_param}", # Test trigger without logical-date (should default to now) - "dags trigger --dag-id=example_bash_operator", + "dags trigger example_bash_operator", "dags pause example_bash_operator", "dags unpause example_bash_operator", # DAG Run commands - 'dagrun get --dag-id=example_bash_operator --dag-run-id="manual__{date_param}"', - "dags update --dag-id=example_bash_operator --no-is-paused", + 'dagrun get example_bash_operator "manual__{date_param}"', + "dags update example_bash_operator --no-is-paused", # DAG Run commands - "dagrun list --dag-id example_bash_operator --state success --limit=1", + "dagrun list --dag-id=example_bash_operator --state success --limit=1", # XCom commands - need a DAG run with completed tasks - 'xcom add --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0 --key={xcom_key} --value=\'{{"test": "value"}}\'', - 'xcom get --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0 --key={xcom_key}', - 'xcom list --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0', - 'xcom edit --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0 --key={xcom_key} --value=\'{{"updated": "value"}}\'', - 'xcom delete --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0 --key={xcom_key}', + 'xcom add example_bash_operator "manual__{date_param}" runme_0 {xcom_key} \'{{"test": "value"}}\'', + 'xcom get example_bash_operator "manual__{date_param}" runme_0 {xcom_key}', + 'xcom list example_bash_operator "manual__{date_param}" runme_0', + 'xcom edit example_bash_operator "manual__{date_param}" runme_0 {xcom_key} \'{{"updated": "value"}}\'', + 'xcom delete example_bash_operator "manual__{date_param}" runme_0 {xcom_key}', # Jobs commands - "jobs list", + "jobs list --job-type=SchedulerJob --hostname=localhost", # Pools commands - "pools create --name=test_pool --slots=5", + "pools create test_pool 5", "pools list", - "pools get --pool-name=test_pool", - "pools get --pool-name=test_pool -o yaml", + "pools get test_pool", + "pools get test_pool -o yaml", "pools update --pool=test_pool --slots=10", "pools import tests/airflowctl_tests/fixtures/test_pools.json", "pools export tests/airflowctl_tests/fixtures/pools_export.json --output=json", - "pools delete --pool=test_pool", - "pools delete --pool=test_import_pool", + "pools delete test_pool", + "pools delete test_import_pool", # Providers commands "providers list", # Variables commands - "variables create --key=test_key --value=test_value", + "variables create test_key test_value", "variables list", - "variables get --variable-key=test_key", - "variables get --variable-key=test_key -o table", - "variables update --key=test_key --value=updated_value", + "variables get test_key", + "variables get test_key -o table", + "variables update test_key updated_value", "variables import tests/airflowctl_tests/fixtures/test_variables.json", - "variables delete --variable-key=test_key", - "variables delete --variable-key=test_import_var", - "variables delete --variable-key=test_import_var_with_desc", + "variables delete test_key", + "variables delete test_import_var", + "variables delete test_import_var_with_desc", # Version command "version --remote", # Plugins command diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index 466ee671b61b2..3d6463643fd50 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -423,11 +423,16 @@ def get_function_details(node: ast.FunctionDef, parent_node: ast.ClassDef) -> di args = [] return_annotation: str = "" - for arg in node.args.args: + num_args = len(node.args.args) + num_defaults = len(node.args.defaults) + first_default_index = num_args - num_defaults + + for idx, arg in enumerate(node.args.args): arg_name = arg.arg arg_type = ast.unparse(arg.annotation) if arg.annotation else "Any" if arg_name != "self": - args.append({arg_name: arg_type}) + has_default = idx >= first_default_index + args.append({arg_name: arg_type, "has_default": has_default}) if node.returns: return_annotation = [ @@ -519,9 +524,9 @@ def _create_arg( arg_flags: tuple, arg_type: type | Callable, arg_help: str, - arg_action: argparse.BooleanOptionalAction | None, - arg_dest: str | None = None, - arg_default: Any | None = None, + arg_action: type[argparse.BooleanOptionalAction] | None, + arg_dest=_UNSET, + arg_default=_UNSET, ) -> Arg: return Arg( flags=arg_flags, @@ -545,15 +550,21 @@ def _create_arg_for_non_primitive_type( for field, field_type in parameter_type_map.model_fields.items(): if field in self.excluded_parameters: continue + + is_required = field_type.is_required() + sanitized_field = self._sanitize_arg_parameter_key(field) self.datamodels_extended_map[parameter_type].append(field) + if type(field_type.annotation) is type: + is_bool = field_type.annotation is bool + arg_flags = (field,) if is_required and not is_bool else ("--" + sanitized_field,) commands.append( self._create_arg( - arg_flags=("--" + self._sanitize_arg_parameter_key(field),), + arg_flags=arg_flags, arg_type=self._python_type_from_string(field_type.annotation), - arg_action=argparse.BooleanOptionalAction if field_type.annotation is bool else None, # type: ignore + arg_action=argparse.BooleanOptionalAction if is_bool else None, arg_help=f"{field} for {parameter_key} operation", - arg_default=False if field_type.annotation is bool else None, + arg_default=False if is_bool else None, ) ) else: @@ -562,13 +573,15 @@ def _create_arg_for_non_primitive_type( except AttributeError: annotation = field_type.annotation + is_bool = annotation is bool + arg_flags = (field,) if is_required and not is_bool else ("--" + sanitized_field,) commands.append( self._create_arg( - arg_flags=("--" + self._sanitize_arg_parameter_key(field),), + arg_flags=arg_flags, arg_type=self._python_type_from_string(annotation), - arg_action=argparse.BooleanOptionalAction if annotation is bool else None, # type: ignore + arg_action=argparse.BooleanOptionalAction if is_bool else None, arg_help=f"{field} for {parameter_key} operation", - arg_default=False if annotation is bool else None, + arg_default=False if is_bool else None, ) ) return commands @@ -578,17 +591,25 @@ def _create_args_map_from_operation(self): for operation in self.operations: args = [] for parameter in operation.get("parameters"): - for parameter_key, parameter_type in parameter.items(): + for parameter_key, parameter_value in parameter.items(): + if parameter_key == "has_default": + continue + parameter_type = parameter_value + has_default = parameter.get("has_default", False) + if self._is_primitive_type(type_name=parameter_type): base_parameter_type = parameter_type.replace(" | None", "").strip() is_bool = base_parameter_type == "bool" + is_optional = "| None" in parameter_type + sanitized_key = self._sanitize_arg_parameter_key(parameter_key) + arg_flags = ("--" + sanitized_key,) if is_optional or is_bool else (sanitized_key,) args.append( self._create_arg( - arg_flags=("--" + self._sanitize_arg_parameter_key(parameter_key),), + arg_flags=arg_flags, arg_type=self._python_type_from_string(parameter_type), arg_action=argparse.BooleanOptionalAction if is_bool else None, arg_help=f"{parameter_key} for {operation.get('name')} operation in {operation.get('parent').name}", - arg_default=None, + arg_default=False if parameter_type == "bool" else None, ) ) else: @@ -647,6 +668,8 @@ def _get_func(args: Namespace, api_operation: dict, api_client: Client = NEW_API args_dict = vars(args) for parameter in api_operation["parameters"]: for parameter_key, parameter_type in parameter.items(): + if parameter_key == "has_default": + continue if self._is_primitive_type(type_name=parameter_type): method_params[self._sanitize_method_param_key(parameter_key)] = args_dict[ parameter_key diff --git a/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py b/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py index e0278cd7c5348..188d0565f17cd 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py @@ -59,7 +59,7 @@ def no_op(): def test_args_create(): return [ ( - "--dag-id", + "dag_id", { "help": "dag_id for backfill operation", "action": None, @@ -69,7 +69,7 @@ def test_args_create(): }, ), ( - "--from-date", + "from_date", { "help": "from_date for backfill operation", "action": None, @@ -79,7 +79,7 @@ def test_args_create(): }, ), ( - "--to-date", + "to_date", { "help": "to_date for backfill operation", "action": None, @@ -158,7 +158,7 @@ def test_args_list(): def test_args_get(): return [ ( - "--backfill-id", + "backfill-id", { "help": "backfill_id for get operation in BackfillsOperations", "default": None, @@ -182,7 +182,7 @@ def test_args_get(): def test_args_delete(): return [ ( - "--backfill-id", + "backfill-id", { "help": "backfill_id for delete operation in BackfillsOperations", "default": None, @@ -275,7 +275,8 @@ def delete(self, backfill_id: str) -> ServerResponseError | None: assert arg.kwargs["action"] == test_arg[1]["action"] assert arg.kwargs["default"] == test_arg[1]["default"] assert arg.kwargs["type"] == test_arg[1]["type"] - assert arg.kwargs["dest"] == test_arg[1]["dest"] + if "dest" in test_arg[1]: + assert arg.kwargs.get("dest") == test_arg[1]["dest"] print(arg.flags) elif sub_command.name == "list": for arg, test_arg in zip(sub_command.args, test_args_list): @@ -669,3 +670,23 @@ def test_help_texts_used_for_auto_generated_commands(self, group_name, subcomman "Help message should match the help_text.yaml" ) return + + def test_positional_args(self): + """Test that required parameters are created as positional arguments.""" + command_factory = CommandFactory(file_path="") + + positional_arg = command_factory._create_arg( + arg_flags=("connection_id",), + arg_type=str, + arg_help="Connection ID", + arg_action=None, + ) + assert positional_arg.flags[0] == "connection_id" + + optional_arg = command_factory._create_arg( + arg_flags=("--description",), + arg_type=str, + arg_help="Description", + arg_action=None, + ) + assert optional_arg.flags[0] == "--description"