Skip to content

Commit

Permalink
CreateBatchPredictionJobOperator Add batch_size param for Vertex AI…
Browse files Browse the repository at this point in the history
… BatchPredictionJob objects (#31118)

* Add batch_size param for BatchPredictionJob objects

Co-authored-by: Jarek Potiuk <jarek@potiuk.com>

---------

Co-authored-by: Jarek Potiuk <jarek@potiuk.com>
  • Loading branch information
VVildVVolf and potiuk committed May 13, 2023
1 parent 779af82 commit a66edcb
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 3 deletions.
Expand Up @@ -114,6 +114,8 @@ def create_batch_prediction_job(
labels: dict[str, str] | None = None,
encryption_spec_key_name: str | None = None,
sync: bool = True,
create_request_timeout: float | None = None,
batch_size: int | None = None,
) -> BatchPredictionJob:
"""
Create a batch prediction job.
Expand Down Expand Up @@ -207,6 +209,14 @@ def create_batch_prediction_job(
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
concurrent Future and any downstream object will be immediately returned and synced when the
Future has completed.
:param create_request_timeout: Optional. The timeout for the create request in seconds.
:param batch_size: Optional. The number of the records (e.g. instances)
of the operation given in each batch
to a machine replica. Machine type, and size of a single record should be considered
when setting this parameter, higher value speeds up the batch operation's execution,
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is same as in the aiplatform's BatchPredictionJob.
"""
self._batch_prediction_job = BatchPredictionJob.create(
job_display_name=job_display_name,
Expand All @@ -232,6 +242,8 @@ def create_batch_prediction_job(
credentials=self.get_credentials(),
encryption_spec_key_name=encryption_spec_key_name,
sync=sync,
create_request_timeout=create_request_timeout,
batch_size=batch_size,
)
return self._batch_prediction_job

Expand Down
Expand Up @@ -139,6 +139,14 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
concurrent Future and any downstream object will be immediately returned and synced when the
Future has completed.
:param create_request_timeout: Optional. The timeout for the create request in seconds.
:param batch_size: Optional. The number of the records (e.g. instances)
of the operation given in each batch
to a machine replica. Machine type, and size of a single record should be considered
when setting this parameter, higher value speeds up the batch operation's execution,
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is same as in the aiplatform's BatchPredictionJob.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
Expand Down Expand Up @@ -181,6 +189,8 @@ def __init__(
labels: dict[str, str] | None = None,
encryption_spec_key_name: str | None = None,
sync: bool = True,
create_request_timeout: float | None = None,
batch_size: int | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
Expand Down Expand Up @@ -208,6 +218,8 @@ def __init__(
self.labels = labels
self.encryption_spec_key_name = encryption_spec_key_name
self.sync = sync
self.create_request_timeout = create_request_timeout
self.batch_size = batch_size
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook: BatchPredictionJobHook | None = None
Expand Down Expand Up @@ -241,6 +253,8 @@ def execute(self, context: Context):
labels=self.labels,
encryption_spec_key_name=self.encryption_spec_key_name,
sync=self.sync,
create_request_timeout=self.create_request_timeout,
batch_size=self.batch_size,
)

batch_prediction_job = result.to_dict()
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/provider.yaml
Expand Up @@ -95,7 +95,7 @@ dependencies:
- google-api-python-client>=1.6.0,<2.0.0
- google-auth>=1.0.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.7.1,<2.0.0
- google-cloud-aiplatform>=1.13.1,<2.0.0
- google-cloud-automl>=2.1.0
- google-cloud-bigquery-datatransfer>=3.0.0
- google-cloud-bigtable>=2.0.0,<3.0.0
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow-providers-google/index.rst
Expand Up @@ -115,7 +115,7 @@ PIP package Version required
``google-api-python-client`` ``>=1.6.0,<2.0.0``
``google-auth`` ``>=1.0.0``
``google-auth-httplib2`` ``>=0.0.1``
``google-cloud-aiplatform`` ``>=1.7.1,<2.0.0``
``google-cloud-aiplatform`` ``>=1.13.1,<2.0.0``
``google-cloud-automl`` ``>=2.1.0``
``google-cloud-bigquery-datatransfer`` ``>=3.0.0``
``google-cloud-bigtable`` ``>=2.0.0,<3.0.0``
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Expand Up @@ -346,7 +346,7 @@
"google-auth-httplib2>=0.0.1",
"google-auth-oauthlib<1.0.0,>=0.3.0",
"google-auth>=1.0.0",
"google-cloud-aiplatform>=1.7.1,<2.0.0",
"google-cloud-aiplatform>=1.13.1,<2.0.0",
"google-cloud-automl>=2.1.0",
"google-cloud-bigquery-datatransfer>=3.0.0",
"google-cloud-bigtable>=2.0.0,<3.0.0",
Expand Down
7 changes: 7 additions & 0 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Expand Up @@ -165,6 +165,9 @@
"export_format_id": "tf-saved-model",
}

TEST_CREATE_REQUEST_TIMEOUT = 100.5
TEST_BATCH_SIZE = 4000


class TestVertexAICreateCustomContainerTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
Expand Down Expand Up @@ -989,6 +992,8 @@ def test_execute(self, mock_hook, to_dict_mock):
model_name=TEST_MODEL_NAME,
instances_format="jsonl",
predictions_format="jsonl",
create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
batch_size=TEST_BATCH_SIZE,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand All @@ -1015,6 +1020,8 @@ def test_execute(self, mock_hook, to_dict_mock):
labels=None,
encryption_spec_key_name=None,
sync=True,
create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
batch_size=TEST_BATCH_SIZE,
)


Expand Down

0 comments on commit a66edcb

Please sign in to comment.