Skip to content

Commit

Permalink
Optimize deferrable mode execution for DataprocSubmitJobOperator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv committed Jun 4, 2023
1 parent 9276310 commit 495ae23
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
8 changes: 8 additions & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -2030,6 +2030,14 @@ def execute(self, context: Context):

self.job_id = new_job_id
if self.deferrable:
job = self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
state = job.status.state
if state == JobStatus.State.DONE:
return self.job_id
elif state == JobStatus.State.ERROR:
raise AirflowException(f"Job failed:\n{job}")
elif state == JobStatus.State.CANCELLED:
raise AirflowException(f"Job was cancelled:\n{job}")
self.defer(
trigger=DataprocSubmitTrigger(
job_id=self.job_id,
Expand Down
28 changes: 27 additions & 1 deletion tests/providers/google/cloud/operators/test_dataproc.py
Expand Up @@ -23,7 +23,7 @@
import pytest
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import Batch
from google.cloud.dataproc_v1 import Batch, JobStatus

from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -1058,6 +1058,32 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocSubmitTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer")
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job")
def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook):
mock_submit_job.return_value.reference.job_id = TEST_JOB_ID
job_status = mock_hook.return_value.get_job.return_value.status
job_status.state = JobStatus.State.DONE

op = DataprocSubmitJobOperator(
task_id=TASK_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
job={},
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
asynchronous=True,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)

op.execute(context=self.mock_context)
assert not mock_defer.called

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
job = {}
Expand Down

0 comments on commit 495ae23

Please sign in to comment.