diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index ae8141e8749cf..dc6bd313f37d9 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -18,6 +18,7 @@ """This module contains a Google Cloud Vertex AI hook.""" from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Sequence from google.api_core.client_options import ClientOptions @@ -31,7 +32,7 @@ ) from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -378,6 +379,8 @@ def cancel_pipeline_job( [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``. + This method is deprecated, please use `PipelineJobHook.cancel_pipeline_job` method. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param pipeline_job: The name of the PipelineJob to cancel. @@ -385,6 +388,10 @@ def cancel_pipeline_job( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ + warnings.warn( + "This method is deprecated, please use `PipelineJobHook.cancel_pipeline_job` method.", + AirflowProviderDeprecationWarning, + ) client = self.get_pipeline_service_client(region) name = client.pipeline_job_path(project_id, region, pipeline_job) @@ -493,6 +500,8 @@ def create_pipeline_job( """ Creates a PipelineJob. A PipelineJob will run immediately when created. + This method is deprecated, please use `PipelineJobHook.create_pipeline_job` method. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param pipeline_job: Required. The PipelineJob to create. @@ -504,6 +513,10 @@ def create_pipeline_job( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ + warnings.warn( + "This method is deprecated, please use `PipelineJobHook.create_pipeline_job` method.", + AirflowProviderDeprecationWarning, + ) client = self.get_pipeline_service_client(region) parent = client.common_location_path(project_id, region) @@ -1752,6 +1765,8 @@ def delete_pipeline_job( """ Deletes a PipelineJob. + This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. @@ -1759,6 +1774,10 @@ def delete_pipeline_job( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ + warnings.warn( + "This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.", + AirflowProviderDeprecationWarning, + ) client = self.get_pipeline_service_client(region) name = client.pipeline_job_path(project_id, region, pipeline_job) @@ -1851,6 +1870,8 @@ def get_pipeline_job( """ Gets a PipelineJob. + This method is deprecated, please use `PipelineJobHook.get_pipeline_job` method. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param pipeline_job: Required. The name of the PipelineJob resource. @@ -1858,6 +1879,10 @@ def get_pipeline_job( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ + warnings.warn( + "This method is deprecated, please use `PipelineJobHook.get_pipeline_job` method.", + AirflowProviderDeprecationWarning, + ) client = self.get_pipeline_service_client(region) name = client.pipeline_job_path(project_id, region, pipeline_job) @@ -1953,6 +1978,8 @@ def list_pipeline_jobs( """ Lists PipelineJobs in a Location. + This method is deprecated, please use `PipelineJobHook.list_pipeline_jobs` method. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param filter: Optional. Lists the PipelineJobs that match the filter expression. The @@ -2008,6 +2035,10 @@ def list_pipeline_jobs( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ + warnings.warn( + "This method is deprecated, please use `PipelineJobHook.list_pipeline_jobs` method.", + AirflowProviderDeprecationWarning, + ) client = self.get_pipeline_service_client(region) parent = client.common_location_path(project_id, region) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py new file mode 100644 index 0000000000000..f50b2c717fa9c --- /dev/null +++ b/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py @@ -0,0 +1,409 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Vertex AI hook. + +.. spelling:word-list:: + + aiplatform +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from google.api_core.client_options import ClientOptions +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud.aiplatform import PipelineJob +from google.cloud.aiplatform_v1 import PipelineServiceClient + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + +if TYPE_CHECKING: + from google.api_core.operation import Operation + from google.api_core.retry import Retry + from google.cloud.aiplatform.metadata import experiment_resources + from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ListPipelineJobsPager + + +class PipelineJobHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI Pipeline Job APIs.""" + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + ) + self._pipeline_job: PipelineJob | None = None + + def get_pipeline_service_client( + self, + region: str | None = None, + ) -> PipelineServiceClient: + """Returns PipelineServiceClient.""" + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") + else: + client_options = ClientOptions() + return PipelineServiceClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options + ) + + def get_pipeline_job_object( + self, + display_name: str, + template_path: str, + job_id: str | None = None, + pipeline_root: str | None = None, + parameter_values: dict[str, Any] | None = None, + input_artifacts: dict[str, str] | None = None, + enable_caching: bool | None = None, + encryption_spec_key_name: str | None = None, + labels: dict[str, str] | None = None, + project: str | None = None, + location: str | None = None, + failure_policy: str | None = None, + ) -> PipelineJob: + """Returns PipelineJob object.""" + return PipelineJob( + display_name=display_name, + template_path=template_path, + job_id=job_id, + pipeline_root=pipeline_root, + parameter_values=parameter_values, + input_artifacts=input_artifacts, + enable_caching=enable_caching, + encryption_spec_key_name=encryption_spec_key_name, + labels=labels, + credentials=self.get_credentials(), + project=project, + location=location, + failure_policy=failure_policy, + ) + + @staticmethod + def extract_pipeline_job_id(obj: dict) -> str: + """Returns unique id of the pipeline_job.""" + return obj["name"].rpartition("/")[-1] + + def wait_for_operation(self, operation: Operation, timeout: float | None = None): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + def cancel_pipeline_job(self) -> None: + """Cancel PipelineJob.""" + if self._pipeline_job: + self._pipeline_job.cancel() + + @GoogleBaseHook.fallback_to_default_project_id + def create_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: PipelineJob, + pipeline_job_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> PipelineJob: + """ + Creates a PipelineJob. A PipelineJob will run immediately when created. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job: Required. The PipelineJob to create. + :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of + the PipelineJob name. If not provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.create_pipeline_job( + request={ + "parent": parent, + "pipeline_job": pipeline_job, + "pipeline_job_id": pipeline_job_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def run_pipeline_job( + self, + project_id: str, + region: str, + display_name: str, + template_path: str, + job_id: str | None = None, + pipeline_root: str | None = None, + parameter_values: dict[str, Any] | None = None, + input_artifacts: dict[str, str] | None = None, + enable_caching: bool | None = None, + encryption_spec_key_name: str | None = None, + labels: dict[str, str] | None = None, + failure_policy: str | None = None, + # START: run param + service_account: str | None = None, + network: str | None = None, + create_request_timeout: float | None = None, + experiment: str | experiment_resources.Experiment | None = None, + # END: run param + ) -> PipelineJob: + """ + Run PipelineJob and monitor the job until completion. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param display_name: Required. The user-defined name of this Pipeline. + :param template_path: Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It can be + a local path, a Google Cloud Storage URI (e.g. "gs://project.name"), an Artifact Registry URI + (e.g. "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI. + :param job_id: Optional. The unique ID of the job run. If not specified, pipeline name + timestamp + will be used. + :param pipeline_root: Optional. The root of the pipeline outputs. If not set, the staging bucket set + in aiplatform.init will be used. If that's not set a pipeline-specific artifacts bucket will be + used. + :param parameter_values: Optional. The mapping from runtime parameter names to its values that + control the pipeline run. + :param input_artifacts: Optional. The mapping from the runtime parameter name for this artifact to + its resource id. For example: "vertex_model":"456". Note: full resource name + ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used. + :param enable_caching: Optional. Whether to turn on caching for the run. + If this is not set, defaults to the compile time settings, which are True for all tasks by + default, while users may specify different caching options for individual tasks. + If this is set, the setting applies to all tasks in the pipeline. Overrides the compile time + settings. + :param encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed + encryption key used to protect the job. Has the form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute resource is created. If this is set, + then all resources created by the PipelineJob will be encrypted with the provided encryption key. + Overrides encryption_spec_key_name set in aiplatform.init. + :param labels: Optional. The user defined metadata to organize PipelineJob. + :param failure_policy: Optional. The failure policy - "slow" or "fast". Currently, the default of a + pipeline is that the pipeline will continue to run until no more tasks can be executed, also + known as PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow"). However, if a pipeline is set + to PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"), it will stop scheduling any new + tasks when a task has failed. Any scheduled tasks will continue to completion. + :param service_account: Optional. Specifies the service account for workload run-as account. Users + submitting jobs must have act-as permission on this run-as account. + :param network: Optional. The full name of the Compute Engine network to which the job should be + peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. If left unspecified, the + network set in aiplatform.init will be used. Otherwise, the job is not peered with any network. + :param create_request_timeout: Optional. The timeout for the create request in seconds. + :param experiment: Optional. The Vertex AI experiment name or instance to associate to this + PipelineJob. Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as + metrics to the current Experiment Run. Pipeline parameters will be associated as parameters to + the current Experiment Run. + """ + self._pipeline_job = self.get_pipeline_job_object( + display_name=display_name, + template_path=template_path, + job_id=job_id, + pipeline_root=pipeline_root, + parameter_values=parameter_values, + input_artifacts=input_artifacts, + enable_caching=enable_caching, + encryption_spec_key_name=encryption_spec_key_name, + labels=labels, + project=project_id, + location=region, + failure_policy=failure_policy, + ) + + self._pipeline_job.submit( + service_account=service_account, + network=network, + create_request_timeout=create_request_timeout, + experiment=experiment, + ) + + self._pipeline_job.wait() + return self._pipeline_job + + @GoogleBaseHook.fallback_to_default_project_id + def get_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> PipelineJob: + """ + Gets a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job_id: Required. The ID of the PipelineJob resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job_id) + + result = client.get_pipeline_job( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_pipeline_jobs( + self, + project_id: str, + region: str, + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListPipelineJobsPager: + """ + Lists PipelineJobs in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: Optional. Lists the PipelineJobs that match the filter expression. The + following fields are supported: + + - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. + - ``display_name``: Supports ``=``, ``!=`` comparisons, and + ``:`` wildcard. + - ``pipeline_job_user_id``: Supports ``=``, ``!=`` + comparisons, and ``:`` wildcard. for example, can check + if pipeline's display_name contains *step* by doing + display_name:"*step*" + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``labels``: Supports key-value equality and key presence. + + Filter expressions can be combined together using logical + operators (``AND`` & ``OR``). For example: + ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. + + The syntax to define filter expression is based on + https://google.aip.dev/160. + :param page_size: Optional. The standard list page size. + :param page_token: Optional. The standard list page token. Typically obtained via + [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] + of the previous + [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] + call. + :param order_by: Optional. A comma-separated list of fields to order by. The default + sort order is in ascending order. Use "desc" after a field + name for descending. You can have multiple order_by fields + provided e.g. "create_time desc, end_time", "end_time, + start_time, update_time" For example, using "create_time + desc, end_time" will order results by create time in + descending order, and if there are multiple jobs having the + same create time, order them by the end time in ascending + order. if order_by is not specified, it will order by + default order is create time in descending order. Supported + fields: + + - ``create_time`` + - ``update_time`` + - ``end_time`` + - ``start_time`` + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + parent = client.common_location_path(project_id, region) + + result = client.list_pipeline_jobs( + request={ + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "order_by": order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Deletes a PipelineJob. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job_id: Required. The ID of the PipelineJob resource to be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job_id) + + result = client.delete_pipeline_job( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/airflow/providers/google/cloud/links/vertex_ai.py b/airflow/providers/google/cloud/links/vertex_ai.py index c05c36ea6ce72..cb48aafac44d2 100644 --- a/airflow/providers/google/cloud/links/vertex_ai.py +++ b/airflow/providers/google/cloud/links/vertex_ai.py @@ -49,6 +49,10 @@ VERTEX_AI_BASE_LINK + "/locations/{region}/endpoints/{endpoint_id}?project={project_id}" ) VERTEX_AI_ENDPOINT_LIST_LINK = VERTEX_AI_BASE_LINK + "/endpoints?project={project_id}" +VERTEX_AI_PIPELINE_JOB_LINK = ( + VERTEX_AI_BASE_LINK + "/locations/{region}/pipelines/runs/{pipeline_id}?project={project_id}" +) +VERTEX_AI_PIPELINE_JOB_LIST_LINK = VERTEX_AI_BASE_LINK + "/pipelines/runs?project={project_id}" class VertexAIModelLink(BaseGoogleLink): @@ -319,3 +323,48 @@ def persist( "project_id": task_instance.project_id, }, ) + + +class VertexAIPipelineJobLink(BaseGoogleLink): + """Helper class for constructing Vertex AI PipelineJob link.""" + + name = "Pipeline Job" + key = "pipeline_job_conf" + format_str = VERTEX_AI_PIPELINE_JOB_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + pipeline_id: str, + ): + task_instance.xcom_push( + context=context, + key=VertexAIPipelineJobLink.key, + value={ + "pipeline_id": pipeline_id, + "region": task_instance.region, + "project_id": task_instance.project_id, + }, + ) + + +class VertexAIPipelineJobListLink(BaseGoogleLink): + """Helper class for constructing Vertex AI PipelineJobList link.""" + + name = "Pipeline Job List" + key = "pipeline_job_list_conf" + format_str = VERTEX_AI_PIPELINE_JOB_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=VertexAIPipelineJobListLink.key, + value={ + "project_id": task_instance.project_id, + }, + ) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py b/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py new file mode 100644 index 0000000000000..a221c3d46b552 --- /dev/null +++ b/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py @@ -0,0 +1,464 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Vertex AI operators.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from google.api_core.exceptions import NotFound +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud.aiplatform_v1.types import PipelineJob + +from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobHook +from airflow.providers.google.cloud.links.vertex_ai import ( + VertexAIPipelineJobLink, + VertexAIPipelineJobListLink, +) +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from google.api_core.retry import Retry + from google.cloud.aiplatform.metadata import experiment_resources + + from airflow.utils.context import Context + + +class RunPipelineJobOperator(GoogleCloudBaseOperator): + """ + Run Pipeline job. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param display_name: Required. The user-defined name of this Pipeline. + :param template_path: Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It can be + a local path, a Google Cloud Storage URI (e.g. "gs://project.name"), an Artifact Registry URI + (e.g. "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI. + :param job_id: Optional. The unique ID of the job run. If not specified, pipeline name + timestamp + will be used. + :param pipeline_root: Optional. The root of the pipeline outputs. If not set, the staging bucket set + in aiplatform.init will be used. If that's not set a pipeline-specific artifacts bucket will be + used. + :param parameter_values: Optional. The mapping from runtime parameter names to its values that + control the pipeline run. + :param input_artifacts: Optional. The mapping from the runtime parameter name for this artifact to + its resource id. For example: "vertex_model":"456". Note: full resource name + ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used. + :param enable_caching: Optional. Whether to turn on caching for the run. + If this is not set, defaults to the compile time settings, which are True for all tasks by + default, while users may specify different caching options for individual tasks. + If this is set, the setting applies to all tasks in the pipeline. Overrides the compile time + settings. + :param encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed + encryption key used to protect the job. Has the form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute resource is created. If this is set, + then all resources created by the PipelineJob will be encrypted with the provided encryption key. + Overrides encryption_spec_key_name set in aiplatform.init. + :param labels: Optional. The user defined metadata to organize PipelineJob. + :param failure_policy: Optional. The failure policy - "slow" or "fast". Currently, the default of a + pipeline is that the pipeline will continue to run until no more tasks can be executed, also + known as PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow"). However, if a pipeline is set + to PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"), it will stop scheduling any new + tasks when a task has failed. Any scheduled tasks will continue to completion. + :param service_account: Optional. Specifies the service account for workload run-as account. Users + submitting jobs must have act-as permission on this run-as account. + :param network: Optional. The full name of the Compute Engine network to which the job should be + peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. If left unspecified, the + network set in aiplatform.init will be used. Otherwise, the job is not peered with any network. + :param create_request_timeout: Optional. The timeout for the create request in seconds. + :param experiment: Optional. The Vertex AI experiment name or instance to associate to this + PipelineJob. Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as + metrics to the current Experiment Run. Pipeline parameters will be associated as parameters to + the current Experiment Run. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + "region", + "project_id", + "input_artifacts", + "impersonation_chain", + ] + operator_extra_links = (VertexAIPipelineJobLink(),) + + def __init__( + self, + *, + project_id: str, + region: str, + display_name: str, + template_path: str, + job_id: str | None = None, + pipeline_root: str | None = None, + parameter_values: dict[str, Any] | None = None, + input_artifacts: dict[str, str] | None = None, + enable_caching: bool | None = None, + encryption_spec_key_name: str | None = None, + labels: dict[str, str] | None = None, + failure_policy: str | None = None, + service_account: str | None = None, + network: str | None = None, + create_request_timeout: float | None = None, + experiment: str | experiment_resources.Experiment | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.display_name = display_name + self.template_path = template_path + self.job_id = job_id + self.pipeline_root = pipeline_root + self.parameter_values = parameter_values + self.input_artifacts = input_artifacts + self.enable_caching = enable_caching + self.encryption_spec_key_name = encryption_spec_key_name + self.labels = labels + self.failure_policy = failure_policy + self.service_account = service_account + self.network = network + self.create_request_timeout = create_request_timeout + self.experiment = experiment + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook: PipelineJobHook | None = None + + def execute(self, context: Context): + self.log.info("Running Pipeline job") + self.hook = PipelineJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = self.hook.run_pipeline_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + template_path=self.template_path, + job_id=self.job_id, + pipeline_root=self.pipeline_root, + parameter_values=self.parameter_values, + input_artifacts=self.input_artifacts, + enable_caching=self.enable_caching, + encryption_spec_key_name=self.encryption_spec_key_name, + labels=self.labels, + failure_policy=self.failure_policy, + service_account=self.service_account, + network=self.network, + create_request_timeout=self.create_request_timeout, + experiment=self.experiment, + ) + + pipeline_job = result.to_dict() + pipeline_job_id = self.hook.extract_pipeline_job_id(pipeline_job) + self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id) + + self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id) + VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id) + return pipeline_job + + def on_kill(self) -> None: + """Callback called when the operator is killed; cancel any running job.""" + if self.hook: + self.hook.cancel_pipeline_job() + + +class GetPipelineJobOperator(GoogleCloudBaseOperator): + """ + Get a Pipeline job. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job_id: Required. The ID of the PipelineJob resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + + """ + + template_fields = [ + "region", + "pipeline_job_id", + "project_id", + "impersonation_chain", + ] + operator_extra_links = (VertexAIPipelineJobLink(),) + + def __init__( + self, + *, + project_id: str, + region: str, + pipeline_job_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.pipeline_job_id = pipeline_job_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = PipelineJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + try: + self.log.info("Get Pipeline job: %s", self.pipeline_job_id) + result = hook.get_pipeline_job( + project_id=self.project_id, + region=self.region, + pipeline_job_id=self.pipeline_job_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + VertexAIPipelineJobLink.persist( + context=context, task_instance=self, pipeline_id=self.pipeline_job_id + ) + self.log.info("Pipeline job was gotten.") + return PipelineJob.to_dict(result) + except NotFound: + self.log.info("The Pipeline job %s does not exist.", self.pipeline_job_id) + + +class ListPipelineJobOperator(GoogleCloudBaseOperator): + """Lists PipelineJob in a Location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param filter: Optional. Lists the PipelineJobs that match the filter expression. The + following fields are supported: + + - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. + - ``display_name``: Supports ``=``, ``!=`` comparisons, and + ``:`` wildcard. + - ``pipeline_job_user_id``: Supports ``=``, ``!=`` + comparisons, and ``:`` wildcard. for example, can check + if pipeline's display_name contains *step* by doing + display_name:"*step*" + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``labels``: Supports key-value equality and key presence. + + Filter expressions can be combined together using logical + operators (``AND`` & ``OR``). For example: + ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. + + The syntax to define filter expression is based on + https://google.aip.dev/160. + :param page_size: Optional. The standard list page size. + :param page_token: Optional. The standard list page token. Typically obtained via + [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token] + of the previous + [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs] + call. + :param order_by: Optional. A comma-separated list of fields to order by. The default + sort order is in ascending order. Use "desc" after a field + name for descending. You can have multiple order_by fields + provided e.g. "create_time desc, end_time", "end_time, + start_time, update_time" For example, using "create_time + desc, end_time" will order results by create time in + descending order, and if there are multiple jobs having the + same create time, order them by the end time in ascending + order. if order_by is not specified, it will order by + default order is create time in descending order. Supported + fields: + + - ``create_time`` + - ``update_time`` + - ``end_time`` + - ``start_time`` + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + "region", + "project_id", + "impersonation_chain", + ] + operator_extra_links = [ + VertexAIPipelineJobListLink(), + ] + + def __init__( + self, + *, + region: str, + project_id: str, + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.page_size = page_size + self.page_token = page_token + self.filter = filter + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = PipelineJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + results = hook.list_pipeline_jobs( + region=self.region, + project_id=self.project_id, + page_size=self.page_size, + page_token=self.page_token, + filter=self.filter, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + VertexAIPipelineJobListLink.persist(context=context, task_instance=self) + return [PipelineJob.to_dict(result) for result in results] + + +class DeletePipelineJobOperator(GoogleCloudBaseOperator): + """ + Delete a Pipeline job. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job_id: Required. The ID of the PipelineJob resource to be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = [ + "region", + "project_id", + "pipeline_job_id", + "impersonation_chain", + ] + + def __init__( + self, + *, + project_id: str, + region: str, + pipeline_job_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.pipeline_job_id = pipeline_job_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = PipelineJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + self.log.info("Deleting Pipeline job: %s", self.pipeline_job_id) + operation = hook.delete_pipeline_job( + region=self.region, + project_id=self.project_id, + pipeline_job_id=self.pipeline_job_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Pipeline job was deleted.") + except NotFound: + self.log.info("The Pipeline Job ID %s does not exist.", self.pipeline_job_id) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 5ceb155495de5..8ba1498d55617 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -624,6 +624,7 @@ operators: - airflow.providers.google.cloud.operators.vertex_ai.endpoint_service - airflow.providers.google.cloud.operators.vertex_ai.hyperparameter_tuning_job - airflow.providers.google.cloud.operators.vertex_ai.model_service + - airflow.providers.google.cloud.operators.vertex_ai.pipeline_job - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.operators.looker @@ -870,6 +871,7 @@ hooks: - airflow.providers.google.cloud.hooks.vertex_ai.endpoint_service - airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job - airflow.providers.google.cloud.hooks.vertex_ai.model_service + - airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.hooks.looker @@ -1132,6 +1134,8 @@ extra-links: - airflow.providers.google.cloud.links.vertex_ai.VertexAIBatchPredictionJobListLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIEndpointLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIEndpointListLink + - airflow.providers.google.cloud.links.vertex_ai.VertexAIPipelineJobLink + - airflow.providers.google.cloud.links.vertex_ai.VertexAIPipelineJobListLink - airflow.providers.google.cloud.links.workflows.WorkflowsWorkflowDetailsLink - airflow.providers.google.cloud.links.workflows.WorkflowsListOfWorkflowsLink - airflow.providers.google.cloud.links.workflows.WorkflowsExecutionLink diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index f9e426a032cbb..5ba490bf55275 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -489,6 +489,46 @@ To delete specific version of model you can use :start-after: [START how_to_cloud_vertex_ai_delete_version_operator] :end-before: [END how_to_cloud_vertex_ai_delete_version_operator] +Running a Pipeline Jobs +^^^^^^^^^^^^^^^^^^^^^^^ + +To run a Google VertexAI Pipeline Job you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.RunPipelineJobOperator`. +The operator returns pipeline job id in :ref:`XCom ` under ``pipeline_job_id`` key. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_run_pipeline_job_operator] + :end-before: [END how_to_cloud_vertex_ai_run_pipeline_job_operator] + +To delete pipeline job you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.DeletePipelineJobOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_pipeline_job_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_pipeline_job_operator] + +To get pipeline job you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.GetPipelineJobOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_get_pipeline_job_operator] + :end-before: [END how_to_cloud_vertex_ai_get_pipeline_job_operator] + +To get a pipeline job list you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.ListPipelineJobOperator`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_list_pipeline_job_operator] + :end-before: [END how_to_cloud_vertex_ai_list_pipeline_job_operator] + Reference ^^^^^^^^^ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a71bd7b790c0b..daf2689f84516 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -866,6 +866,7 @@ KeyManagementServiceClient keyring keyspace keytab +kfp Kibana killMode Kinesis @@ -958,6 +959,7 @@ mesos MessageAttributes metaclass metadatabase +metadataStores metarouter Metastore metastore @@ -1189,6 +1191,7 @@ productionalize ProductSearchClient profiler programmatically +proj projectId projectid proto diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py b/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py new file mode 100644 index 0000000000000..e784dabb0dfd1 --- /dev/null +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py @@ -0,0 +1,219 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from google.api_core.gapic_v1.method import DEFAULT + +from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import ( + PipelineJobHook, +) +from tests.providers.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, +) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_REGION: str = "test-region" +TEST_PROJECT_ID: str = "test-project-id" +TEST_PIPELINE_JOB: dict = {} +TEST_PIPELINE_JOB_ID = "test_pipeline_job_id" + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +PIPELINE_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job.{}" + + +class TestPipelineJobWithDefaultProjectIdHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = PipelineJobHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_create_pipeline_job(self, mock_client) -> None: + self.hook.create_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_pipeline_job.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_delete_pipeline_job(self, mock_client) -> None: + self.hook.delete_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_get_pipeline_job(self, mock_client) -> None: + self.hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_list_pipeline_jobs(self, mock_client) -> None: + self.hook.list_pipeline_jobs( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_pipeline_jobs.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + +class TestPipelineJobWithoutDefaultProjectIdHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id + ): + self.hook = PipelineJobHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_create_pipeline_job(self, mock_client) -> None: + self.hook.create_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.create_pipeline_job.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + pipeline_job=TEST_PIPELINE_JOB, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_delete_pipeline_job(self, mock_client) -> None: + self.hook.delete_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.delete_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_get_pipeline_job(self, mock_client) -> None: + self.hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.get_pipeline_job.assert_called_once_with( + request=dict( + name=mock_client.return_value.pipeline_job_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.pipeline_job_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID + ) + + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobHook.get_pipeline_service_client")) + def test_list_pipeline_jobs(self, mock_client) -> None: + self.hook.list_pipeline_jobs( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.list_pipeline_jobs.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 91e1d4718a8a8..867291ae02189 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -76,6 +76,12 @@ SetDefaultVersionOnModelOperator, UploadModelOperator, ) +from airflow.providers.google.cloud.operators.vertex_ai.pipeline_job import ( + DeletePipelineJobOperator, + GetPipelineJobOperator, + ListPipelineJobOperator, + RunPipelineJobOperator, +) VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" TIMEOUT = 120 @@ -176,6 +182,8 @@ TEST_BATCH_SIZE = 4000 TEST_VERSION_ALIASES = ["new-alias"] +TEST_TEMPLATE_PATH = "test_template_path" + class TestVertexAICreateCustomContainerTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) @@ -1714,3 +1722,134 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) + + +class TestVertexAIRunPipelineJobOperator: + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJob.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = RunPipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + display_name=DISPLAY_NAME, + template_path=TEST_TEMPLATE_PATH, + job_id=TEST_PIPELINE_JOB_ID, + pipeline_root="", + parameter_values={}, + input_artifacts={}, + enable_caching=False, + encryption_spec_key_name="", + labels={}, + failure_policy="", + service_account="", + network="", + create_request_timeout=None, + experiment=None, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.run_pipeline_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + display_name=DISPLAY_NAME, + template_path=TEST_TEMPLATE_PATH, + job_id=TEST_PIPELINE_JOB_ID, + pipeline_root="", + parameter_values={}, + input_artifacts={}, + enable_caching=False, + encryption_spec_key_name="", + labels={}, + failure_policy="", + service_account="", + network="", + create_request_timeout=None, + experiment=None, + ) + + +class TestVertexAIGetPipelineJobOperator: + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJob.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = GetPipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.get_pipeline_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + +class TestVertexAIDeletePipelineJobOperator: + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) + def test_execute(self, mock_hook): + op = DeletePipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_pipeline_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + pipeline_job_id=TEST_PIPELINE_JOB_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + +class TestVertexAIListPipelineJobOperator: + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) + def test_execute(self, mock_hook): + page_token = "page_token" + page_size = 42 + filter = "filter" + order_by = "order_by" + + op = ListPipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_pipeline_jobs.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py deleted file mode 100644 index 78aa7108f6f74..0000000000000 --- a/tests/providers/google/cloud/operators/test_vertex_ai_system.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from tests.providers.google.cloud.utils.gcp_authenticator import GCP_VERTEX_AI_KEY -from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context - - -@pytest.mark.backend("mysql", "postgres") -@pytest.mark.credential_file(GCP_VERTEX_AI_KEY) -class TestVertexAIExampleDagsSystem(GoogleSystemTest): - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_custom_jobs_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_custom_jobs", dag_folder=CLOUD_DAG_FOLDER) - - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_dataset_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_dataset", dag_folder=CLOUD_DAG_FOLDER) - - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_auto_ml_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_auto_ml", dag_folder=CLOUD_DAG_FOLDER) - - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_batch_prediction_job_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_batch_prediction_job", dag_folder=CLOUD_DAG_FOLDER) - - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_endpoint_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_endpoint", dag_folder=CLOUD_DAG_FOLDER) - - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_hyperparameter_tuning_job_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_hyperparameter_tuning_job", dag_folder=CLOUD_DAG_FOLDER) - - @provide_gcp_context(GCP_VERTEX_AI_KEY) - def test_run_model_service_example_dag(self): - self.run_dag(dag_id="example_gcp_vertex_ai_model_service", dag_folder=CLOUD_DAG_FOLDER) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py new file mode 100644 index 0000000000000..16bc8c021686f --- /dev/null +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# mypy ignore arg types (for templated fields) +# type: ignore[arg-type] + +""" +Example Airflow DAG for Google Vertex AI service testing Pipeline Job operations. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSDeleteObjectsOperator, + GCSListObjectsOperator, + GCSSynchronizeBucketsOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.pipeline_job import ( + DeletePipelineJobOperator, + GetPipelineJobOperator, + ListPipelineJobOperator, + RunPipelineJobOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "example_vertex_ai_pipeline_job_operations" +REGION = "us-central1" +DISPLAY_NAME = f"pipeline-job-{ENV_ID}" + +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") +TEMPLATE_PATH = "https://us-kfp.pkg.dev/ml-pipeline/google-cloud-registry/automl-tabular/sha256:85e4218fc6604ee82353c9d2ebba20289eb1b71930798c0bb8ce32d8a10de146" +OUTPUT_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}" + +PARAMETER_VALUES = { + "train_budget_milli_node_hours": 2000, + "optimization_objective": "minimize-log-loss", + "project": PROJECT_ID, + "location": REGION, + "root_dir": OUTPUT_BUCKET, + "target_column": "Adopted", + "training_fraction": 0.8, + "validation_fraction": 0.1, + "test_fraction": 0.1, + "prediction_type": "classification", + "data_source_csv_filenames": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/vertex-ai/tabular-dataset.csv", + "transformations": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/vertex-ai/column_transformations.json", +} + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "vertex_ai", "pipeline_job"], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=REGION, + ) + + move_pipeline_files = GCSSynchronizeBucketsOperator( + task_id="move_files_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="vertex-ai/pipeline", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="vertex-ai", + recursive=True, + ) + + # [START how_to_cloud_vertex_ai_run_pipeline_job_operator] + run_pipeline_job = RunPipelineJobOperator( + task_id="run_pipeline_job", + display_name=DISPLAY_NAME, + template_path=TEMPLATE_PATH, + parameter_values=PARAMETER_VALUES, + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_run_pipeline_job_operator] + + # [START how_to_cloud_vertex_ai_get_pipeline_job_operator] + get_pipeline_job = GetPipelineJobOperator( + task_id="get_pipeline_job", + project_id=PROJECT_ID, + region=REGION, + pipeline_job_id="{{ task_instance.xcom_pull(" + "task_ids='run_pipeline_job', key='pipeline_job_id') }}", + ) + # [END how_to_cloud_vertex_ai_get_pipeline_job_operator] + + # [START how_to_cloud_vertex_ai_delete_pipeline_job_operator] + delete_pipeline_job = DeletePipelineJobOperator( + task_id="delete_pipeline_job", + project_id=PROJECT_ID, + region=REGION, + pipeline_job_id="{{ task_instance.xcom_pull(" + "task_ids='run_pipeline_job', key='pipeline_job_id') }}", + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_cloud_vertex_ai_delete_pipeline_job_operator] + + # [START how_to_cloud_vertex_ai_list_pipeline_job_operator] + list_pipeline_job = ListPipelineJobOperator( + task_id="list_pipeline_job", + region=REGION, + project_id=PROJECT_ID, + ) + # [END how_to_cloud_vertex_ai_list_pipeline_job_operator] + + list_buckets = GCSListObjectsOperator(task_id="list_buckets", bucket=DATA_SAMPLE_GCS_BUCKET_NAME) + + delete_files = GCSDeleteObjectsOperator( + task_id="delete_files", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, objects=list_buckets.output + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_bucket + >> move_pipeline_files + # TEST BODY + >> run_pipeline_job + >> get_pipeline_job + >> delete_pipeline_job + >> list_pipeline_job + # TEST TEARDOWN + >> list_buckets + >> delete_files + >> delete_bucket + ) + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)