From 495ae23d45eda52f2b368d0afa4213f4e69e97cd Mon Sep 17 00:00:00 2001 From: Phani Kumar <94376113+phanikumv@users.noreply.github.com> Date: Mon, 5 Jun 2023 04:58:33 +0530 Subject: [PATCH] Optimize deferrable mode execution for `DataprocSubmitJobOperator` (#31317) --- .../google/cloud/operators/dataproc.py | 8 ++++++ .../google/cloud/operators/test_dataproc.py | 28 ++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 985278d4d9d06..3019a536d91a2 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -2030,6 +2030,14 @@ def execute(self, context: Context): self.job_id = new_job_id if self.deferrable: + job = self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id) + state = job.status.state + if state == JobStatus.State.DONE: + return self.job_id + elif state == JobStatus.State.ERROR: + raise AirflowException(f"Job failed:\n{job}") + elif state == JobStatus.State.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job}") self.defer( trigger=DataprocSubmitTrigger( job_id=self.job_id, diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 9c2e202f97c5a..5494a88ba9330 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -23,7 +23,7 @@ import pytest from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry -from google.cloud.dataproc_v1 import Batch +from google.cloud.dataproc_v1 import Batch, JobStatus from airflow.exceptions import ( AirflowException, @@ -1058,6 +1058,32 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook): assert isinstance(exc.value.trigger, DataprocSubmitTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer") + @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job") + def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook): + mock_submit_job.return_value.reference.job_id = TEST_JOB_ID + job_status = mock_hook.return_value.get_job.return_value.status + job_status.state = JobStatus.State.DONE + + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + job={}, + gcp_conn_id=GCP_CONN_ID, + retry=RETRY, + asynchronous=True, + timeout=TIMEOUT, + metadata=METADATA, + request_id=REQUEST_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + op.execute(context=self.mock_context) + assert not mock_defer.called + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_on_kill(self, mock_hook): job = {}