diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py index fc2bc3da75454..a2a541648a3a2 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py @@ -15,22 +15,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains a Google Cloud Vertex AI hook.""" +""" +This module contains a Google Cloud Vertex AI hook. + +.. spelling:word-list:: + + JobServiceAsyncClient +""" from __future__ import annotations +import asyncio +from functools import lru_cache from typing import TYPE_CHECKING, Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.aiplatform import CustomJob, HyperparameterTuningJob, gapic, hyperparameter_tuning -from google.cloud.aiplatform_v1 import JobServiceClient, types +from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook if TYPE_CHECKING: from google.api_core.operation import Operation from google.api_core.retry import Retry + from google.api_core.retry_async import AsyncRetry from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager @@ -172,6 +182,7 @@ def create_hyperparameter_tuning_job( tensorboard: str | None = None, sync: bool = True, # END: run param + wait_job_completed: bool = True, ) -> HyperparameterTuningJob: """ Create a HyperparameterTuningJob. @@ -256,6 +267,7 @@ def create_hyperparameter_tuning_job( https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training :param sync: Whether to execute this method synchronously. If False, this method will unblock and it will be executed in a concurrent Future. + :param wait_job_completed: Whether to wait for the job completed. """ custom_job = self.get_custom_job_object( project=project_id, @@ -292,7 +304,11 @@ def create_hyperparameter_tuning_job( tensorboard=tensorboard, sync=sync, ) - self._hyperparameter_tuning_job.wait() + + if wait_job_completed: + self._hyperparameter_tuning_job.wait() + else: + self._hyperparameter_tuning_job._wait_for_resource_creation() return self._hyperparameter_tuning_job @GoogleBaseHook.fallback_to_default_project_id @@ -413,3 +429,104 @@ def delete_hyperparameter_tuning_job( metadata=metadata, ) return result + + +class HyperparameterTuningJobAsyncHook(GoogleBaseHook): + """Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs.""" + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ): + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + @lru_cache + def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient: + """ + Retrieves Vertex AI async client. + + :return: Google Cloud Vertex AI client object. + """ + endpoint = f"{region}-aiplatform.googleapis.com:443" if region and region != "global" else None + return JobServiceAsyncClient( + credentials=self.get_credentials(), + client_info=CLIENT_INFO, + client_options=ClientOptions(api_endpoint=endpoint), + ) + + async def get_hyperparameter_tuning_job( + self, + project_id: str, + location: str, + job_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.HyperparameterTuningJob: + """ + Retrieves a hyperparameter tuning job. + + :param project_id: Required. The ID of the Google Cloud project that the job belongs to. + :param location: Required. The ID of the Google Cloud region that the job belongs to. + :param job_id: Required. The hyperparameter tuning job id. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client: JobServiceAsyncClient = self.get_job_service_client(region=location) + job_name = client.hyperparameter_tuning_job_path(project_id, location, job_id) + + result = await client.get_hyperparameter_tuning_job( + request={ + "name": job_name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return result + + async def wait_hyperparameter_tuning_job( + self, + project_id: str, + location: str, + job_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + poll_interval: int = 10, + ) -> types.HyperparameterTuningJob: + statuses_complete = { + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_PAUSED, + JobState.JOB_STATE_SUCCEEDED, + } + while True: + try: + self.log.info("Requesting hyperparameter tuning job with id %s", job_id) + job: types.HyperparameterTuningJob = await self.get_hyperparameter_tuning_job( + project_id=project_id, + location=location, + job_id=job_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except Exception as ex: + self.log.exception("Exception occurred while requesting job %s", job_id) + raise AirflowException(ex) + + self.log.info("Status of the hyperparameter tuning job %s is %s", job.name, job.state.name) + if job.state in statuses_complete: + return job + + self.log.info("Sleeping for %s seconds.", poll_interval) + await asyncio.sleep(poll_interval) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py index 060eeff249471..6fac474c4627b 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py @@ -20,12 +20,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.aiplatform_v1.types import HyperparameterTuningJob +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( HyperparameterTuningJobHook, ) @@ -34,6 +36,7 @@ VertexAITrainingLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparameterTuningJobTrigger if TYPE_CHECKING: from google.api_core.retry import Retry @@ -124,7 +127,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator): `service_account` is required with provided `tensorboard`. For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :param sync: Whether to execute this method synchronously. If False, this method will unblock and it + :param sync: Whether to execute this method synchronously. If False, this method will unblock, and it will be executed in a concurrent Future. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term @@ -135,6 +138,9 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Run operator in the deferrable mode. Note that it requires calling the operator + with `sync=False` parameter. + :param poll_interval: Interval size which defines how often job status is checked in deferrable mode. """ template_fields = [ @@ -177,6 +183,8 @@ def __init__( # END: run param gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, **kwargs, ) -> None: super().__init__(**kwargs) @@ -209,8 +217,17 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.hook: HyperparameterTuningJobHook | None = None + self.deferrable = deferrable + self.poll_interval = poll_interval def execute(self, context: Context): + if self.deferrable and self.sync: + raise AirflowException( + "Deferrable mode can be used only with sync=False option. " + "If you are willing to run the operator in deferrable mode, please, set sync=False. " + "Otherwise, disable deferrable mode `deferrable=False`." + ) + self.log.info("Creating Hyperparameter Tuning job") self.hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, @@ -243,12 +260,26 @@ def execute(self, context: Context): enable_web_access=self.enable_web_access, tensorboard=self.tensorboard, sync=self.sync, + wait_job_completed=not self.deferrable, ) hyperparameter_tuning_job = result.to_dict() hyperparameter_tuning_job_id = self.hook.extract_hyperparameter_tuning_job_id( hyperparameter_tuning_job ) + if self.deferrable: + self.defer( + trigger=CreateHyperparameterTuningJobTrigger( + conn_id=self.gcp_conn_id, + project_id=self.project_id, + location=self.region, + job_id=hyperparameter_tuning_job_id, + poll_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id) self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id) @@ -262,6 +293,32 @@ def on_kill(self) -> None: if self.hook: self.hook.cancel_hyperparameter_tuning_job() + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]: + if event and event["status"] == "error": + raise AirflowException(event["message"]) + job: dict[str, Any] = event["job"] + self.log.info("Hyperparameter tuning job %s created and completed successfully.", job["name"]) + hook = HyperparameterTuningJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + job_id = hook.extract_hyperparameter_tuning_job_id(job) + self.xcom_push( + context, + key="hyperparameter_tuning_job_id", + value=job_id, + ) + self.xcom_push( + context, + key="training_conf", + value={ + "training_conf_id": job_id, + "region": self.region, + "project_id": self.project_id, + }, + ) + return event["job"] + class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator): """ diff --git a/airflow/providers/google/cloud/triggers/vertex_ai.py b/airflow/providers/google/cloud/triggers/vertex_ai.py new file mode 100644 index 0000000000000..b4b121895406f --- /dev/null +++ b/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, AsyncIterator, Sequence + +from google.cloud.aiplatform_v1 import HyperparameterTuningJob, JobState + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( + HyperparameterTuningJobAsyncHook, +) +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class CreateHyperparameterTuningJobTrigger(BaseTrigger): + """CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation.""" + + statuses_success = { + JobState.JOB_STATE_PAUSED, + JobState.JOB_STATE_SUCCEEDED, + } + + def __init__( + self, + conn_id: str, + project_id: str, + location: str, + job_id: str, + poll_interval: int, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.conn_id = conn_id + self.project_id = project_id + self.location = location + self.job_id = job_id + self.poll_interval = poll_interval + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.vertex_ai.CreateHyperparameterTuningJobTrigger", + { + "conn_id": self.conn_id, + "project_id": self.project_id, + "location": self.location, + "job_id": self.job_id, + "poll_interval": self.poll_interval, + "impersonation_chain": self.impersonation_chain, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = self._get_async_hook() + try: + job = await hook.wait_hyperparameter_tuning_job( + project_id=self.project_id, + location=self.location, + job_id=self.job_id, + poll_interval=self.poll_interval, + ) + except AirflowException as ex: + yield TriggerEvent( + { + "status": "error", + "message": str(ex), + } + ) + return + + status = "success" if job.state in self.statuses_success else "error" + message = f"Hyperparameter tuning job {job.name} completed with status {job.state.name}" + yield TriggerEvent( + { + "status": status, + "message": message, + "job": HyperparameterTuningJob.to_dict(job), + } + ) + + def _get_async_hook(self) -> HyperparameterTuningJobAsyncHook: + return HyperparameterTuningJobAsyncHook( + gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain + ) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index cfd14a12c70b4..694acf706158b 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -976,6 +976,9 @@ triggers: - integration-name: Google Cloud python-modules: - airflow.providers.google.cloud.triggers.cloud_batch + - integration-name: Google Vertex AI + python-modules: + - airflow.providers.google.cloud.triggers.vertex_ai transfers: - source-integration-name: Presto diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index 5ba490bf55275..c778a624d212b 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -368,6 +368,15 @@ The operator returns hyperparameter tuning job id in :ref:`XCom ` :start-after: [START how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator] :end-before: [END how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator] +:class:`~airflow.providers.google.cloud.operators.vertex_ai.hyperparameter_tuning_job.CreateHyperparameterTuningJobOperator` +also supports deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_hyperparameter_tuning_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator_deferrable] + To delete hyperparameter tuning job you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.hyperparameter_tuning_job.DeleteHyperparameterTuningJobOperator`. diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_hyperparameter_tuning_job.py b/tests/providers/google/cloud/hooks/vertex_ai/test_hyperparameter_tuning_job.py index cb62bb91e7e7e..f99c16d7e3c34 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_hyperparameter_tuning_job.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_hyperparameter_tuning_job.py @@ -21,8 +21,12 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud import aiplatform +from google.cloud.aiplatform_v1 import JobState +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( + HyperparameterTuningJobAsyncHook, HyperparameterTuningJobHook, ) from tests.providers.google.cloud.utils.base_gcp_mock import ( @@ -35,11 +39,86 @@ TEST_PROJECT_ID: str = "test-project-id" TEST_HYPERPARAMETER_TUNING_JOB_ID = "test_hyperparameter_tuning_job_id" TEST_UPDATE_MASK: dict = {} +TEST_DISPLAY_NAME = "Test display name" +TEST_METRIC_SPECS = { + "accuracy": "maximize", +} +TEST_PARAM_SPECS = { + "learning_rate": aiplatform.hyperparameter_tuning.DoubleParameterSpec(min=0.01, max=1, scale="log"), +} +TEST_MAX_TRIAL_COUNT = 3 +TEST_PARALLEL_TRIAL_COUNT = 3 +# CustomJob param +TEST_WORKER_POOL_SPECS = [ + { + "machine_spec": { + "machine_type": "n1-standard-4", + "accelerator_type": "ACCELERATOR_TYPE_UNSPECIFIED", + "accelerator_count": 0, + }, + } +] +TEST_BASE_OUTPUT_DIR = None +TEST_CUSTOM_JOB_LABELS = None +TEST_CUSTOM_JOB_ENCRYPTION_SPEC_KEY_NAME = None +TEST_STAGING_BUCKET = None +# CustomJob param +TEST_MAX_FAILED_TRIAL_COUNT = 0 +TEST_SEARCH_ALGORITHM = None +TEST_MEASUREMENT_SELECTION = "best" +TEST_HYPERPARAMETER_TUNING_JOB_LABELS = None +TEST_HYPERPARAMETER_TUNING_JOB_ENCRYPTION_SPEC_KEY_NAME = None +# run param +TEST_SERVICE_ACCOUNT = None +TEST_NETWORK = None +TEST_TIMEOUT = None +TEST_RESTART_JOB_ON_WORKER_RESTART = False +TEST_ENABLE_WEB_ACCESS = False +TEST_TENSORBOARD = None +TEST_SYNC = True +TEST_WAIT_JOB_COMPLETED = True +TEST_CREATE_HYPERPARAMETER_TUNING_JOB_PARAMS = dict( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + display_name=TEST_DISPLAY_NAME, + metric_spec=TEST_METRIC_SPECS, + parameter_spec=TEST_PARAM_SPECS, + max_trial_count=TEST_MAX_TRIAL_COUNT, + parallel_trial_count=TEST_PARALLEL_TRIAL_COUNT, + # START: CustomJob param + worker_pool_specs=TEST_WORKER_POOL_SPECS, + base_output_dir=TEST_BASE_OUTPUT_DIR, + custom_job_labels=TEST_CUSTOM_JOB_LABELS, + custom_job_encryption_spec_key_name=TEST_CUSTOM_JOB_ENCRYPTION_SPEC_KEY_NAME, + staging_bucket=TEST_STAGING_BUCKET, + # END: CustomJob param + max_failed_trial_count=TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=TEST_SEARCH_ALGORITHM, + measurement_selection=TEST_MEASUREMENT_SELECTION, + hyperparameter_tuning_job_labels=TEST_HYPERPARAMETER_TUNING_JOB_LABELS, + hyperparameter_tuning_job_encryption_spec_key_name=TEST_HYPERPARAMETER_TUNING_JOB_ENCRYPTION_SPEC_KEY_NAME, + # START: run param + service_account=TEST_SERVICE_ACCOUNT, + network=TEST_NETWORK, + timeout=TEST_TIMEOUT, + restart_job_on_worker_restart=TEST_RESTART_JOB_ON_WORKER_RESTART, + enable_web_access=TEST_ENABLE_WEB_ACCESS, + tensorboard=TEST_TENSORBOARD, + sync=TEST_SYNC, + # END: run param + wait_job_completed=TEST_WAIT_JOB_COMPLETED, +) BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" -HYPERPARAMETER_TUNING_JOB_STRING = ( +HYPERPARAMETER_TUNING_JOB_HOOK_PATH = ( "airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job.{}" ) +HYPERPARAMETER_TUNING_JOB_HOOK_STRING = ( + HYPERPARAMETER_TUNING_JOB_HOOK_PATH.format("HyperparameterTuningJobHook") + ".{}" +) +HYPERPARAMETER_TUNING_JOB_ASYNC_HOOK_STRING = ( + HYPERPARAMETER_TUNING_JOB_HOOK_PATH.format("HyperparameterTuningJobAsyncHook") + ".{}" +) class TestHyperparameterTuningJobWithDefaultProjectIdHook: @@ -53,7 +132,7 @@ def setup_method(self): ): self.hook = HyperparameterTuningJobHook(gcp_conn_id=TEST_GCP_CONN_ID) - @mock.patch(HYPERPARAMETER_TUNING_JOB_STRING.format("HyperparameterTuningJobHook.get_job_service_client")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_job_service_client")) def test_delete_hyperparameter_tuning_job(self, mock_client) -> None: self.hook.delete_hyperparameter_tuning_job( project_id=TEST_PROJECT_ID, @@ -75,7 +154,7 @@ def test_delete_hyperparameter_tuning_job(self, mock_client) -> None: TEST_HYPERPARAMETER_TUNING_JOB_ID, ) - @mock.patch(HYPERPARAMETER_TUNING_JOB_STRING.format("HyperparameterTuningJobHook.get_job_service_client")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_job_service_client")) def test_get_hyperparameter_tuning_job(self, mock_client) -> None: self.hook.get_hyperparameter_tuning_job( project_id=TEST_PROJECT_ID, @@ -97,7 +176,7 @@ def test_get_hyperparameter_tuning_job(self, mock_client) -> None: TEST_HYPERPARAMETER_TUNING_JOB_ID, ) - @mock.patch(HYPERPARAMETER_TUNING_JOB_STRING.format("HyperparameterTuningJobHook.get_job_service_client")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_job_service_client")) def test_list_hyperparameter_tuning_jobs(self, mock_client) -> None: self.hook.list_hyperparameter_tuning_jobs( project_id=TEST_PROJECT_ID, @@ -126,7 +205,7 @@ def setup_method(self): ): self.hook = HyperparameterTuningJobHook(gcp_conn_id=TEST_GCP_CONN_ID) - @mock.patch(HYPERPARAMETER_TUNING_JOB_STRING.format("HyperparameterTuningJobHook.get_job_service_client")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_job_service_client")) def test_delete_hyperparameter_tuning_job(self, mock_client) -> None: self.hook.delete_hyperparameter_tuning_job( project_id=TEST_PROJECT_ID, @@ -148,7 +227,7 @@ def test_delete_hyperparameter_tuning_job(self, mock_client) -> None: TEST_HYPERPARAMETER_TUNING_JOB_ID, ) - @mock.patch(HYPERPARAMETER_TUNING_JOB_STRING.format("HyperparameterTuningJobHook.get_job_service_client")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_job_service_client")) def test_get_hyperparameter_tuning_job(self, mock_client) -> None: self.hook.get_hyperparameter_tuning_job( project_id=TEST_PROJECT_ID, @@ -170,7 +249,7 @@ def test_get_hyperparameter_tuning_job(self, mock_client) -> None: TEST_HYPERPARAMETER_TUNING_JOB_ID, ) - @mock.patch(HYPERPARAMETER_TUNING_JOB_STRING.format("HyperparameterTuningJobHook.get_job_service_client")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_job_service_client")) def test_list_hyperparameter_tuning_jobs(self, mock_client) -> None: self.hook.list_hyperparameter_tuning_jobs( project_id=TEST_PROJECT_ID, @@ -190,3 +269,197 @@ def test_list_hyperparameter_tuning_jobs(self, mock_client) -> None: timeout=None, ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + +class TestHyperparameterTuningJobHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = HyperparameterTuningJobHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_hyperparameter_tuning_job_object")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_custom_job_object")) + def test_create_hyperparameter_tuning_job( + self, + mock_get_custom_job_object, + mock_get_hyperparameter_tuning_job_object, + ): + mock_custom_job = mock_get_custom_job_object.return_value + + result = self.hook.create_hyperparameter_tuning_job(**TEST_CREATE_HYPERPARAMETER_TUNING_JOB_PARAMS) + + mock_get_custom_job_object.assert_called_once_with( + project=TEST_PROJECT_ID, + location=TEST_REGION, + display_name=TEST_DISPLAY_NAME, + worker_pool_specs=TEST_WORKER_POOL_SPECS, + base_output_dir=TEST_BASE_OUTPUT_DIR, + labels=TEST_CUSTOM_JOB_LABELS, + encryption_spec_key_name=TEST_CUSTOM_JOB_ENCRYPTION_SPEC_KEY_NAME, + staging_bucket=TEST_STAGING_BUCKET, + ) + mock_get_hyperparameter_tuning_job_object.assert_called_once_with( + project=TEST_PROJECT_ID, + location=TEST_REGION, + display_name=TEST_DISPLAY_NAME, + custom_job=mock_custom_job, + metric_spec=TEST_METRIC_SPECS, + parameter_spec=TEST_PARAM_SPECS, + max_trial_count=TEST_MAX_TRIAL_COUNT, + parallel_trial_count=TEST_PARALLEL_TRIAL_COUNT, + max_failed_trial_count=TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=TEST_SEARCH_ALGORITHM, + measurement_selection=TEST_MEASUREMENT_SELECTION, + labels=TEST_HYPERPARAMETER_TUNING_JOB_LABELS, + encryption_spec_key_name=TEST_HYPERPARAMETER_TUNING_JOB_ENCRYPTION_SPEC_KEY_NAME, + ) + self.hook._hyperparameter_tuning_job.run.assert_called_once_with( + service_account=TEST_SERVICE_ACCOUNT, + network=TEST_NETWORK, + timeout=TEST_TIMEOUT, + restart_job_on_worker_restart=TEST_RESTART_JOB_ON_WORKER_RESTART, + enable_web_access=TEST_ENABLE_WEB_ACCESS, + tensorboard=TEST_TENSORBOARD, + sync=TEST_SYNC, + ) + self.hook._hyperparameter_tuning_job.wait.assert_called_once() + self.hook._hyperparameter_tuning_job._wait_for_resource_creation.assert_not_called() + assert result == self.hook._hyperparameter_tuning_job + + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_hyperparameter_tuning_job_object")) + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_STRING.format("get_custom_job_object")) + def test_create_hyperparameter_tuning_job_no_wait(self, _, __): + params = dict(TEST_CREATE_HYPERPARAMETER_TUNING_JOB_PARAMS) + params["wait_job_completed"] = False + + result = self.hook.create_hyperparameter_tuning_job(**params) + + self.hook._hyperparameter_tuning_job.wait.assert_not_called() + self.hook._hyperparameter_tuning_job._wait_for_resource_creation.assert_called_once() + assert result == self.hook._hyperparameter_tuning_job + + +class TestHyperparameterTuningJobAsyncHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = HyperparameterTuningJobAsyncHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @pytest.mark.asyncio + @mock.patch(HYPERPARAMETER_TUNING_JOB_ASYNC_HOOK_STRING.format("get_job_service_client")) + async def test_get_hyperparameter_tuning_job(self, mock_get_job_service_client): + mock_client = mock_get_job_service_client.return_value + mock_job_name = mock_client.hyperparameter_tuning_job_path.return_value + mock_job = mock.MagicMock() + mock_async_get_hyperparameter_tuning_job = mock.AsyncMock(return_value=mock_job) + mock_client.get_hyperparameter_tuning_job.side_effect = mock_async_get_hyperparameter_tuning_job + + result = await self.hook.get_hyperparameter_tuning_job( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_HYPERPARAMETER_TUNING_JOB_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + mock_get_job_service_client.assert_called_once_with(region=TEST_REGION) + mock_client.hyperparameter_tuning_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_HYPERPARAMETER_TUNING_JOB_ID + ) + mock_async_get_hyperparameter_tuning_job.assert_awaited_once_with( + request={"name": mock_job_name}, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + assert result == mock_job + + @pytest.mark.parametrize( + "state", + [ + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_PAUSED, + JobState.JOB_STATE_SUCCEEDED, + ], + ) + @pytest.mark.asyncio + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_PATH.format("asyncio.sleep")) + async def test_wait_hyperparameter_tuning_job(self, mock_sleep, state): + mock_job = mock.MagicMock(state=state) + mock_async_get_hpt_job = mock.AsyncMock(return_value=mock_job) + mock_get_hpt_job = mock.MagicMock(side_effect=mock_async_get_hpt_job) + + await_kwargs = dict( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_HYPERPARAMETER_TUNING_JOB_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + with mock.patch.object(self.hook, "get_hyperparameter_tuning_job", mock_get_hpt_job): + result = await self.hook.wait_hyperparameter_tuning_job(**await_kwargs) + + mock_async_get_hpt_job.assert_awaited_once_with(**await_kwargs) + mock_sleep.assert_not_awaited() + assert result == mock_job + + @pytest.mark.parametrize( + "state", + [ + JobState.JOB_STATE_UNSPECIFIED, + JobState.JOB_STATE_QUEUED, + JobState.JOB_STATE_PENDING, + JobState.JOB_STATE_RUNNING, + JobState.JOB_STATE_CANCELLING, + JobState.JOB_STATE_EXPIRED, + JobState.JOB_STATE_UPDATING, + JobState.JOB_STATE_PARTIALLY_SUCCEEDED, + ], + ) + @pytest.mark.asyncio + @mock.patch(HYPERPARAMETER_TUNING_JOB_HOOK_PATH.format("asyncio.sleep")) + async def test_wait_hyperparameter_tuning_job_waited(self, mock_sleep, state): + mock_job_incomplete = mock.MagicMock(state=state) + mock_job_complete = mock.MagicMock(state=JobState.JOB_STATE_SUCCEEDED) + mock_async_get_ht_job = mock.AsyncMock(side_effect=[mock_job_incomplete, mock_job_complete]) + mock_get_ht_job = mock.MagicMock(side_effect=mock_async_get_ht_job) + + await_kwargs = dict( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_HYPERPARAMETER_TUNING_JOB_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + with mock.patch.object(self.hook, "get_hyperparameter_tuning_job", mock_get_ht_job): + result = await self.hook.wait_hyperparameter_tuning_job(**await_kwargs) + + mock_async_get_ht_job.assert_has_awaits( + [ + mock.call(**await_kwargs), + mock.call(**await_kwargs), + ] + ) + mock_sleep.assert_awaited_once() + assert result == mock_job_complete + + @pytest.mark.asyncio + async def test_wait_hyperparameter_tuning_job_exception(self): + mock_get_ht_job = mock.MagicMock(side_effect=Exception) + with mock.patch.object(self.hook, "get_hyperparameter_tuning_job", mock_get_ht_job): + with pytest.raises(AirflowException): + await self.hook.wait_hyperparameter_tuning_job( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_HYPERPARAMETER_TUNING_JOB_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 381d582ae7b97..686b7ff7c2320 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -17,11 +17,13 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call +import pytest from google.api_core.gapic_v1.method import DEFAULT from google.api_core.retry import Retry +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import ( CreateAutoMLForecastingTrainingJobOperator, CreateAutoMLImageTrainingJobOperator, @@ -1345,8 +1347,127 @@ def test_execute(self, mock_hook, to_dict_mock): enable_web_access=False, tensorboard=None, sync=False, + wait_job_completed=True, ) + @mock.patch( + VERTEX_AI_PATH.format("hyperparameter_tuning_job.CreateHyperparameterTuningJobOperator.defer") + ) + @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook")) + def test_deferrable(self, mock_hook, mock_defer): + op = CreateHyperparameterTuningJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + worker_pool_specs=[], + sync=False, + parameter_spec={}, + metric_spec={}, + max_trial_count=15, + parallel_trial_count=3, + deferrable=True, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_defer.assert_called_once() + + def test_deferrable_sync_error(self): + op = CreateHyperparameterTuningJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + worker_pool_specs=[], + sync=True, + parameter_spec={}, + metric_spec={}, + max_trial_count=15, + parallel_trial_count=3, + deferrable=True, + ) + with pytest.raises(AirflowException): + op.execute(context={"ti": mock.MagicMock()}) + + @mock.patch( + VERTEX_AI_PATH.format("hyperparameter_tuning_job.CreateHyperparameterTuningJobOperator.xcom_push") + ) + @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook")) + def test_execute_complete(self, mock_hook, mock_xcom_push): + test_job_id = "test_job_id" + test_job = {"name": f"test/{test_job_id}"} + event = { + "status": "success", + "message": "test message", + "job": test_job, + } + mock_hook.return_value.extract_hyperparameter_tuning_job_id.return_value = test_job_id + mock_context = mock.MagicMock() + + op = CreateHyperparameterTuningJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + worker_pool_specs=[], + sync=False, + parameter_spec={}, + metric_spec={}, + max_trial_count=15, + parallel_trial_count=3, + ) + + result = op.execute_complete(context=mock_context, event=event) + + mock_xcom_push.assert_has_calls( + [ + call(mock_context, key="hyperparameter_tuning_job_id", value=test_job_id), + call( + mock_context, + key="training_conf", + value={ + "training_conf_id": test_job_id, + "region": GCP_LOCATION, + "project_id": GCP_PROJECT, + }, + ), + ] + ) + assert result == test_job + + def test_execute_complete_error(self): + event = { + "status": "error", + "message": "test error message", + } + + op = CreateHyperparameterTuningJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + worker_pool_specs=[], + sync=False, + parameter_spec={}, + metric_spec={}, + max_trial_count=15, + parallel_trial_count=3, + ) + + with pytest.raises(AirflowException): + op.execute_complete(context=mock.MagicMock(), event=event) + class TestVertexAIGetHyperparameterTuningJobOperator: @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJob.to_dict")) diff --git a/tests/providers/google/cloud/triggers/test_vertex_ai.py b/tests/providers/google/cloud/triggers/test_vertex_ai.py new file mode 100644 index 0000000000000..5d68e171cf45a --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_vertex_ai.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import patch + +import pytest +from google.cloud.aiplatform_v1 import JobState + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparameterTuningJobTrigger +from airflow.triggers.base import TriggerEvent + +TEST_CONN_ID = "test_connection" +TEST_PROJECT_ID = "test_propject_id" +TEST_LOCATION = "us-central-1" +TEST_HPT_JOB_ID = "test_job_id" +TEST_POLL_INTERVAL = 20 +TEST_IMPERSONATION_CHAIN = "test_chain" +TEST_HPT_JOB_NAME = ( + f"projects/{TEST_PROJECT_ID}/locations/{TEST_LOCATION}/hyperparameterTuningJobs/{TEST_HPT_JOB_ID}" +) +VERTEX_AI_TRIGGER_PATH = "airflow.providers.google.cloud.triggers.vertex_ai.{}" + + +@pytest.fixture +def create_hyperparameter_tuning_job_trigger(): + return CreateHyperparameterTuningJobTrigger( + conn_id=TEST_CONN_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + +class TestCreateHyperparameterTuningJobTrigger: + def test_serialize(self, create_hyperparameter_tuning_job_trigger): + classpath, kwargs = create_hyperparameter_tuning_job_trigger.serialize() + assert ( + classpath + == "airflow.providers.google.cloud.triggers.vertex_ai.CreateHyperparameterTuningJobTrigger" + ) + assert kwargs == dict( + conn_id=TEST_CONN_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) + def test_get_async_hook(self, mock_async_hook, create_hyperparameter_tuning_job_trigger): + hook_expected = mock_async_hook.return_value + + hook_created = create_hyperparameter_tuning_job_trigger._get_async_hook() + + mock_async_hook.assert_called_once_with( + gcp_conn_id=TEST_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN + ) + assert hook_created == hook_expected + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "state, status", + [ + (JobState.JOB_STATE_CANCELLED, "error"), + (JobState.JOB_STATE_FAILED, "error"), + (JobState.JOB_STATE_PAUSED, "success"), + (JobState.JOB_STATE_SUCCEEDED, "success"), + ], + ) + @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) + @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJob")) + async def test_run( + self, mock_hpt_job, mock_async_hook, state, status, create_hyperparameter_tuning_job_trigger + ): + mock_job = mock.MagicMock( + status="success", + state=state, + ) + mock_job.name = TEST_HPT_JOB_NAME + mock_async_wait_hyperparameter_tuning_job = mock.AsyncMock(return_value=mock_job) + mock_async_hook.return_value.wait_hyperparameter_tuning_job.side_effect = mock.MagicMock( + side_effect=mock_async_wait_hyperparameter_tuning_job + ) + mock_dict_job = mock.MagicMock() + mock_hpt_job.to_dict.return_value = mock_dict_job + + generator = create_hyperparameter_tuning_job_trigger.run() + event_actual = await generator.asend(None) # type:ignore[attr-defined] + + mock_async_wait_hyperparameter_tuning_job.assert_awaited_once_with( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + ) + assert event_actual == TriggerEvent( + { + "status": status, + "message": f"Hyperparameter tuning job {TEST_HPT_JOB_NAME} completed with status {state.name}", + "job": mock_dict_job, + } + ) + + @pytest.mark.asyncio + @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) + async def test_run_exception(self, mock_async_hook, create_hyperparameter_tuning_job_trigger): + mock_async_hook.return_value.wait_hyperparameter_tuning_job.side_effect = AirflowException( + "test error" + ) + + generator = create_hyperparameter_tuning_job_trigger.run() + event_actual = await generator.asend(None) # type:ignore[attr-defined] + + assert event_actual == TriggerEvent( + { + "status": "error", + "message": "test error", + } + ) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_hyperparameter_tuning_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_hyperparameter_tuning_job.py index 52baa494b287b..b4567590a17a0 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_hyperparameter_tuning_job.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_hyperparameter_tuning_job.py @@ -104,6 +104,23 @@ ) # [END how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator] + # [START how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator_deferrable] + create_hyperparameter_tuning_job_def = CreateHyperparameterTuningJobOperator( + task_id="create_hyperparameter_tuning_job_def", + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + worker_pool_specs=WORKER_POOL_SPECS, + sync=False, + region=REGION, + project_id=PROJECT_ID, + parameter_spec=PARAM_SPECS, + metric_spec=METRIC_SPEC, + max_trial_count=15, + parallel_trial_count=3, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_create_hyperparameter_tuning_job_operator_deferrable] + # [START how_to_cloud_vertex_ai_get_hyperparameter_tuning_job_operator] get_hyperparameter_tuning_job = GetHyperparameterTuningJobOperator( task_id="get_hyperparameter_tuning_job", @@ -125,6 +142,16 @@ ) # [END how_to_cloud_vertex_ai_delete_hyperparameter_tuning_job_operator] + delete_hyperparameter_tuning_job_def = DeleteHyperparameterTuningJobOperator( + task_id="delete_hyperparameter_tuning_job_def", + project_id=PROJECT_ID, + region=REGION, + hyperparameter_tuning_job_id="{{ task_instance.xcom_pull(" + "task_ids='create_hyperparameter_tuning_job_def', " + "key='hyperparameter_tuning_job_id') }}", + trigger_rule=TriggerRule.ALL_DONE, + ) + # [START how_to_cloud_vertex_ai_list_hyperparameter_tuning_job_operator] list_hyperparameter_tuning_job = ListHyperparameterTuningJobOperator( task_id="list_hyperparameter_tuning_job", @@ -143,9 +170,9 @@ # TEST SETUP create_bucket # TEST BODY - >> create_hyperparameter_tuning_job + >> [create_hyperparameter_tuning_job, create_hyperparameter_tuning_job_def] >> get_hyperparameter_tuning_job - >> delete_hyperparameter_tuning_job + >> [delete_hyperparameter_tuning_job, delete_hyperparameter_tuning_job_def] >> list_hyperparameter_tuning_job # TEST TEARDOWN >> delete_bucket