diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index f8c9b7875c967..546eba1070635 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -414,7 +414,7 @@ def __init__( *, location: str, cluster_name: str, - use_internal_ip: bool | None = None, + use_internal_ip: bool = False, project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -433,15 +433,6 @@ def __init__( ) is_delete_operator_pod = False - if use_internal_ip is not None: - warnings.warn( - f"You have set parameter use_internal_ip in class {self.__class__.__name__}. " - "In current implementation of the operator the parameter is not used and will " - "be deleted in future.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - if regional is not None: warnings.warn( f"You have set parameter regional in class {self.__class__.__name__}. " @@ -457,6 +448,7 @@ def __init__( self.cluster_name = cluster_name self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.use_internal_ip = use_internal_ip self.pod: V1Pod | None = None self._ssl_ca_cert: str | None = None @@ -516,7 +508,10 @@ def fetch_cluster_info(self) -> tuple[str, str | None]: project_id=self.project_id, ) - self._cluster_url = f"https://{cluster.endpoint}" + if not self.use_internal_ip: + self._cluster_url = f"https://{cluster.endpoint}" + else: + self._cluster_url = f"https://{cluster.private_cluster_config.private_endpoint}" self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate return self._cluster_url, self._ssl_ca_cert diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 6fa0184e88bf6..5d6acb196ad25 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -66,6 +66,7 @@ TEMP_FILE = "tempfile.NamedTemporaryFile" GKE_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator" CLUSTER_URL = "https://test-host" +CLUSTER_PRIVATE_URL = "https://test-private-host" SSL_CA_CERT = "TEST_SSL_CA_CERT_CONTENT" @@ -293,6 +294,31 @@ def test_execute_with_impersonation_service_chain_one_element( fetch_cluster_info_mock.assert_called_once() + @pytest.mark.parametrize("use_internal_ip", [True, False]) + @mock.patch(f"{GKE_HOOK_PATH}.get_cluster") + def test_cluster_info(self, get_cluster_mock, use_internal_ip): + get_cluster_mock.return_value = mock.MagicMock( + **{ + "endpoint": "test-host", + "private_cluster_config.private_endpoint": "test-private-host", + "master_auth.cluster_ca_certificate": SSL_CA_CERT, + } + ) + gke_op = GKEStartPodOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + image=IMAGE, + use_internal_ip=use_internal_ip, + ) + cluster_url, ssl_ca_cert = gke_op.fetch_cluster_info() + + assert cluster_url == CLUSTER_PRIVATE_URL if use_internal_ip else CLUSTER_URL + assert ssl_ca_cert == SSL_CA_CERT + class TestGKEPodOperatorAsync: def setup_method(self):