From 86b1bd22d14792d89ddc43627e4a72dcb628c5f0 Mon Sep 17 00:00:00 2001 From: VladaZakharova <80038284+VladaZakharova@users.noreply.github.com> Date: Fri, 1 Dec 2023 18:01:33 +0100 Subject: [PATCH] Fix CloudRunExecuteJobOperator not able to retrieve the Cloud Run job status in deferrable mode (#36012) Co-authored-by: Ulada Zakharava --- .../google/cloud/operators/cloud_run.py | 4 +- .../google/cloud/triggers/cloud_run.py | 14 ++--- .../google/cloud/operators/test_cloud_run.py | 6 +- .../google/cloud/triggers/test_cloud_run.py | 59 ++++++++++--------- 4 files changed, 44 insertions(+), 39 deletions(-) diff --git a/airflow/providers/google/cloud/operators/cloud_run.py b/airflow/providers/google/cloud/operators/cloud_run.py index 14d27810dab5b..91b3ae6cea399 100644 --- a/airflow/providers/google/cloud/operators/cloud_run.py +++ b/airflow/providers/google/cloud/operators/cloud_run.py @@ -321,10 +321,10 @@ def execute(self, context: Context): def execute_complete(self, context: Context, event: dict): status = event["status"] - if status == RunJobStatus.TIMEOUT: + if status == RunJobStatus.TIMEOUT.value: raise AirflowException("Operation timed out") - if status == RunJobStatus.FAIL: + if status == RunJobStatus.FAIL.value: error_code = event["operation_error_code"] error_message = event["operation_error_message"] raise AirflowException( diff --git a/airflow/providers/google/cloud/triggers/cloud_run.py b/airflow/providers/google/cloud/triggers/cloud_run.py index 9506245d20cd3..f47a7ac1b34ac 100644 --- a/airflow/providers/google/cloud/triggers/cloud_run.py +++ b/airflow/providers/google/cloud/triggers/cloud_run.py @@ -102,21 +102,21 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while timeout is None or timeout > 0: operation: operations_pb2.Operation = await hook.get_operation(self.operation_name) if operation.done: - # An operation can only have one of those two combinations: if it is succeeded, then - # the response field will be populated, else, then the error field will be. - if operation.response is not None: + # An operation can only have one of those two combinations: if it is failed, then + # the error field will be populated, else, then the response field will be. + if operation.error.SerializeToString(): yield TriggerEvent( { - "status": RunJobStatus.SUCCESS, + "status": RunJobStatus.FAIL.value, + "operation_error_code": operation.error.code, + "operation_error_message": operation.error.message, "job_name": self.job_name, } ) else: yield TriggerEvent( { - "status": RunJobStatus.FAIL, - "operation_error_code": operation.error.code, - "operation_error_message": operation.error.message, + "status": RunJobStatus.SUCCESS.value, "job_name": self.job_name, } ) diff --git a/tests/providers/google/cloud/operators/test_cloud_run.py b/tests/providers/google/cloud/operators/test_cloud_run.py index 152e625a233c9..829518e0d0e96 100644 --- a/tests/providers/google/cloud/operators/test_cloud_run.py +++ b/tests/providers/google/cloud/operators/test_cloud_run.py @@ -166,7 +166,7 @@ def test_execute_deferrable_execute_complete_method_timeout(self, hook_mock): task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, deferrable=True ) - event = {"status": RunJobStatus.TIMEOUT, "job_name": JOB_NAME} + event = {"status": RunJobStatus.TIMEOUT.value, "job_name": JOB_NAME} with pytest.raises(AirflowException) as e: operator.execute_complete(mock.MagicMock(), event) @@ -183,7 +183,7 @@ def test_execute_deferrable_execute_complete_method_fail(self, hook_mock): error_message = "error message" event = { - "status": RunJobStatus.FAIL, + "status": RunJobStatus.FAIL.value, "operation_error_code": error_code, "operation_error_message": error_message, "job_name": JOB_NAME, @@ -204,7 +204,7 @@ def test_execute_deferrable_execute_complete_method_success(self, hook_mock): task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, deferrable=True ) - event = {"status": RunJobStatus.SUCCESS, "job_name": JOB_NAME} + event = {"status": RunJobStatus.SUCCESS.value, "job_name": JOB_NAME} result = operator.execute_complete(mock.MagicMock(), event) assert result["name"] == JOB_NAME diff --git a/tests/providers/google/cloud/triggers/test_cloud_run.py b/tests/providers/google/cloud/triggers/test_cloud_run.py index 30d56241ed4c7..d64c4cee106bf 100644 --- a/tests/providers/google/cloud/triggers/test_cloud_run.py +++ b/tests/providers/google/cloud/triggers/test_cloud_run.py @@ -20,13 +20,16 @@ from unittest import mock import pytest +from google.protobuf.any_pb2 import Any +from google.rpc.status_pb2 import Status -from airflow.exceptions import AirflowException from airflow.providers.google.cloud.triggers.cloud_run import CloudRunJobFinishedTrigger, RunJobStatus from airflow.triggers.base import TriggerEvent OPERATION_NAME = "operation" JOB_NAME = "jobName" +ERROR_CODE = 13 +ERROR_MESSAGE = "Some message" PROJECT_ID = "projectId" LOCATION = "us-central1" GCP_CONNECTION_ID = "gcp_connection_id" @@ -73,20 +76,21 @@ async def test_trigger_on_operation_completed_yield_successfully( Tests the CloudRunJobFinishedTrigger fires once the job execution reaches a successful state. """ - done = True - name = "name" - error_code = 10 - error_message = "message" + async def _mock_operation(name): + operation = mock.MagicMock() + operation.done = True + operation.name = "name" + operation.error = Any() + operation.error.ParseFromString(b"") + return operation - mock_hook.return_value.get_operation.return_value = self._mock_operation( - done, name, error_code, error_message - ) + mock_hook.return_value.get_operation = _mock_operation generator = trigger.run() actual = await generator.asend(None) # type:ignore[attr-defined] assert ( TriggerEvent( { - "status": RunJobStatus.SUCCESS, + "status": RunJobStatus.SUCCESS.value, "job_name": JOB_NAME, } ) @@ -102,18 +106,28 @@ async def test_trigger_on_operation_failed_yield_error( Tests the CloudRunJobFinishedTrigger raises an exception once the job execution fails. """ - done = False - name = "name" - error_code = 10 - error_message = "message" + async def _mock_operation(name): + operation = mock.MagicMock() + operation.done = True + operation.name = "name" + operation.error = Status(code=13, message="Some message") + return operation - mock_hook.return_value.get_operation.return_value = self._mock_operation( - done, name, error_code, error_message - ) + mock_hook.return_value.get_operation = _mock_operation generator = trigger.run() - with pytest.raises(expected_exception=AirflowException): - await generator.asend(None) # type:ignore[attr-defined] + actual = await generator.asend(None) # type:ignore[attr-defined] + assert ( + TriggerEvent( + { + "status": RunJobStatus.FAIL.value, + "operation_error_code": ERROR_CODE, + "operation_error_message": ERROR_MESSAGE, + "job_name": JOB_NAME, + } + ) + == actual + ) @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.cloud_run.CloudRunAsyncHook") @@ -144,12 +158,3 @@ async def _mock_operation(name): ) == actual ) - - async def _mock_operation(self, done, name, error_code, error_message): - operation = mock.MagicMock() - operation.done = done - operation.name = name - operation.error = mock.MagicMock() - operation.error.message = error_message - operation.error.code = error_code - return operation