Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] google provider - split GkeStartPodOperator execute #23518

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import os
import tempfile
import warnings
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Generator, Optional, Sequence, Union

from google.cloud.container_v1.types import Cluster

Expand Down Expand Up @@ -336,11 +337,22 @@ def __init__(
if self.config_file:
raise AirflowException("config_file is not an allowed parameter for the GKEStartPodOperator.")

def execute(self, context: 'Context') -> Optional[str]:
hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id)
self.project_id = self.project_id or hook.project_id
@staticmethod
@contextmanager
def get_gke_config_file(
gcp_conn_id,
project_id: Optional[str],
cluster_name: str,
impersonation_chain: Optional[Union[str, Sequence[str]]],
regional: bool,
location: str,
use_internal_ip: bool,
) -> Generator[str, None, None]:

if not self.project_id:
hook = GoogleBaseHook(gcp_conn_id=gcp_conn_id)
project_id = project_id or hook.project_id

if not project_id:
raise AirflowException(
"The project id must be passed either as "
"keyword project_id parameter or as project_id extra "
Expand All @@ -363,15 +375,15 @@ def execute(self, context: 'Context') -> Optional[str]:
"container",
"clusters",
"get-credentials",
self.cluster_name,
cluster_name,
"--project",
self.project_id,
project_id,
]
if self.impersonation_chain:
if isinstance(self.impersonation_chain, str):
impersonation_account = self.impersonation_chain
elif len(self.impersonation_chain) == 1:
impersonation_account = self.impersonation_chain[0]
if impersonation_chain:
if isinstance(impersonation_chain, str):
impersonation_account = impersonation_chain
elif len(impersonation_chain) == 1:
impersonation_account = impersonation_chain[0]
else:
raise AirflowException(
"Chained list of accounts is not supported, please specify only one service account"
Expand All @@ -383,15 +395,28 @@ def execute(self, context: 'Context') -> Optional[str]:
impersonation_account,
]
)
if self.regional:
if regional:
cmd.append('--region')
else:
cmd.append('--zone')
cmd.append(self.location)
if self.use_internal_ip:
cmd.append(location)
if use_internal_ip:
cmd.append('--internal-ip')
execute_in_subprocess(cmd)

# Tell `KubernetesPodOperator` where the config file is located
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
yield os.environ[KUBE_CONFIG_ENV_VAR]

def execute(self, context: 'Context') -> Optional[str]:

with GKEStartPodOperator.get_gke_config_file(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
cluster_name=self.cluster_name,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
location=self.location,
use_internal_ip=self.use_internal_ip,
) as config_file:
self.config_file = config_file
return super().execute(context)