diff --git a/airflow-ctl/src/airflowctl/api/operations.py b/airflow-ctl/src/airflowctl/api/operations.py index e250b66e127dd..36158372c3e46 100644 --- a/airflow-ctl/src/airflowctl/api/operations.py +++ b/airflow-ctl/src/airflowctl/api/operations.py @@ -19,7 +19,9 @@ import datetime import json -from typing import TYPE_CHECKING, Any, TypeVar, get_args +from collections.abc import Mapping +from types import UnionType +from typing import TYPE_CHECKING, Annotated, Any, TypeVar, Union, get_args, get_origin import httpx import structlog @@ -78,7 +80,7 @@ XComResponseNative, XComUpdateBody, ) -from airflowctl.exceptions import AirflowCtlConnectionException +from airflowctl.exceptions import AirflowCtlConnectionException, AirflowCtlValidationException if TYPE_CHECKING: from airflowctl.api.client import Client @@ -145,6 +147,88 @@ def wrapped(self, *args, **kwargs): } +def _field_label(field_name: str, field_info) -> str: + return field_info.alias or field_name + + +def _iter_model_annotations(annotation: Any): + origin = get_origin(annotation) + if origin is Annotated: + yield from _iter_model_annotations(get_args(annotation)[0]) + return + if origin in (list, tuple, set, frozenset): + for arg in get_args(annotation): + yield from _iter_model_annotations(arg) + return + if origin in (UnionType, Union): + for arg in get_args(annotation): + if arg is not type(None): + yield from _iter_model_annotations(arg) + return + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + yield annotation + + +def _missing_required_fields(value: Any, model: type[BaseModel] | None = None, prefix: str = "") -> list[str]: + if isinstance(value, BaseModel): + model = type(value) + fields_set = value.model_fields_set + missing = [] + for field_name, field_info in model.model_fields.items(): + label = _field_label(field_name, field_info) + field_path = f"{prefix}.{label}" if prefix else label + if field_info.is_required() and field_name not in fields_set: + missing.append(field_path) + continue + if hasattr(value, field_name): + missing.extend(_missing_required_fields(getattr(value, field_name), prefix=field_path)) + return missing + + if isinstance(value, Mapping) and model is not None: + missing = [] + for field_name, field_info in model.model_fields.items(): + label = _field_label(field_name, field_info) + field_path = f"{prefix}.{label}" if prefix else label + keys = {field_name} + if field_info.alias: + keys.add(field_info.alias) + present_key = next((key for key in keys if key in value), None) + if present_key is None: + if field_info.is_required(): + missing.append(field_path) + continue + for nested_model in _iter_model_annotations(field_info.annotation): + missing.extend(_missing_required_fields(value[present_key], nested_model, field_path)) + return missing + + if isinstance(value, list): + missing = [] + for index, item in enumerate(value): + missing.extend(_missing_required_fields(item, model, f"{prefix}[{index}]")) + return missing + + return [] + + +def validate_required_fields( + value: Any, model: type[BaseModel] | None = None, name: str | None = None +) -> None: + missing = _missing_required_fields(value, model) + if not missing: + return + field_list = ", ".join(missing) + target = name or (model.__name__ if model else type(value).__name__) + raise AirflowCtlValidationException( + f"Missing required field(s) for {target}: {field_list}. " + "Please provide the missing value(s) before sending the request." + ) + + +def dump_body(body: BaseModel, **kwargs: Any) -> dict[str, Any]: + validate_required_fields(body) + return body.model_dump(**kwargs) + + def get_field_default(annotation) -> Any: args = get_args(annotation) if args: @@ -239,7 +323,7 @@ def login_with_username_and_password(self, login: LoginBody) -> LoginResponse | """Login to the API server.""" try: return LoginResponse.model_validate_json( - self.client.post("/token/cli", json=login.model_dump(mode="json")).content + self.client.post("/token/cli", json=dump_body(login, mode="json")).content ) except ServerResponseError as e: raise e @@ -282,7 +366,7 @@ def create_event( if asset_event_body.extra is None: asset_event_body.extra = {} self.response = self.client.post( - "assets/events", json=asset_event_body.model_dump(mode="json", exclude_none=True) + "assets/events", json=dump_body(asset_event_body, mode="json", exclude_none=True) ) return AssetEventResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -354,7 +438,7 @@ def create(self, backfill: BackfillPostBody) -> BackfillResponse | ServerRespons """Create a backfill.""" try: self.response = self.client.post( - "backfills", json=backfill.model_dump(mode="json", exclude_none=True) + "backfills", json=dump_body(backfill, mode="json", exclude_none=True) ) return BackfillResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -364,7 +448,7 @@ def create_dry_run(self, backfill: BackfillPostBody) -> BackfillResponse | Serve """Create a dry run backfill.""" try: self.response = self.client.post( - "backfills/dry_run", json=backfill.model_dump(mode="json", exclude_none=True) + "backfills/dry_run", json=dump_body(backfill, mode="json", exclude_none=True) ) return BackfillResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -450,7 +534,7 @@ def create( """Create a connection.""" try: self.response = self.client.post( - "connections", json=connection.model_dump(mode="json", by_alias=True, exclude_none=True) + "connections", json=dump_body(connection, mode="json", by_alias=True, exclude_none=True) ) return ConnectionResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -460,7 +544,7 @@ def bulk(self, connections: BulkBodyConnectionBody) -> BulkResponse | ServerResp """CRUD multiple connections.""" try: self.response = self.client.patch( - "connections", json=connections.model_dump(mode="json", by_alias=True) + "connections", json=dump_body(connections, mode="json", by_alias=True) ) return BulkResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -490,7 +574,7 @@ def update( try: self.response = self.client.patch( f"connections/{connection.connection_id}", - json=connection.model_dump(mode="json", by_alias=True), + json=dump_body(connection, mode="json", by_alias=True), ) return ConnectionResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -503,7 +587,7 @@ def test( """Test a connection.""" try: self.response = self.client.post( - "connections/test", json=connection.model_dump(mode="json", by_alias=True) + "connections/test", json=dump_body(connection, mode="json", by_alias=True) ) return ConnectionTestResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -539,7 +623,7 @@ def list(self) -> DAGCollectionResponse | ServerResponseError: def update(self, dag_id: str, dag_body: DAGPatchBody) -> DAGResponse | ServerResponseError: try: - self.response = self.client.patch(f"dags/{dag_id}", json=dag_body.model_dump(mode="json")) + self.response = self.client.patch(f"dags/{dag_id}", json=dump_body(dag_body, mode="json")) return DAGResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e @@ -591,7 +675,7 @@ def trigger( trigger_dag_run.conf = {} try: self.response = self.client.post( - f"dags/{dag_id}/dagRuns", json=trigger_dag_run.model_dump(mode="json") + f"dags/{dag_id}/dagRuns", json=dump_body(trigger_dag_run, mode="json") ) return DAGRunResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -685,7 +769,7 @@ def list(self) -> PoolCollectionResponse | ServerResponseError: def create(self, pool: PoolBody) -> PoolResponse | ServerResponseError: """Create a pool.""" try: - self.response = self.client.post("pools", json=pool.model_dump(mode="json", exclude_none=True)) + self.response = self.client.post("pools", json=dump_body(pool, mode="json", exclude_none=True)) return PoolResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e @@ -693,7 +777,7 @@ def create(self, pool: PoolBody) -> PoolResponse | ServerResponseError: def bulk(self, pools: BulkBodyPoolBody) -> BulkResponse | ServerResponseError: """CRUD multiple pools.""" try: - self.response = self.client.patch("pools", json=pools.model_dump(mode="json")) + self.response = self.client.patch("pools", json=dump_body(pools, mode="json")) return BulkResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e @@ -710,7 +794,7 @@ def update(self, pool_body: PoolPatchBody) -> PoolResponse | ServerResponseError """Update a pool.""" try: self.response = self.client.patch( - f"pools/{pool_body.pool}", json=pool_body.model_dump(mode="json") + f"pools/{pool_body.pool}", json=dump_body(pool_body, mode="json") ) return PoolResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -744,7 +828,7 @@ def create(self, variable: VariableBody) -> VariableResponse | ServerResponseErr """Create a variable.""" try: self.response = self.client.post( - "variables", json=variable.model_dump(mode="json", exclude_none=True) + "variables", json=dump_body(variable, mode="json", exclude_none=True) ) return VariableResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -753,7 +837,7 @@ def create(self, variable: VariableBody) -> VariableResponse | ServerResponseErr def bulk(self, variables: BulkBodyVariableBody) -> BulkResponse | ServerResponseError: """CRUD multiple variables.""" try: - self.response = self.client.patch("variables", json=variables.model_dump(mode="json")) + self.response = self.client.patch("variables", json=dump_body(variables, mode="json")) return BulkResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e @@ -770,7 +854,7 @@ def update(self, variable: VariableBody) -> VariableResponse | ServerResponseErr """Update a variable.""" try: self.response = self.client.patch( - f"variables/{variable.key}", json=variable.model_dump(mode="json") + f"variables/{variable.key}", json=dump_body(variable, mode="json") ) return VariableResponse.model_validate_json(self.response.content) except ServerResponseError as e: @@ -855,7 +939,7 @@ def add( try: self.response = self.client.post( f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries", - json=body.model_dump(mode="json", exclude_unset=True, exclude_none=True), + json=dump_body(body, mode="json", exclude_unset=True, exclude_none=True), ) return XComResponseNative.model_validate_json(self.response.content) except ServerResponseError as e: @@ -883,7 +967,7 @@ def edit( try: self.response = self.client.patch( f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{key}", - json=body.model_dump(mode="json", exclude_unset=True, exclude_none=True), + json=dump_body(body, mode="json", exclude_unset=True, exclude_none=True), ) return XComResponseNative.model_validate_json(self.response.content) except ServerResponseError as e: diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index 11ff4542e01ef..8836036823022 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -46,6 +46,7 @@ AirflowCtlCredentialNotFoundException, AirflowCtlKeyringException, AirflowCtlNotFoundException, + AirflowCtlValidationException, ) from airflowctl.utils.module_loading import import_string @@ -80,6 +81,7 @@ def safe_call_command(function: Callable, args: Iterable[Arg]) -> None: AirflowCtlConnectionException, AirflowCtlKeyringException, AirflowCtlNotFoundException, + AirflowCtlValidationException, ) as e: rich.print(f"command failed due to {e}") sys.exit(1) diff --git a/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py b/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py index 02958740a3693..e8533d5cc1b4d 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py @@ -29,6 +29,7 @@ BulkCreateActionConnectionBody, ConnectionBody, ) +from airflowctl.api.operations import validate_required_fields @provide_api_client(kind=ClientKind.CLI) @@ -46,20 +47,21 @@ def import_(args, api_client=NEW_API_CLIENT) -> None: except Exception as e: raise SystemExit(f"Error reading connections file {args.file}: {e}") try: - connections_data = { - k: ConnectionBody( - connection_id=k, - conn_type=v.get("conn_type"), - host=v.get("host"), - login=v.get("login"), - password=v.get("password"), - port=v.get("port"), - extra=v.get("extra"), - description=v.get("description", ""), - **({"schema": v["schema"]} if "schema" in v else {}), - ) - for k, v in connections_json.items() - } + connections_data = {} + for connection_id, connection_config in connections_json.items(): + connection_data = { + "connection_id": connection_id, + **({"conn_type": connection_config["conn_type"]} if "conn_type" in connection_config else {}), + "host": connection_config.get("host"), + "login": connection_config.get("login"), + "password": connection_config.get("password"), + "port": connection_config.get("port"), + "extra": connection_config.get("extra"), + "description": connection_config.get("description", ""), + **({"schema": connection_config["schema"]} if "schema" in connection_config else {}), + } + validate_required_fields(connection_data, ConnectionBody, f"connection {connection_id!r}") + connections_data[connection_id] = ConnectionBody(**connection_data) connection_create_action = BulkCreateActionConnectionBody( action="create", entities=list(connections_data.values()), diff --git a/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py b/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py index 08e56eed87b0b..b35885ef164a2 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py @@ -31,6 +31,7 @@ BulkCreateActionPoolBody, PoolBody, ) +from airflowctl.api.operations import validate_required_fields from airflowctl.ctl.console_formatting import AirflowConsole @@ -95,18 +96,12 @@ def _import_helper(api_client: Client, filepath: Path, action_on_existence: Bulk raise SystemExit("Invalid format: Expected a list of pool objects") pools_to_update = [] - for pool_config in pools_json: - if not isinstance(pool_config, dict) or "name" not in pool_config or "slots" not in pool_config: + for index, pool_config in enumerate(pools_json): + if not isinstance(pool_config, dict): raise SystemExit(f"Invalid pool configuration: {pool_config}") + validate_required_fields(pool_config, PoolBody, f"pool at index {index}") - pools_to_update.append( - PoolBody( - name=pool_config["name"], - slots=pool_config["slots"], - description=pool_config.get("description", ""), - include_deferred=pool_config.get("include_deferred", False), - ) - ) + pools_to_update.append(PoolBody.model_validate(pool_config)) bulk_body = BulkBodyPoolBody( actions=[ diff --git a/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py b/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py index 88bf33a0f0197..88d3334ad2878 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py @@ -29,6 +29,7 @@ BulkCreateActionVariableBody, VariableBody, ) +from airflowctl.api.operations import validate_required_fields @provide_api_client(kind=ClientKind.CLI) @@ -50,17 +51,9 @@ def import_(args, api_client=NEW_API_CLIENT) -> list[str]: action_on_existence = BulkActionOnExistence(args.action_on_existing_key) vars_to_update = [] for k, v in var_json.items(): - value, description = v, None - if isinstance(v, dict) and "value" in v: - value, description = v["value"], v.get("description") - - vars_to_update.append( - VariableBody( - key=k, - value=value, - description=description, - ) - ) + variable_data = {"key": k, **v} if isinstance(v, dict) else {"key": k, "value": v} + validate_required_fields(variable_data, VariableBody, f"variable {k!r}") + vars_to_update.append(VariableBody.model_validate(variable_data)) bulk_body = BulkBodyVariableBody( actions=[ diff --git a/airflow-ctl/src/airflowctl/exceptions.py b/airflow-ctl/src/airflowctl/exceptions.py index 0af8b9fa593cc..1635f4ba97660 100644 --- a/airflow-ctl/src/airflowctl/exceptions.py +++ b/airflow-ctl/src/airflowctl/exceptions.py @@ -44,3 +44,7 @@ class AirflowCtlConnectionException(AirflowCtlException): class AirflowCtlKeyringException(AirflowCtlException): """Raise when a keyring error occurs while performing an operation.""" + + +class AirflowCtlValidationException(AirflowCtlException): + """Raise when user-provided data fails client-side validation.""" diff --git a/airflow-ctl/tests/airflow_ctl/api/test_operations.py b/airflow-ctl/tests/airflow_ctl/api/test_operations.py index 52faecee73ea0..30763926aa5e3 100644 --- a/airflow-ctl/tests/airflow_ctl/api/test_operations.py +++ b/airflow-ctl/tests/airflow_ctl/api/test_operations.py @@ -101,7 +101,7 @@ XComResponseNative, ) from airflowctl.api.operations import BaseOperations -from airflowctl.exceptions import AirflowCtlConnectionException +from airflowctl.exceptions import AirflowCtlConnectionException, AirflowCtlValidationException if TYPE_CHECKING: from pydantic import NonNegativeInt @@ -1404,6 +1404,19 @@ def handle_request(request: httpx.Request) -> httpx.Response: response = client.pools.create(pool=self.pool) assert response == self.pool_response + def test_create_rejects_missing_required_field_before_request(self): + def handle_request(request: httpx.Request) -> httpx.Response: + raise AssertionError("request should not be sent") + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + pool = PoolBody.model_construct(name=self.pool_name) + + with pytest.raises( + AirflowCtlValidationException, + match="Missing required field\\(s\\) for PoolBody: slots", + ): + client.pools.create(pool=pool) + def test_bulk(self): def handle_request(request: httpx.Request) -> httpx.Response: assert request.url.path == "/api/v2/pools" @@ -1413,6 +1426,28 @@ def handle_request(request: httpx.Request) -> httpx.Response: response = client.pools.bulk(pools=self.pools_bulk_body) assert response == self.pool_bulk_response + def test_bulk_rejects_nested_missing_required_field_before_request(self): + def handle_request(request: httpx.Request) -> httpx.Response: + raise AssertionError("request should not be sent") + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + pool = PoolBody.model_construct(name=self.pool_name) + pools = BulkBodyPoolBody.model_construct( + actions=[ + BulkCreateActionPoolBody.model_construct( + action="create", + entities=[pool], + action_on_existence=BulkActionOnExistence.FAIL, + ) + ] + ) + + with pytest.raises( + AirflowCtlValidationException, + match=r"actions\[0\]\.entities\[0\]\.slots", + ): + client.pools.bulk(pools=pools) + def test_delete(self): def handle_request(request: httpx.Request) -> httpx.Response: assert request.url.path == f"/api/v2/pools/{self.pool_name}" diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py index ba803ddd8e0cd..d0fff9b9ea60a 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py @@ -132,6 +132,30 @@ def test_import_error(self, api_client_maker, tmp_path, monkeypatch): ) assert exc_info.value.code == 1 + def test_import_missing_required_conn_type(self, api_client_maker, tmp_path, monkeypatch, capsys): + api_client = api_client_maker( + path="/api/v2/connections", + response_json=self.bulk_response_success.model_dump(), + expected_http_status_code=200, + kind=ClientKind.CLI, + ) + + monkeypatch.chdir(tmp_path) + json_path = tmp_path / self.export_file_name + json_path.write_text(json.dumps({self.connection_id: {"host": "test_host"}})) + + with pytest.raises(SystemExit) as exc_info: + connection_command.import_( + self.parser.parse_args(["connections", "import", json_path.as_posix()]), + api_client=api_client, + ) + + assert exc_info.value.code == 1 + output = capsys.readouterr().out + assert "Missing required field(s)" in output + assert "test_connection" in output + assert "conn_type" in output + def test_import_without_extra_field(self, api_client_maker, tmp_path, monkeypatch): """Import succeeds when JSON omits the ``extra`` field (#62653). diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py index 0bc2438929454..6d06a7608f2ee 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py @@ -30,6 +30,7 @@ BulkCreateActionPoolBody, ) from airflowctl.ctl.commands import pool_command +from airflowctl.exceptions import AirflowCtlValidationException @pytest.fixture @@ -60,10 +61,23 @@ def test_import_invalid_json(self, mock_client, tmp_path): def test_import_invalid_pool_config(self, mock_client, tmp_path): """Test import with invalid pool configuration.""" invalid_pool = tmp_path / "invalid_pool.json" - invalid_pool.write_text(json.dumps([{"invalid": "config"}])) - with pytest.raises(SystemExit, match="Invalid pool configuration: {'invalid': 'config'}"): + invalid_pool.write_text(json.dumps(["invalid"])) + with pytest.raises(SystemExit, match="Invalid pool configuration: invalid"): pool_command.import_(mock.MagicMock(file=invalid_pool, action_on_existing_key="fail")) + def test_import_missing_required_pool_field(self, mock_client, tmp_path): + """Test import with a missing field required by the API schema.""" + invalid_pool = tmp_path / "invalid_pool.json" + invalid_pool.write_text(json.dumps([{"name": "test_pool"}])) + + with pytest.raises( + AirflowCtlValidationException, + match="Missing required field\\(s\\) for pool at index 0: slots", + ): + pool_command.import_(mock.MagicMock(file=invalid_pool, action_on_existing_key="fail")) + + mock_client.pools.bulk.assert_not_called() + def test_import_success(self, mock_client, tmp_path, capsys): """Test successful pool import.""" pools_file = tmp_path / "pools.json" diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py index a0598d03459f8..90a50ddc0cc5c 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py @@ -29,6 +29,7 @@ ) from airflowctl.ctl import cli_parser from airflowctl.ctl.commands import variable_command +from airflowctl.exceptions import AirflowCtlValidationException class TestCliVariableCommands: @@ -134,3 +135,24 @@ def test_import_error(self, api_client_maker, tmp_path, monkeypatch): self.parser.parse_args(["variables", "import", expected_json_path.as_posix()]), api_client=api_client, ) + + def test_import_missing_required_value(self, api_client_maker, tmp_path, monkeypatch): + api_client = api_client_maker( + path="/api/v2/variables", + response_json=self.bulk_response_success.model_dump(), + expected_http_status_code=200, + kind=ClientKind.CLI, + ) + + monkeypatch.chdir(tmp_path) + variable_path = tmp_path / self.export_file_name + variable_path.write_text(json.dumps({self.key: {"description": "missing value"}})) + + with pytest.raises( + AirflowCtlValidationException, + match="Missing required field\\(s\\) for variable 'key': value", + ): + variable_command.import_( + self.parser.parse_args(["variables", "import", variable_path.as_posix()]), + api_client=api_client, + )