Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 104 additions & 20 deletions airflow-ctl/src/airflowctl/api/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import datetime
import json
from typing import TYPE_CHECKING, Any, TypeVar, get_args
from collections.abc import Mapping
from types import UnionType
from typing import TYPE_CHECKING, Annotated, Any, TypeVar, Union, get_args, get_origin

import httpx
import structlog
Expand Down Expand Up @@ -78,7 +80,7 @@
XComResponseNative,
XComUpdateBody,
)
from airflowctl.exceptions import AirflowCtlConnectionException
from airflowctl.exceptions import AirflowCtlConnectionException, AirflowCtlValidationException

if TYPE_CHECKING:
from airflowctl.api.client import Client
Expand Down Expand Up @@ -145,6 +147,88 @@ def wrapped(self, *args, **kwargs):
}


def _field_label(field_name: str, field_info) -> str:
return field_info.alias or field_name


def _iter_model_annotations(annotation: Any):
origin = get_origin(annotation)
if origin is Annotated:
yield from _iter_model_annotations(get_args(annotation)[0])
return
if origin in (list, tuple, set, frozenset):
for arg in get_args(annotation):
yield from _iter_model_annotations(arg)
return
if origin in (UnionType, Union):
for arg in get_args(annotation):
if arg is not type(None):
yield from _iter_model_annotations(arg)
return
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
yield annotation


def _missing_required_fields(value: Any, model: type[BaseModel] | None = None, prefix: str = "") -> list[str]:
if isinstance(value, BaseModel):
model = type(value)
fields_set = value.model_fields_set
missing = []
for field_name, field_info in model.model_fields.items():
label = _field_label(field_name, field_info)
field_path = f"{prefix}.{label}" if prefix else label
if field_info.is_required() and field_name not in fields_set:
missing.append(field_path)
continue
if hasattr(value, field_name):
missing.extend(_missing_required_fields(getattr(value, field_name), prefix=field_path))
return missing

if isinstance(value, Mapping) and model is not None:
missing = []
for field_name, field_info in model.model_fields.items():
label = _field_label(field_name, field_info)
field_path = f"{prefix}.{label}" if prefix else label
keys = {field_name}
if field_info.alias:
keys.add(field_info.alias)
present_key = next((key for key in keys if key in value), None)
if present_key is None:
if field_info.is_required():
missing.append(field_path)
continue
for nested_model in _iter_model_annotations(field_info.annotation):
missing.extend(_missing_required_fields(value[present_key], nested_model, field_path))
return missing

if isinstance(value, list):
missing = []
for index, item in enumerate(value):
missing.extend(_missing_required_fields(item, model, f"{prefix}[{index}]"))
return missing

return []


def validate_required_fields(
value: Any, model: type[BaseModel] | None = None, name: str | None = None
) -> None:
missing = _missing_required_fields(value, model)
if not missing:
return
field_list = ", ".join(missing)
target = name or (model.__name__ if model else type(value).__name__)
raise AirflowCtlValidationException(
f"Missing required field(s) for {target}: {field_list}. "
"Please provide the missing value(s) before sending the request."
)


def dump_body(body: BaseModel, **kwargs: Any) -> dict[str, Any]:
validate_required_fields(body)
return body.model_dump(**kwargs)


