diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py index d7fbc438fd955..f5a53d5585b90 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py @@ -33,7 +33,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook if TYPE_CHECKING: from google.api_core.operation import Operation @@ -409,15 +409,18 @@ def wait_command_execution_result( time.sleep(poll_interval) -class CloudComposerAsyncHook(GoogleBaseHook): +class CloudComposerAsyncHook(GoogleBaseAsyncHook): """Hook for Google Cloud Composer async APIs.""" + sync_hook_class = CloudComposerHook + client_options = ClientOptions(api_endpoint="composer.googleapis.com:443") - def get_environment_client(self) -> EnvironmentsAsyncClient: + async def get_environment_client(self) -> EnvironmentsAsyncClient: """Retrieve client library object that allow access Environments service.""" + sync_hook = await self.get_sync_hook() return EnvironmentsAsyncClient( - credentials=self.get_credentials(), + credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=self.client_options, ) @@ -429,9 +432,8 @@ def get_parent(self, project_id, region): return f"projects/{project_id}/locations/{region}" async def get_operation(self, operation_name): - return await self.get_environment_client().transport.operations_client.get_operation( - name=operation_name - ) + client = await self.get_environment_client() + return await client.transport.operations_client.get_operation(name=operation_name) @GoogleBaseHook.fallback_to_default_project_id async def create_environment( @@ -556,7 +558,7 @@ async def execute_airflow_command( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() return await client.execute_airflow_command( request={ @@ -598,7 +600,7 @@ async def poll_airflow_command( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() return await client.poll_airflow_command( request={ diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py index 96748bc2a7c49..62f63c00b930a 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -52,10 +52,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.pooling_period_seconds = pooling_period_seconds - self.gcp_hook = CloudComposerAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + self.gcp_hook: CloudComposerAsyncHook | None = None def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -70,7 +67,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + def _get_async_hook(self): + return CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + async def run(self): + self.gcp_hook = self._get_async_hook() while True: operation = await self.gcp_hook.get_operation(operation_name=self.operation_name) if operation.done: @@ -107,11 +111,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval - - self.gcp_hook = CloudComposerAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + self.gcp_hook: CloudComposerAsyncHook | None = None def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -127,7 +127,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + def _get_async_hook(self): + return CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + async def run(self): + self.gcp_hook = self._get_async_hook() try: result = await self.gcp_hook.wait_command_execution_result( project_id=self.project_id, @@ -184,10 +191,7 @@ def __init__( self.poll_interval = poll_interval self.composer_airflow_version = composer_airflow_version - self.gcp_hook = CloudComposerAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + self.gcp_hook: CloudComposerAsyncHook | None = None def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -248,7 +252,14 @@ def _check_dag_runs_states( return False return True + def _get_async_hook(self): + return CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + async def run(self): + self.gcp_hook = self._get_async_hook() try: while True: if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp():