Skip to content

Commit

Permalink
fix(google): update only the necessary column in config file for impe…
Browse files Browse the repository at this point in the history
…rsonation
  • Loading branch information
Lee-W committed Aug 2, 2023
1 parent 7ed3550 commit 7c7181c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 29 deletions.
40 changes: 17 additions & 23 deletions astronomer/providers/google/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.utils.process_utils import execute_in_subprocess, patch_environ
from google.auth import default, impersonated_credentials
from google.auth import impersonated_credentials
from google.cloud.container_v1 import ClusterManagerClient

KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
Expand Down Expand Up @@ -55,6 +55,7 @@ def _get_gke_config_file(
"--project",
project_id,
]
impersonation_account = None
if impersonation_chain:
if isinstance(impersonation_chain, str):
impersonation_account = impersonation_chain
Expand All @@ -71,42 +72,35 @@ def _get_gke_config_file(
impersonation_account,
]
)
creds, _ = default()

if regional:
cmd.append("--region")
else:
cmd.append("--zone")
cmd.append(location)
if use_internal_ip:
cmd.append("--internal-ip")
execute_in_subprocess(cmd)

if impersonation_account:
creds = hook.get_credentials()
impersonated_creds = impersonated_credentials.Credentials(
source_credentials=creds,
target_principal=impersonation_account,
target_scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

try:
client = ClusterManagerClient(credentials=impersonated_creds)
name = f"projects/{project_id}/locations/{location}/clusters/{cluster_name}"
cluster = client.get_cluster(name=name)
if not use_internal_ip:
cluster_url = f"https://{cluster.endpoint}"
else:
cluster_url = f"https://{cluster.private_cluster_config.private_endpoint}"
ssl_ca_cert = cluster.master_auth.cluster_ca_certificate
client.get_cluster(name=name)
except Exception as e:
raise Exception(f"Error while creating impersonated creds: {e}")

with open(conf_file.name, "r") as input_file:
with open(conf_file.name) as input_file:
config_content = yaml.safe_load(input_file.read())

config_content["users"][0]["user"]["token"] = impersonated_creds.token
config_content["clusters"][0]["server"] = cluster_url
config_content["clusters"][0]["certificate-authority-data"] = ssl_ca_cert
with open(conf_file.name, "w") as output_file:
yaml.dump(config_content, output_file, default_flow_style=False)

if regional:
cmd.append("--region")
else:
cmd.append("--zone")
cmd.append(location)
if use_internal_ip:
cmd.append("--internal-ip")
execute_in_subprocess(cmd)

yaml.dump(config_content, output_file)
# Tell `KubernetesPodOperator` where the config file is located
yield os.environ[KUBE_CONFIG_ENV_VAR]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains Google GKE operators."""
from __future__ import annotations

from typing import Any, Sequence, Union
from typing import Any, Sequence

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
use_internal_ip: bool = False,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Union[str, Sequence[str]] | None = None,
impersonation_chain: str | Sequence[str] | None = None,
regional: bool = False,
poll_interval: float = 5,
logging_interval: int | None = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, AsyncIterator, Sequence, Union
from typing import Any, AsyncIterator, Sequence

from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
from airflow.triggers.base import TriggerEvent
Expand All @@ -13,7 +13,6 @@
from astronomer.providers.google.cloud import _get_gke_config_file



class GKEStartPodTrigger(WaitContainerTrigger):
"""
Fetch GKE cluster config and wait for pod to start up.
Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(
use_internal_ip: bool = False,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Union[str, Sequence[str]] | None = None,
impersonation_chain: str | Sequence[str] | None = None,
regional: bool = False,
cluster_context: str | None = None,
in_cluster: bool | None = None,
Expand Down Expand Up @@ -97,7 +96,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Wait for pod to reach terminal state"""
try:
with _get_gke_config_file(
Expand Down

0 comments on commit 7c7181c

Please sign in to comment.