Skip to content

Commit

Permalink
Refactor GKECreateClusterOperator's body validation (#31923)
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Jun 29, 2023
1 parent 4a525e8 commit f3f69bf
Showing 1 changed file with 60 additions and 30 deletions.
90 changes: 60 additions & 30 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Expand Up @@ -20,7 +20,7 @@

import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from google.api_core.exceptions import AlreadyExists
from google.cloud.container_v1.types import Cluster
Expand Down Expand Up @@ -268,41 +268,71 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
self.deferrable = deferrable
self._check_input()
self._validate_input()

self._hook: GKEHook | None = None

def _check_input(self) -> None:
if (
not all([self.project_id, self.location, self.body])
or (isinstance(self.body, dict) and "name" not in self.body)
or (
isinstance(self.body, dict)
and ("initial_node_count" not in self.body and "node_pools" not in self.body)
)
or (not (isinstance(self.body, dict)) and not (getattr(self.body, "name", None)))
or (
not (isinstance(self.body, dict))
and (
not (getattr(self.body, "initial_node_count", None))
and not (getattr(self.body, "node_pools", None))
def _validate_input(self) -> None:
"""Primary validation of the input body."""
self._alert_deprecated_body_fields()

error_messages: list[str] = []
if not self._body_field("name"):
error_messages.append("Field body['name'] is missing or incorrect")

if self._body_field("initial_node_count"):
if self._body_field("node_pools"):
error_messages.append(
"Do not use filed body['initial_node_count'] and body['node_pools'] at the same time."
)
)
):
self.log.error(
"One of (project_id, location, body, body['name'], "
"body['initial_node_count']), body['node_pools'] is missing or incorrect"
)
raise AirflowException("Operator has incorrect or missing input.")
elif (
isinstance(self.body, dict) and ("initial_node_count" in self.body and "node_pools" in self.body)
) or (
not (isinstance(self.body, dict))
and (getattr(self.body, "initial_node_count", None) and getattr(self.body, "node_pools", None))
):
self.log.error("Only one of body['initial_node_count']) and body['node_pools'] may be specified")

if self._body_field("node_config"):
if self._body_field("node_pools"):
error_messages.append(
"Do not use filed body['node_config'] and body['node_pools'] at the same time."
)

if self._body_field("node_pools"):
if any([self._body_field("node_config"), self._body_field("initial_node_count")]):
error_messages.append(
"The field body['node_pools'] should not be set if "
"body['node_config'] or body['initial_code_count'] are specified."
)

if not any([self._body_field("node_config"), self._body_field("initial_node_count")]):
if not self._body_field("node_pools"):
error_messages.append(
"Field body['node_pools'] is required if none of fields "
"body['initial_node_count'] or body['node_pools'] are specified."
)

for message in error_messages:
self.log.error(message)

if error_messages:
raise AirflowException("Operator has incorrect or missing input.")

def _body_field(self, field_name: str, default_value: Any = None) -> Any:
"""Extracts the value of the given field name."""
if isinstance(self.body, dict):
return self.body.get(field_name, default_value)
else:
return getattr(self.body, field_name, default_value)

def _alert_deprecated_body_fields(self) -> None:
"""Generates warning messages if deprecated fields were used in the body."""
deprecated_body_fields_with_replacement = [
("initial_node_count", "node_pool.initial_node_count"),
("node_config", "node_pool.config"),
("zone", "location"),
("instance_group_urls", "node_pools.instance_group_urls"),
]
for deprecated_field, replacement in deprecated_body_fields_with_replacement:
if self._body_field(deprecated_field):
warnings.warn(
f"The body field '{deprecated_field}' is deprecated. Use '{replacement}' instead."
)

def execute(self, context: Context) -> str:
hook = self._get_hook()
try:
Expand Down

0 comments on commit f3f69bf

Please sign in to comment.