Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix delay in Dataproc CreateBatch operator #26126

Merged
merged 3 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,15 @@ def get_batch_client(self, region: str | None = None) -> BatchControllerClient:
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
def wait_for_operation(
self,
operation: Operation,
timeout: float | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
):
"""Waits for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
return operation.result(timeout=timeout, retry=result_retry)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)
Expand Down
8 changes: 7 additions & 1 deletion airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,6 +2038,8 @@ class DataprocCreateBatchOperator(BaseOperator):
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param result_retry: Result retry object used to retry requests. Is used to decrease delay between
executing chained tasks in a DAG by specifying exact amount of seconds for executing.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
Expand Down Expand Up @@ -2074,6 +2076,7 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -2083,6 +2086,7 @@ def __init__(
self.batch_id = batch_id
self.request_id = request_id
self.retry = retry
self.result_retry = result_retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
Expand All @@ -2107,7 +2111,9 @@ def execute(self, context: Context):
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")
result = hook.wait_for_operation(timeout=self.timeout, operation=self.operation)
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)
except AlreadyExists:
self.log.info("Batch with given id already exists")
Expand Down
31 changes: 31 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@

TIMEOUT = 120
RETRY = mock.MagicMock(Retry)
RESULT_RETRY = mock.MagicMock(Retry)
METADATA = [("key", "value")]
REQUEST_ID = "request_id_uuid"

Expand Down Expand Up @@ -1706,6 +1707,36 @@ def test_execute(self, mock_hook, to_dict_mock):
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_with_result_retry(self, mock_hook, to_dict_mock):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
result_retry=RESULT_RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_batch_failed(self, mock_hook, to_dict_mock):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow import models
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateBatchOperator,
Expand All @@ -36,6 +38,7 @@
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
REGION = "europe-west1"
BATCH_ID = f"test-batch-id-{ENV_ID}"
BATCH_ID_2 = f"test-batch-id-{ENV_ID}-2"
BATCH_CONFIG = {
"spark_batch": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
Expand All @@ -58,14 +61,26 @@
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID,
timeout=5.0,
)

create_batch_2 = DataprocCreateBatchOperator(
task_id="create_batch_2",
project_id=PROJECT_ID,
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID_2,
result_retry=Retry(maximum=10.0, initial=10.0, multiplier=1.0),
)
# [END how_to_cloud_dataproc_create_batch_operator]

# [START how_to_cloud_dataproc_get_batch_operator]
get_batch = DataprocGetBatchOperator(
task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)

get_batch_2 = DataprocGetBatchOperator(
task_id="get_batch_2", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID_2
)
# [END how_to_cloud_dataproc_get_batch_operator]

# [START how_to_cloud_dataproc_list_batches_operator]
Expand All @@ -80,10 +95,23 @@
delete_batch = DataprocDeleteBatchOperator(
task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)
delete_batch.trigger_rule = TriggerRule.ALL_DONE

delete_batch_2 = DataprocDeleteBatchOperator(
task_id="delete_batch_2", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID_2
)
# [END how_to_cloud_dataproc_delete_batch_operator]
delete_batch.trigger_rule = TriggerRule.ALL_DONE

create_batch >> get_batch >> list_batches >> delete_batch
(
create_batch
>> create_batch_2
>> get_batch
>> get_batch_2
>> list_batches
>> delete_batch
>> delete_batch_2
)

from tests.system.utils.watcher import watcher

Expand Down