Skip to content

Commit

Permalink
Support regional GKE cluster (#18966)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrkm4ntr committed Dec 22, 2021
1 parent 658f406 commit a4622e1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 29 deletions.
22 changes: 15 additions & 7 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def get_operation(self, operation_name: str, project_id: Optional[str] = None) -
:return: The new, updated operation from Google Cloud
"""
return self.get_conn().get_operation(
project_id=project_id or self.project_id, zone=self.location, operation_id=operation_name
name=f'projects/{project_id or self.project_id}'
+ f'/locations/{self.location}/operations/{operation_name}'
)

@staticmethod
Expand Down Expand Up @@ -170,11 +171,13 @@ def delete_cluster(
:type timeout: float
:return: The full url to the delete operation if successful, else None
"""
self.log.info("Deleting (project_id=%s, zone=%s, cluster_id=%s)", project_id, self.location, name)
self.log.info("Deleting (project_id=%s, location=%s, cluster_id=%s)", project_id, self.location, name)

try:
resource = self.get_conn().delete_cluster(
project_id=project_id, zone=self.location, cluster_id=name, retry=retry, timeout=timeout
name=f'projects/{project_id}/locations/{self.location}/clusters/{name}',
retry=retry,
timeout=timeout,
)
resource = self.wait_for_operation(resource)
# Returns server-defined url for the resource
Expand Down Expand Up @@ -223,11 +226,14 @@ def create_cluster(
self._append_label(cluster, 'airflow-version', 'v' + version.version)

self.log.info(
"Creating (project_id=%s, zone=%s, cluster_name=%s)", project_id, self.location, cluster.name
"Creating (project_id=%s, location=%s, cluster_name=%s)", project_id, self.location, cluster.name
)
try:
resource = self.get_conn().create_cluster(
project_id=project_id, zone=self.location, cluster=cluster, retry=retry, timeout=timeout
parent=f'projects/{project_id}/locations/{self.location}',
cluster=cluster,
retry=retry,
timeout=timeout,
)
resource = self.wait_for_operation(resource)

Expand Down Expand Up @@ -261,7 +267,7 @@ def get_cluster(
:return: google.cloud.container_v1.types.Cluster
"""
self.log.info(
"Fetching cluster (project_id=%s, zone=%s, cluster_name=%s)",
"Fetching cluster (project_id=%s, location=%s, cluster_name=%s)",
project_id or self.project_id,
self.location,
name,
Expand All @@ -270,7 +276,9 @@ def get_cluster(
return (
self.get_conn()
.get_cluster(
project_id=project_id, zone=self.location, cluster_id=name, retry=retry, timeout=timeout
name=f'projects/{project_id}/locations/{self.location}/clusters/{name}',
retry=retry,
timeout=timeout,
)
.self_link
)
18 changes: 13 additions & 5 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class GKEDeleteClusterOperator(BaseOperator):
:type project_id: str
:param name: The name of the resource to delete, in this case cluster name
:type name: str
:param location: The name of the Google Compute Engine zone in which the cluster
:param location: The name of the Google Compute Engine zone or region in which the cluster
resides.
:type location: str
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
Expand Down Expand Up @@ -158,7 +158,7 @@ class GKECreateClusterOperator(BaseOperator):
:param project_id: The Google Developers Console [project ID or project number]
:type project_id: str
:param location: The name of the Google Compute Engine zone in which the cluster
:param location: The name of the Google Compute Engine or region in which the cluster
resides.
:type location: str
:param body: The Cluster definition to create, can be protobuf or python dict, if
Expand Down Expand Up @@ -273,13 +273,14 @@ class GKEStartPodOperator(KubernetesPodOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:GKEStartPodOperator`
:param location: The name of the Google Kubernetes Engine zone in which the
:param location: The name of the Google Kubernetes Engine zone or region in which the
cluster resides, e.g. 'us-central1-a'
:type location: str
:param cluster_name: The name of the Google Kubernetes Engine cluster the pod
should be spawned in
:type cluster_name: str
:param use_internal_ip: Use the internal IP address as the endpoint.
:type use_internal_ip: bool
:param project_id: The Google Developers Console project id
:type project_id: str
:param gcp_conn_id: The google cloud connection id to use. This allows for
Expand All @@ -294,6 +295,8 @@ class GKEStartPodOperator(KubernetesPodOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param regional: The location param is region name.
:type regional: bool
"""

template_fields = {'project_id', 'location', 'cluster_name'} | set(KubernetesPodOperator.template_fields)
Expand All @@ -307,6 +310,7 @@ def __init__(
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
regional: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -316,6 +320,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.use_internal_ip = use_internal_ip
self.impersonation_chain = impersonation_chain
self.regional = regional

if self.gcp_conn_id is None:
raise AirflowException(
Expand Down Expand Up @@ -356,8 +361,6 @@ def execute(self, context) -> Optional[str]:
"clusters",
"get-credentials",
self.cluster_name,
"--zone",
self.location,
"--project",
self.project_id,
]
Expand All @@ -377,6 +380,11 @@ def execute(self, context) -> Optional[str]:
impersonation_account,
]
)
if self.regional:
cmd.append('--region')
else:
cmd.append('--zone')
cmd.append(self.location)
if self.use_internal_ip:
cmd.append('--internal-ip')
execute_in_subprocess(cmd)
Expand Down
16 changes: 5 additions & 11 deletions tests/providers/google/cloud/hooks/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def test_delete_cluster(self, wait_mock, convert_mock, mock_project_id):
)

client_delete.assert_called_once_with(
project_id=TEST_GCP_PROJECT_ID,
zone=GKE_ZONE,
cluster_id=CLUSTER_NAME,
name=f'projects/{TEST_GCP_PROJECT_ID}/locations/{GKE_ZONE}/clusters/{CLUSTER_NAME}',
retry=retry_mock,
timeout=timeout_mock,
)
Expand Down Expand Up @@ -145,8 +143,7 @@ def test_create_cluster_proto(self, wait_mock, convert_mock, mock_project_id):
)

