Skip to content

Commit

Permalink
Deferrable mode for Custom Training Job operators
Browse files Browse the repository at this point in the history
  • Loading branch information
e-galan committed Mar 14, 2024
1 parent 1d54a9b commit 51e773f
Show file tree
Hide file tree
Showing 10 changed files with 2,964 additions and 102 deletions.
1,497 changes: 1,426 additions & 71 deletions airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/links/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

VERTEX_AI_BASE_LINK = "/vertex-ai"
VERTEX_AI_MODEL_LINK = (
VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}"
VERTEX_AI_BASE_LINK
+ "/models/locations/{region}/models/{model_id}/versions/default/properties?project={project_id}"
)
VERTEX_AI_MODEL_LIST_LINK = VERTEX_AI_BASE_LINK + "/models?project={project_id}"
VERTEX_AI_MODEL_EXPORT_LINK = "/storage/browser/{bucket_name}/model-{model_id}?project={project_id}"
Expand Down
355 changes: 333 additions & 22 deletions airflow/providers/google/cloud/operators/vertex_ai/custom_job.py

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions airflow/providers/google/cloud/triggers/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
HyperparameterTuningJobAsyncHook,
)
Expand Down Expand Up @@ -189,3 +190,96 @@ async def _wait_job(self) -> types.PipelineJob:
poll_interval=self.poll_interval,
)
return job


class CustomTrainingJobTrigger(BaseVertexAIJobTrigger):
"""
Make async calls to Vertex AI to check the state of a running custom training job.
Return the job when it enters a completed state.
"""

job_type_verbose_name = "Custom Training Job"
job_serializer_class = types.TrainingPipeline
statuses_success = {
PipelineState.PIPELINE_STATE_PAUSED,
PipelineState.PIPELINE_STATE_SUCCEEDED,
}

@cached_property
def async_hook(self) -> CustomJobAsyncHook:
return CustomJobAsyncHook(
gcp_conn_id=self.conn_id,
impersonation_chain=self.impersonation_chain,
)

async def _wait_job(self) -> types.TrainingPipeline:
pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
project_id=self.project_id,
location=self.location,
pipeline_id=self.job_id,
poll_interval=self.poll_interval,
)
return pipeline


class CustomContainerTrainingJobTrigger(BaseVertexAIJobTrigger):
"""
Make async calls to Vertex AI to check the state of a running custom container training job.
Return the job when it enters a completed state.
"""

job_type_verbose_name = "Custom Container Training Job"
job_serializer_class = types.TrainingPipeline
statuses_success = {
PipelineState.PIPELINE_STATE_PAUSED,
PipelineState.PIPELINE_STATE_SUCCEEDED,
}

@cached_property
def async_hook(self) -> CustomJobAsyncHook:
return CustomJobAsyncHook(
gcp_conn_id=self.conn_id,
impersonation_chain=self.impersonation_chain,
)

async def _wait_job(self) -> types.TrainingPipeline:
pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
project_id=self.project_id,
location=self.location,
pipeline_id=self.job_id,
poll_interval=self.poll_interval,
)
return pipeline


class CustomPythonPackageTrainingJobTrigger(BaseVertexAIJobTrigger):
"""
Make async calls to Vertex AI to check the state of a running custom python package training job.
Return the job when it enters a completed state.
"""

job_type_verbose_name = "Custom Python Package Training Job"
job_serializer_class = types.TrainingPipeline
statuses_success = {
PipelineState.PIPELINE_STATE_PAUSED,
PipelineState.PIPELINE_STATE_SUCCEEDED,
}

@cached_property
def async_hook(self) -> CustomJobAsyncHook:
return CustomJobAsyncHook(
gcp_conn_id=self.conn_id,
impersonation_chain=self.impersonation_chain,
)

async def _wait_job(self) -> types.TrainingPipeline:
pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
project_id=self.project_id,
location=self.location,
pipeline_id=self.job_id,
poll_interval=self.poll_interval,
)
return pipeline
24 changes: 21 additions & 3 deletions docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Preparation step

For each operator you must prepare and create dataset. Then put dataset id to ``dataset_id`` parameter in operator.

How to run Container Training Job
How to run a Custom Container Training Job
:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator`

Before start running this Job you should create a docker image with training script inside. Documentation how to
Expand All @@ -121,7 +121,16 @@ for container which will be created from this image in ``command`` parameter.
:start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
:end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator]

How to run Python Package Training Job
The :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator`
also provides the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator_deferrable]
:end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator_deferrable]

How to run a Python Package Training Job
:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`

Before start running this Job you should create a python package with training script inside. Documentation how to
Expand All @@ -135,7 +144,16 @@ parameter should has the name of script which will run your training task.
:start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
:end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]

How to run Training Job
The :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`
also provides the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator_deferrable]
:end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator_deferrable]

How to run a Custom Training Job
:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`.

For this Job you should put path to your local training script inside ``script_path`` parameter.
Expand Down
Loading

0 comments on commit 51e773f

Please sign in to comment.