Skip to content

Commit

Permalink
Fix delay in Dataproc CreateBatch operator (#26126)
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Oct 10, 2022
1 parent 7601460 commit 6f0b600
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 5 deletions.
9 changes: 7 additions & 2 deletions airflow/providers/google/cloud/hooks/dataproc.py
Expand Up @@ -253,10 +253,15 @@ def get_batch_client(self, region: str | None = None) -> BatchControllerClient:
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
def wait_for_operation(
self,
operation: Operation,
timeout: float | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
):
"""Waits for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
return operation.result(timeout=timeout, retry=result_retry)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)
Expand Down
8 changes: 7 additions & 1 deletion airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -2038,6 +2038,8 @@ class DataprocCreateBatchOperator(BaseOperator):
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param result_retry: Result retry object used to retry requests. Is used to decrease delay between
executing chained tasks in a DAG by specifying exact amount of seconds for executing.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
Expand Down Expand Up @@ -2074,6 +2076,7 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -2083,6 +2086,7 @@ def __init__(
self.batch_id = batch_id
self.request_id = request_id
self.retry = retry
self.result_retry = result_retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
Expand All @@ -2107,7 +2111,9 @@ def execute(self, context: Context):
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")
result = hook.wait_for_operation(timeout=self.timeout, operation=self.operation)
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)
except AlreadyExists:
self.log.info("Batch with given id already exists")
Expand Down
31 changes: 31 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Expand Up @@ -193,6 +193,7 @@

TIMEOUT = 120
RETRY = mock.MagicMock(Retry)
RESULT_RETRY = mock.MagicMock(Retry)
METADATA = [("key", "value")]
REQUEST_ID = "request_id_uuid"

Expand Down Expand Up @@ -1706,6 +1707,36 @@ def test_execute(self, mock_hook, to_dict_mock):
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_with_result_retry(self, mock_hook, to_dict_mock):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
result_retry=RESULT_RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_batch_failed(self, mock_hook, to_dict_mock):
Expand Down
Expand Up @@ -22,6 +22,8 @@
import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow import models
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateBatchOperator,
Expand All @@ -36,6 +38,7 @@
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
REGION = "europe-west1"
BATCH_ID = f"test-batch-id-{ENV_ID}"
BATCH_ID_2 = f"test-batch-id-{ENV_ID}-2"
BATCH_CONFIG = {
"spark_batch": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
Expand All @@ -58,14 +61,26 @@
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID,
timeout=5.0,
)

create_batch_2 = DataprocCreateBatchOperator(
task_id="create_batch_2",
project_id=PROJECT_ID,
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID_2,
result_retry=Retry(maximum=10.0, initial=10.0, multiplier=1.0),
)
# [END how_to_cloud_dataproc_create_batch_operator]

# [START how_to_cloud_dataproc_get_batch_operator]
get_batch = DataprocGetBatchOperator(
task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)

get_batch_2 = DataprocGetBatchOperator(
task_id="get_batch_2", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID_2
)
# [END how_to_cloud_dataproc_get_batch_operator]

# [START how_to_cloud_dataproc_list_batches_operator]
Expand All @@ -80,10 +95,23 @@
delete_batch = DataprocDeleteBatchOperator(
task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)
delete_batch.trigger_rule = TriggerRule.ALL_DONE

delete_batch_2 = DataprocDeleteBatchOperator(
task_id="delete_batch_2", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID_2
)
# [END how_to_cloud_dataproc_delete_batch_operator]
delete_batch.trigger_rule = TriggerRule.ALL_DONE

create_batch >> get_batch >> list_batches >> delete_batch
(
create_batch
>> create_batch_2
>> get_batch
>> get_batch_2
>> list_batches
>> delete_batch
>> delete_batch_2
)

from tests.system.utils.watcher import watcher

Expand Down

0 comments on commit 6f0b600

Please sign in to comment.