client_create.assert_called_once_with(
project_id=TEST_GCP_PROJECT_ID,
zone=GKE_ZONE,
parent=f'projects/{TEST_GCP_PROJECT_ID}/locations/{GKE_ZONE}',
cluster=mock_cluster_proto,
retry=retry_mock,
timeout=timeout_mock,
Expand All @@ -173,8 +170,7 @@ def test_create_cluster_dict(self, wait_mock, convert_mock, mock_project_id):
)

client_create.assert_called_once_with(
project_id=TEST_GCP_PROJECT_ID,
zone=GKE_ZONE,
parent=f'projects/{TEST_GCP_PROJECT_ID}/locations/{GKE_ZONE}',
cluster=proto_mock,
retry=retry_mock,
timeout=timeout_mock,
Expand Down Expand Up @@ -228,9 +224,7 @@ def test_get_cluster(self):
)

client_get.assert_called_once_with(
project_id=TEST_GCP_PROJECT_ID,
zone=GKE_ZONE,
cluster_id=CLUSTER_NAME,
name=f'projects/{TEST_GCP_PROJECT_ID}/locations/{GKE_ZONE}/clusters/{CLUSTER_NAME}',
retry=retry_mock,
timeout=timeout_mock,
)
Expand All @@ -256,7 +250,7 @@ def test_get_operation(self):
self.gke_hook._client.get_operation = mock.Mock()
self.gke_hook.get_operation('TEST_OP', project_id=TEST_GCP_PROJECT_ID)
self.gke_hook._client.get_operation.assert_called_once_with(
project_id=TEST_GCP_PROJECT_ID, zone=GKE_ZONE, operation_id='TEST_OP'
name=f'projects/{TEST_GCP_PROJECT_ID}/locations/{GKE_ZONE}/operations/TEST_OP'
)

def test_append_label(self):
Expand Down
55 changes: 49 additions & 6 deletions tests/providers/google/cloud/operators/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,53 @@ def test_execute(self, file_mock, mock_execute_in_subprocess, mock_gcp_hook, exe
'clusters',
'get-credentials',
CLUSTER_NAME,
'--project',
TEST_GCP_PROJECT_ID,
'--zone',
PROJECT_LOCATION,
]
)

assert self.gke_op.config_file == FILE_NAME

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
return_value=[
Connection(
extra=json.dumps(
{"extra__google_cloud_platform__keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}
)
)
],
)
@mock.patch('airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute')
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GoogleBaseHook')
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.execute_in_subprocess')
@mock.patch('tempfile.NamedTemporaryFile')
def test_execute_regional(
self, file_mock, mock_execute_in_subprocess, mock_gcp_hook, exec_mock, get_con_mock
):
self.gke_op.regional = True
type(file_mock.return_value.__enter__.return_value).name = PropertyMock(
side_effect=[FILE_NAME, '/path/to/new-file']
)

self.gke_op.execute(None)

mock_gcp_hook.return_value.provide_authorized_gcloud.assert_called_once()

mock_execute_in_subprocess.assert_called_once_with(
[
'gcloud',
'container',
'clusters',
'get-credentials',
CLUSTER_NAME,
'--project',
TEST_GCP_PROJECT_ID,
'--region',
PROJECT_LOCATION,
]
)

Expand Down Expand Up @@ -282,10 +325,10 @@ def test_execute_with_internal_ip(
'clusters',
'get-credentials',
CLUSTER_NAME,
'--zone',
PROJECT_LOCATION,
'--project',
TEST_GCP_PROJECT_ID,
'--zone',
PROJECT_LOCATION,
'--internal-ip',
]
)
Expand Down Expand Up @@ -325,12 +368,12 @@ def test_execute_with_impersonation_service_account(
'clusters',
'get-credentials',
CLUSTER_NAME,
'--zone',
PROJECT_LOCATION,
'--project',
TEST_GCP_PROJECT_ID,
'--impersonate-service-account',
'test_account@example.com',
'--zone',
PROJECT_LOCATION,
]
)

Expand Down Expand Up @@ -369,12 +412,12 @@ def test_execute_with_impersonation_service_chain_one_element(
'clusters',
'get-credentials',
CLUSTER_NAME,
'--zone',
PROJECT_LOCATION,
'--project',
TEST_GCP_PROJECT_ID,
'--impersonate-service-account',
'test_account@example.com',
'--zone',
PROJECT_LOCATION,
]
)

Expand Down

0 comments on commit a4622e1

Please sign in to comment.