diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py index 93cd6902bba75..a3d0d205136b1 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py @@ -114,6 +114,8 @@ def create_batch_prediction_job( labels: dict[str, str] | None = None, encryption_spec_key_name: str | None = None, sync: bool = True, + create_request_timeout: float | None = None, + batch_size: int | None = None, ) -> BatchPredictionJob: """ Create a batch prediction job. @@ -207,6 +209,14 @@ def create_batch_prediction_job( :param sync: Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + :param create_request_timeout: Optional. The timeout for the create request in seconds. + :param batch_size: Optional. The number of the records (e.g. instances) + of the operation given in each batch + to a machine replica. Machine type, and size of a single record should be considered + when setting this parameter, higher value speeds up the batch operation's execution, + but too high value will result in a whole batch not fitting in a machine's memory, + and the whole operation will fail. + The default value is same as in the aiplatform's BatchPredictionJob. """ self._batch_prediction_job = BatchPredictionJob.create( job_display_name=job_display_name, @@ -232,6 +242,8 @@ def create_batch_prediction_job( credentials=self.get_credentials(), encryption_spec_key_name=encryption_spec_key_name, sync=sync, + create_request_timeout=create_request_timeout, + batch_size=batch_size, ) return self._batch_prediction_job diff --git a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py index 5c123227751e0..dc4775fe24184 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py @@ -139,6 +139,14 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator): :param sync: Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + :param create_request_timeout: Optional. The timeout for the create request in seconds. + :param batch_size: Optional. The number of the records (e.g. instances) + of the operation given in each batch + to a machine replica. Machine type, and size of a single record should be considered + when setting this parameter, higher value speeds up the batch operation's execution, + but too high value will result in a whole batch not fitting in a machine's memory, + and the whole operation will fail. + The default value is same as in the aiplatform's BatchPredictionJob. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. @@ -181,6 +189,8 @@ def __init__( labels: dict[str, str] | None = None, encryption_spec_key_name: str | None = None, sync: bool = True, + create_request_timeout: float | None = None, + batch_size: int | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -208,6 +218,8 @@ def __init__( self.labels = labels self.encryption_spec_key_name = encryption_spec_key_name self.sync = sync + self.create_request_timeout = create_request_timeout + self.batch_size = batch_size self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.hook: BatchPredictionJobHook | None = None @@ -241,6 +253,8 @@ def execute(self, context: Context): labels=self.labels, encryption_spec_key_name=self.encryption_spec_key_name, sync=self.sync, + create_request_timeout=self.create_request_timeout, + batch_size=self.batch_size, ) batch_prediction_job = result.to_dict() diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index f79c54509a564..41263cd87cb5f 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -95,7 +95,7 @@ dependencies: - google-api-python-client>=1.6.0,<2.0.0 - google-auth>=1.0.0 - google-auth-httplib2>=0.0.1 - - google-cloud-aiplatform>=1.7.1,<2.0.0 + - google-cloud-aiplatform>=1.13.1,<2.0.0 - google-cloud-automl>=2.1.0 - google-cloud-bigquery-datatransfer>=3.0.0 - google-cloud-bigtable>=2.0.0,<3.0.0 diff --git a/docs/apache-airflow-providers-google/index.rst b/docs/apache-airflow-providers-google/index.rst index d78fbf67dd086..85dc79385cfd8 100644 --- a/docs/apache-airflow-providers-google/index.rst +++ b/docs/apache-airflow-providers-google/index.rst @@ -115,7 +115,7 @@ PIP package Version required ``google-api-python-client`` ``>=1.6.0,<2.0.0`` ``google-auth`` ``>=1.0.0`` ``google-auth-httplib2`` ``>=0.0.1`` -``google-cloud-aiplatform`` ``>=1.7.1,<2.0.0`` +``google-cloud-aiplatform`` ``>=1.13.1,<2.0.0`` ``google-cloud-automl`` ``>=2.1.0`` ``google-cloud-bigquery-datatransfer`` ``>=3.0.0`` ``google-cloud-bigtable`` ``>=2.0.0,<3.0.0`` diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 9777e3bfa1ec7..a9aae0f43b74a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -346,7 +346,7 @@ "google-auth-httplib2>=0.0.1", "google-auth-oauthlib<1.0.0,>=0.3.0", "google-auth>=1.0.0", - "google-cloud-aiplatform>=1.7.1,<2.0.0", + "google-cloud-aiplatform>=1.13.1,<2.0.0", "google-cloud-automl>=2.1.0", "google-cloud-bigquery-datatransfer>=3.0.0", "google-cloud-bigtable>=2.0.0,<3.0.0", diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 2c5b9a694da81..f4b3ad154dda1 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -165,6 +165,9 @@ "export_format_id": "tf-saved-model", } +TEST_CREATE_REQUEST_TIMEOUT = 100.5 +TEST_BATCH_SIZE = 4000 + class TestVertexAICreateCustomContainerTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) @@ -989,6 +992,8 @@ def test_execute(self, mock_hook, to_dict_mock): model_name=TEST_MODEL_NAME, instances_format="jsonl", predictions_format="jsonl", + create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT, + batch_size=TEST_BATCH_SIZE, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -1015,6 +1020,8 @@ def test_execute(self, mock_hook, to_dict_mock): labels=None, encryption_spec_key_name=None, sync=True, + create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT, + batch_size=TEST_BATCH_SIZE, )