Skip to content

Commit

Permalink
Create operators for VertexAI Pipeline Job (#34915)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed Oct 18, 2023
1 parent 86e27c7 commit f16906d
Show file tree
Hide file tree
Showing 11 changed files with 1,526 additions and 56 deletions.
33 changes: 32 additions & 1 deletion airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
Expand Up @@ -18,6 +18,7 @@
"""This module contains a Google Cloud Vertex AI hook."""
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Sequence

from google.api_core.client_options import ClientOptions
Expand All @@ -31,7 +32,7 @@
)
from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook

Expand Down Expand Up @@ -378,13 +379,19 @@ def cancel_pipeline_job(
[google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and
[PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``.
This method is deprecated, please use `PipelineJobHook.cancel_pipeline_job` method.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param pipeline_job: The name of the PipelineJob to cancel.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
warnings.warn(
"This method is deprecated, please use `PipelineJobHook.cancel_pipeline_job` method.",
AirflowProviderDeprecationWarning,
)
client = self.get_pipeline_service_client(region)
name = client.pipeline_job_path(project_id, region, pipeline_job)

Expand Down Expand Up @@ -493,6 +500,8 @@ def create_pipeline_job(
"""
Creates a PipelineJob. A PipelineJob will run immediately when created.
This method is deprecated, please use `PipelineJobHook.create_pipeline_job` method.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param pipeline_job: Required. The PipelineJob to create.
Expand All @@ -504,6 +513,10 @@ def create_pipeline_job(
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
warnings.warn(
"This method is deprecated, please use `PipelineJobHook.create_pipeline_job` method.",
AirflowProviderDeprecationWarning,
)
client = self.get_pipeline_service_client(region)
parent = client.common_location_path(project_id, region)

Expand Down Expand Up @@ -1752,13 +1765,19 @@ def delete_pipeline_job(
"""
Deletes a PipelineJob.
This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
warnings.warn(
"This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.",
AirflowProviderDeprecationWarning,
)
client = self.get_pipeline_service_client(region)
name = client.pipeline_job_path(project_id, region, pipeline_job)

Expand Down Expand Up @@ -1851,13 +1870,19 @@ def get_pipeline_job(
"""
Gets a PipelineJob.
This method is deprecated, please use `PipelineJobHook.get_pipeline_job` method.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param pipeline_job: Required. The name of the PipelineJob resource.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
warnings.warn(
"This method is deprecated, please use `PipelineJobHook.get_pipeline_job` method.",
AirflowProviderDeprecationWarning,
)
client = self.get_pipeline_service_client(region)
name = client.pipeline_job_path(project_id, region, pipeline_job)

Expand Down Expand Up @@ -1953,6 +1978,8 @@ def list_pipeline_jobs(
"""
Lists PipelineJobs in a Location.
This method is deprecated, please use `PipelineJobHook.list_pipeline_jobs` method.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param filter: Optional. Lists the PipelineJobs that match the filter expression. The
Expand Down Expand Up @@ -2008,6 +2035,10 @@ def list_pipeline_jobs(
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
warnings.warn(
"This method is deprecated, please use `PipelineJobHook.list_pipeline_jobs` method.",
AirflowProviderDeprecationWarning,
)
client = self.get_pipeline_service_client(region)
parent = client.common_location_path(project_id, region)

Expand Down

0 comments on commit f16906d

Please sign in to comment.