Skip to content

Commit

Permalink
Fix CloudRunExecuteJobOperator not able to retrieve the Cloud Run job…
Browse files Browse the repository at this point in the history
… status in deferrable mode (#36012)

Co-authored-by: Ulada Zakharava <Vlada_Zakharava@epam.com>
  • Loading branch information
VladaZakharova and Ulada Zakharava committed Dec 1, 2023
1 parent fd03dc2 commit 86b1bd2
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/cloud_run.py
Expand Up @@ -321,10 +321,10 @@ def execute(self, context: Context):
def execute_complete(self, context: Context, event: dict):
status = event["status"]

if status == RunJobStatus.TIMEOUT:
if status == RunJobStatus.TIMEOUT.value:
raise AirflowException("Operation timed out")

if status == RunJobStatus.FAIL:
if status == RunJobStatus.FAIL.value:
error_code = event["operation_error_code"]
error_message = event["operation_error_message"]
raise AirflowException(
Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/google/cloud/triggers/cloud_run.py
Expand Up @@ -102,21 +102,21 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
while timeout is None or timeout > 0:
operation: operations_pb2.Operation = await hook.get_operation(self.operation_name)
if operation.done:
# An operation can only have one of those two combinations: if it is succeeded, then
# the response field will be populated, else, then the error field will be.
if operation.response is not None:
# An operation can only have one of those two combinations: if it is failed, then
# the error field will be populated, else, then the response field will be.
if operation.error.SerializeToString():
yield TriggerEvent(
{
"status": RunJobStatus.SUCCESS,
"status": RunJobStatus.FAIL.value,
"operation_error_code": operation.error.code,
"operation_error_message": operation.error.message,
"job_name": self.job_name,
}
)
else:
yield TriggerEvent(
{
"status": RunJobStatus.FAIL,
"operation_error_code": operation.error.code,
"operation_error_message": operation.error.message,
"status": RunJobStatus.SUCCESS.value,
"job_name": self.job_name,
}
)
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/operators/test_cloud_run.py
Expand Up @@ -166,7 +166,7 @@ def test_execute_deferrable_execute_complete_method_timeout(self, hook_mock):
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, deferrable=True
)

event = {"status": RunJobStatus.TIMEOUT, "job_name": JOB_NAME}
event = {"status": RunJobStatus.TIMEOUT.value, "job_name": JOB_NAME}

with pytest.raises(AirflowException) as e:
operator.execute_complete(mock.MagicMock(), event)
Expand All @@ -183,7 +183,7 @@ def test_execute_deferrable_execute_complete_method_fail(self, hook_mock):
error_message = "error message"

event = {
"status": RunJobStatus.FAIL,
"status": RunJobStatus.FAIL.value,
"operation_error_code": error_code,
"operation_error_message": error_message,
"job_name": JOB_NAME,
Expand All @@ -204,7 +204,7 @@ def test_execute_deferrable_execute_complete_method_success(self, hook_mock):
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, deferrable=True
)

event = {"status": RunJobStatus.SUCCESS, "job_name": JOB_NAME}
event = {"status": RunJobStatus.SUCCESS.value, "job_name": JOB_NAME}

result = operator.execute_complete(mock.MagicMock(), event)
assert result["name"] == JOB_NAME
Expand Down
59 changes: 32 additions & 27 deletions tests/providers/google/cloud/triggers/test_cloud_run.py
Expand Up @@ -20,13 +20,16 @@
from unittest import mock

import pytest
from google.protobuf.any_pb2 import Any
from google.rpc.status_pb2 import Status

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.triggers.cloud_run import CloudRunJobFinishedTrigger, RunJobStatus
from airflow.triggers.base import TriggerEvent

OPERATION_NAME = "operation"
JOB_NAME = "jobName"
ERROR_CODE = 13
ERROR_MESSAGE = "Some message"
PROJECT_ID = "projectId"
LOCATION = "us-central1"
GCP_CONNECTION_ID = "gcp_connection_id"
Expand Down Expand Up @@ -73,20 +76,21 @@ async def test_trigger_on_operation_completed_yield_successfully(
Tests the CloudRunJobFinishedTrigger fires once the job execution reaches a successful state.
"""

done = True
name = "name"
error_code = 10
error_message = "message"
async def _mock_operation(name):
operation = mock.MagicMock()
operation.done = True
operation.name = "name"
operation.error = Any()
operation.error.ParseFromString(b"")
return operation

mock_hook.return_value.get_operation.return_value = self._mock_operation(
done, name, error_code, error_message
)
mock_hook.return_value.get_operation = _mock_operation
generator = trigger.run()
actual = await generator.asend(None) # type:ignore[attr-defined]
assert (
TriggerEvent(
{
"status": RunJobStatus.SUCCESS,
"status": RunJobStatus.SUCCESS.value,
"job_name": JOB_NAME,
}
)
Expand All @@ -102,18 +106,28 @@ async def test_trigger_on_operation_failed_yield_error(
Tests the CloudRunJobFinishedTrigger raises an exception once the job execution fails.
"""

done = False
name = "name"
error_code = 10
error_message = "message"
async def _mock_operation(name):
operation = mock.MagicMock()
operation.done = True
operation.name = "name"
operation.error = Status(code=13, message="Some message")
return operation

mock_hook.return_value.get_operation.return_value = self._mock_operation(
done, name, error_code, error_message
)
mock_hook.return_value.get_operation = _mock_operation
generator = trigger.run()

with pytest.raises(expected_exception=AirflowException):
await generator.asend(None) # type:ignore[attr-defined]
actual = await generator.asend(None) # type:ignore[attr-defined]
assert (
TriggerEvent(
{
"status": RunJobStatus.FAIL.value,
"operation_error_code": ERROR_CODE,
"operation_error_message": ERROR_MESSAGE,
"job_name": JOB_NAME,
}
)
== actual
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.cloud_run.CloudRunAsyncHook")
Expand Down Expand Up @@ -144,12 +158,3 @@ async def _mock_operation(name):
)
== actual
)

async def _mock_operation(self, done, name, error_code, error_message):
operation = mock.MagicMock()
operation.done = done
operation.name = name
operation.error = mock.MagicMock()
operation.error.message = error_message
operation.error.code = error_code
return operation

0 comments on commit 86b1bd2

Please sign in to comment.