def get_field_default(annotation) -> Any:
args = get_args(annotation)
if args:
Expand Down Expand Up @@ -239,7 +323,7 @@ def login_with_username_and_password(self, login: LoginBody) -> LoginResponse |
"""Login to the API server."""
try:
return LoginResponse.model_validate_json(
self.client.post("/token/cli", json=login.model_dump(mode="json")).content
self.client.post("/token/cli", json=dump_body(login, mode="json")).content
)
except ServerResponseError as e:
raise e
Expand Down Expand Up @@ -282,7 +366,7 @@ def create_event(
if asset_event_body.extra is None:
asset_event_body.extra = {}
self.response = self.client.post(
"assets/events", json=asset_event_body.model_dump(mode="json", exclude_none=True)
"assets/events", json=dump_body(asset_event_body, mode="json", exclude_none=True)
)
return AssetEventResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -354,7 +438,7 @@ def create(self, backfill: BackfillPostBody) -> BackfillResponse | ServerRespons
"""Create a backfill."""
try:
self.response = self.client.post(
"backfills", json=backfill.model_dump(mode="json", exclude_none=True)
"backfills", json=dump_body(backfill, mode="json", exclude_none=True)
)
return BackfillResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand All @@ -364,7 +448,7 @@ def create_dry_run(self, backfill: BackfillPostBody) -> BackfillResponse | Serve
"""Create a dry run backfill."""
try:
self.response = self.client.post(
"backfills/dry_run", json=backfill.model_dump(mode="json", exclude_none=True)
"backfills/dry_run", json=dump_body(backfill, mode="json", exclude_none=True)
)
return BackfillResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -450,7 +534,7 @@ def create(
"""Create a connection."""
try:
self.response = self.client.post(
"connections", json=connection.model_dump(mode="json", by_alias=True, exclude_none=True)
"connections", json=dump_body(connection, mode="json", by_alias=True, exclude_none=True)
)
return ConnectionResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand All @@ -460,7 +544,7 @@ def bulk(self, connections: BulkBodyConnectionBody) -> BulkResponse | ServerResp
"""CRUD multiple connections."""
try:
self.response = self.client.patch(
"connections", json=connections.model_dump(mode="json", by_alias=True)
"connections", json=dump_body(connections, mode="json", by_alias=True)
)
return BulkResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -490,7 +574,7 @@ def update(
try:
self.response = self.client.patch(
f"connections/{connection.connection_id}",
json=connection.model_dump(mode="json", by_alias=True),
json=dump_body(connection, mode="json", by_alias=True),
)
return ConnectionResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand All @@ -503,7 +587,7 @@ def test(
"""Test a connection."""
try:
self.response = self.client.post(
"connections/test", json=connection.model_dump(mode="json", by_alias=True)
"connections/test", json=dump_body(connection, mode="json", by_alias=True)
)
return ConnectionTestResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -539,7 +623,7 @@ def list(self) -> DAGCollectionResponse | ServerResponseError:

def update(self, dag_id: str, dag_body: DAGPatchBody) -> DAGResponse | ServerResponseError:
try:
self.response = self.client.patch(f"dags/{dag_id}", json=dag_body.model_dump(mode="json"))
self.response = self.client.patch(f"dags/{dag_id}", json=dump_body(dag_body, mode="json"))
return DAGResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
raise e
Expand Down Expand Up @@ -591,7 +675,7 @@ def trigger(
trigger_dag_run.conf = {}
try:
self.response = self.client.post(
f"dags/{dag_id}/dagRuns", json=trigger_dag_run.model_dump(mode="json")
f"dags/{dag_id}/dagRuns", json=dump_body(trigger_dag_run, mode="json")
)
return DAGRunResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -685,15 +769,15 @@ def list(self) -> PoolCollectionResponse | ServerResponseError:
def create(self, pool: PoolBody) -> PoolResponse | ServerResponseError:
"""Create a pool."""
try:
self.response = self.client.post("pools", json=pool.model_dump(mode="json", exclude_none=True))
self.response = self.client.post("pools", json=dump_body(pool, mode="json", exclude_none=True))
return PoolResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
raise e

def bulk(self, pools: BulkBodyPoolBody) -> BulkResponse | ServerResponseError:
"""CRUD multiple pools."""
try:
self.response = self.client.patch("pools", json=pools.model_dump(mode="json"))
self.response = self.client.patch("pools", json=dump_body(pools, mode="json"))
return BulkResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
raise e
Expand All @@ -710,7 +794,7 @@ def update(self, pool_body: PoolPatchBody) -> PoolResponse | ServerResponseError
"""Update a pool."""
try:
self.response = self.client.patch(
f"pools/{pool_body.pool}", json=pool_body.model_dump(mode="json")
f"pools/{pool_body.pool}", json=dump_body(pool_body, mode="json")
)
return PoolResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -744,7 +828,7 @@ def create(self, variable: VariableBody) -> VariableResponse | ServerResponseErr
"""Create a variable."""
try:
self.response = self.client.post(
"variables", json=variable.model_dump(mode="json", exclude_none=True)
"variables", json=dump_body(variable, mode="json", exclude_none=True)
)
return VariableResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand All @@ -753,7 +837,7 @@ def create(self, variable: VariableBody) -> VariableResponse | ServerResponseErr
def bulk(self, variables: BulkBodyVariableBody) -> BulkResponse | ServerResponseError:
"""CRUD multiple variables."""
try:
self.response = self.client.patch("variables", json=variables.model_dump(mode="json"))
self.response = self.client.patch("variables", json=dump_body(variables, mode="json"))
return BulkResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
raise e
Expand All @@ -770,7 +854,7 @@ def update(self, variable: VariableBody) -> VariableResponse | ServerResponseErr
"""Update a variable."""
try:
self.response = self.client.patch(
f"variables/{variable.key}", json=variable.model_dump(mode="json")
f"variables/{variable.key}", json=dump_body(variable, mode="json")
)
return VariableResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -855,7 +939,7 @@ def add(
try:
self.response = self.client.post(
f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries",
json=body.model_dump(mode="json", exclude_unset=True, exclude_none=True),
json=dump_body(body, mode="json", exclude_unset=True, exclude_none=True),
)
return XComResponseNative.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down Expand Up @@ -883,7 +967,7 @@ def edit(
try:
self.response = self.client.patch(
f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{key}",
json=body.model_dump(mode="json", exclude_unset=True, exclude_none=True),
json=dump_body(body, mode="json", exclude_unset=True, exclude_none=True),
)
return XComResponseNative.model_validate_json(self.response.content)
except ServerResponseError as e:
Expand Down
2 changes: 2 additions & 0 deletions airflow-ctl/src/airflowctl/ctl/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AirflowCtlCredentialNotFoundException,
AirflowCtlKeyringException,
AirflowCtlNotFoundException,
AirflowCtlValidationException,
)
from airflowctl.utils.module_loading import import_string

Expand Down Expand Up @@ -80,6 +81,7 @@ def safe_call_command(function: Callable, args: Iterable[Arg]) -> None:
AirflowCtlConnectionException,
AirflowCtlKeyringException,
AirflowCtlNotFoundException,
AirflowCtlValidationException,
) as e:
rich.print(f"command failed due to {e}")
sys.exit(1)
Expand Down
30 changes: 16 additions & 14 deletions airflow-ctl/src/airflowctl/ctl/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BulkCreateActionConnectionBody,
ConnectionBody,
)
from airflowctl.api.operations import validate_required_fields


@provide_api_client(kind=ClientKind.CLI)
Expand All @@ -46,20 +47,21 @@ def import_(args, api_client=NEW_API_CLIENT) -> None:
except Exception as e:
raise SystemExit(f"Error reading connections file {args.file}: {e}")
try:
connections_data = {
k: ConnectionBody(
connection_id=k,
conn_type=v.get("conn_type"),
host=v.get("host"),
login=v.get("login"),
password=v.get("password"),
port=v.get("port"),
extra=v.get("extra"),
description=v.get("description", ""),
**({"schema": v["schema"]} if "schema" in v else {}),
)
for k, v in connections_json.items()
}
connections_data = {}
for connection_id, connection_config in connections_json.items():
connection_data = {
"connection_id": connection_id,
**({"conn_type": connection_config["conn_type"]} if "conn_type" in connection_config else {}),
"host": connection_config.get("host"),
"login": connection_config.get("login"),
"password": connection_config.get("password"),
"port": connection_config.get("port"),
"extra": connection_config.get("extra"),
"description": connection_config.get("description", ""),
**({"schema": connection_config["schema"]} if "schema" in connection_config else {}),
}
validate_required_fields(connection_data, ConnectionBody, f"connection {connection_id!r}")
connections_data[connection_id] = ConnectionBody(**connection_data)
connection_create_action = BulkCreateActionConnectionBody(
action="create",
entities=list(connections_data.values()),
Expand Down
15 changes: 5 additions & 10 deletions airflow-ctl/src/airflowctl/ctl/commands/pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
BulkCreateActionPoolBody,
PoolBody,
)
from airflowctl.api.operations import validate_required_fields
from airflowctl.ctl.console_formatting import AirflowConsole


Expand Down Expand Up @@ -95,18 +96,12 @@ def _import_helper(api_client: Client, filepath: Path, action_on_existence: Bulk
raise SystemExit("Invalid format: Expected a list of pool objects")

pools_to_update = []
for pool_config in pools_json:
if not isinstance(pool_config, dict) or "name" not in pool_config or "slots" not in pool_config:
for index, pool_config in enumerate(pools_json):
if not isinstance(pool_config, dict):
raise SystemExit(f"Invalid pool configuration: {pool_config}")
validate_required_fields(pool_config, PoolBody, f"pool at index {index}")

pools_to_update.append(
PoolBody(
name=pool_config["name"],
slots=pool_config["slots"],
description=pool_config.get("description", ""),
include_deferred=pool_config.get("include_deferred", False),
)
)
pools_to_update.append(PoolBody.model_validate(pool_config))

bulk_body = BulkBodyPoolBody(
actions=[
Expand Down
15 changes: 4 additions & 11 deletions airflow-ctl/src/airflowctl/ctl/commands/variable_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BulkCreateActionVariableBody,
VariableBody,
)
from airflowctl.api.operations import validate_required_fields


@provide_api_client(kind=ClientKind.CLI)
Expand All @@ -50,17 +51,9 @@ def import_(args, api_client=NEW_API_CLIENT) -> list[str]:
action_on_existence = BulkActionOnExistence(args.action_on_existing_key)
vars_to_update = []
for k, v in var_json.items():
value, description = v, None
if isinstance(v, dict) and "value" in v:
value, description = v["value"], v.get("description")

vars_to_update.append(
VariableBody(
key=k,
value=value,
description=description,
)
)
variable_data = {"key": k, **v} if isinstance(v, dict) else {"key": k, "value": v}
validate_required_fields(variable_data, VariableBody, f"variable {k!r}")
vars_to_update.append(VariableBody.model_validate(variable_data))

bulk_body = BulkBodyVariableBody(
actions=[
Expand Down
4 changes: 4 additions & 0 deletions airflow-ctl/src/airflowctl/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ class AirflowCtlConnectionException(AirflowCtlException):

class AirflowCtlKeyringException(AirflowCtlException):
"""Raise when a keyring error occurs while performing an operation."""


class AirflowCtlValidationException(AirflowCtlException):
"""Raise when user-provided data fails client-side validation."""
Loading
Loading