Skip to content

Commit

Permalink
Added overrides parameter to CloudRunExecuteJobOperator (#34874)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChloeSheasby committed Oct 25, 2023
1 parent 5f4d2b5 commit 0bb5631
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 10 deletions.
12 changes: 9 additions & 3 deletions airflow/providers/google/cloud/hooks/cloud_run.py
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from google.cloud.run_v2 import (
CreateJobRequest,
Expand Down Expand Up @@ -113,9 +113,15 @@ def update_job(

@GoogleBaseHook.fallback_to_default_project_id
def execute_job(
self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
self,
job_name: str,
region: str,
project_id: str = PROVIDE_PROJECT_ID,
overrides: dict[str, Any] | None = None,
) -> operation.Operation:
run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
run_job_request = RunJobRequest(
name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", overrides=overrides
)
operation = self.get_conn().run_job(request=run_job_request)
return operation

Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/google/cloud/operators/cloud_run.py
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from google.cloud.run_v2 import Job

Expand Down Expand Up @@ -248,6 +248,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
:param job_name: Required. The name of the job to update.
:param job: Required. The job descriptor containing the new configuration of the job to update.
The name field will be replaced by job_name
:param overrides: Optional map of override values.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param polling_period_seconds: Optional: Control the rate of the poll for the result of deferrable run.
By default, the trigger will poll every 10 seconds.
Expand All @@ -270,6 +271,7 @@ def __init__(
project_id: str,
region: str,
job_name: str,
overrides: dict[str, Any] | None = None,
polling_period_seconds: float = 10,
timeout_seconds: float | None = None,
gcp_conn_id: str = "google_cloud_default",
Expand All @@ -281,6 +283,7 @@ def __init__(
self.project_id = project_id
self.region = region
self.job_name = job_name
self.overrides = overrides
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_period_seconds = polling_period_seconds
Expand All @@ -293,7 +296,7 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
self.operation = hook.execute_job(
region=self.region, project_id=self.project_id, job_name=self.job_name
region=self.region, project_id=self.project_id, job_name=self.job_name, overrides=self.overrides
)

if not self.deferrable:
Expand Down
Expand Up @@ -77,6 +77,15 @@ or you can define the same operator in the deferrable mode:
:start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode]
:end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode]

You can also specify overrides that allow you to give a new entrypoint command to the job and more:

:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`

.. exampleinclude:: /../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
:language: python
:dedent: 4
:start-after: [START howto_operator_cloud_run_execute_job_with_overrides]
:end-before: [END howto_operator_cloud_run_execute_job_with_overrides]


Update a job
Expand Down
15 changes: 12 additions & 3 deletions tests/providers/google/cloud/hooks/test_cloud_run.py
Expand Up @@ -34,7 +34,7 @@
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id


class TestCloudBathHook:
class TestCloudRunHook:
def dummy_get_credentials(self):
pass

Expand Down Expand Up @@ -111,9 +111,18 @@ def test_execute_job(self, mock_batch_service_client, cloud_run_hook):
job_name = "job1"
region = "region1"
project_id = "projectid"
run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": 1,
"timeout": "60s",
}
run_job_request = RunJobRequest(
name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", overrides=overrides
)

cloud_run_hook.execute_job(job_name=job_name, region=region, project_id=project_id)
cloud_run_hook.execute_job(
job_name=job_name, region=region, project_id=project_id, overrides=overrides
)
cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)

@mock.patch(
Expand Down
68 changes: 67 additions & 1 deletion tests/providers/google/cloud/operators/test_cloud_run.py
Expand Up @@ -96,7 +96,7 @@ def test_execute_success(self, hook_mock):
operator.execute(context=mock.MagicMock())

hook_mock.return_value.execute_job.assert_called_once_with(
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=None
)

@mock.patch(CLOUD_RUN_HOOK_PATH)
Expand Down Expand Up @@ -209,6 +209,72 @@ def test_execute_deferrable_execute_complete_method_success(self, hook_mock):
result = operator.execute_complete(mock.MagicMock(), event)
assert result["name"] == JOB_NAME

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides(self, hook_mock):
hook_mock.return_value.get_job.return_value = JOB
hook_mock.return_value.execute_job.return_value = self._mock_operation(3, 3, 0)

overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": 1,
"timeout": "60s",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

operator.execute(context=mock.MagicMock())

hook_mock.return_value.execute_job.assert_called_once_with(
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=overrides
)

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides_with_invalid_task_count(self, hook_mock):
overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": -1,
"timeout": "60s",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

with pytest.raises(AirflowException):
operator.execute(context=mock.MagicMock())

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides_with_invalid_timeout(self, hook_mock):
overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": 1,
"timeout": "60",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

with pytest.raises(AirflowException):
operator.execute(context=mock.MagicMock())

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides_with_invalid_container_args(self, hook_mock):
overrides = {
"container_overrides": [{"name": "job", "args": "python main.py"}],
"task_count": 1,
"timeout": "60s",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

with pytest.raises(AirflowException):
operator.execute(context=mock.MagicMock())

def _mock_operation(self, task_count, succeeded_count, failed_count):
operation = mock.MagicMock()
operation.result.return_value = self._mock_execution(task_count, succeeded_count, failed_count)
Expand Down
Expand Up @@ -44,12 +44,14 @@
job_name_prefix = "cloudrun-system-test-job"
job1_name = f"{job_name_prefix}1"
job2_name = f"{job_name_prefix}2"
job3_name = f"{job_name_prefix}3"

create1_task_name = "create-job1"
create2_task_name = "create-job2"

execute1_task_name = "execute-job1"
execute2_task_name = "execute-job2"
execute3_task_name = "execute-job3"

update_job1_task_name = "update-job1"

Expand All @@ -70,6 +72,9 @@ def _assert_executed_jobs_xcom(ti):
job2_dicts = ti.xcom_pull(task_ids=[execute2_task_name], key="return_value")
assert job2_name in job2_dicts[0]["name"]

job3_dicts = ti.xcom_pull(task_ids=[execute3_task_name], key="return_value")
assert job3_name in job3_dicts[0]["name"]


def _assert_created_jobs_xcom(ti):
job1_dicts = ti.xcom_pull(task_ids=[create1_task_name], key="return_value")
Expand Down Expand Up @@ -181,6 +186,31 @@ def _create_job_with_label():
)
# [END howto_operator_cloud_run_execute_job_deferrable_mode]

# [START howto_operator_cloud_run_execute_job_with_overrides]
overrides = {
"container_overrides": [
{
"name": "job",
"args": ["python", "main.py"],
"env": [{"name": "ENV_VAR", "value": "value"}],
"clearArgs": False,
}
],
"task_count": 1,
"timeout": "60s",
}

execute3 = CloudRunExecuteJobOperator(
task_id=execute3_task_name,
project_id=PROJECT_ID,
region=region,
overrides=overrides,
job_name=job3_name,
dag=dag,
deferrable=False,
)
# [END howto_operator_cloud_run_execute_job_with_overrides]

assert_executed_jobs = PythonOperator(
task_id="assert-executed-jobs", python_callable=_assert_executed_jobs_xcom, dag=dag
)
Expand Down Expand Up @@ -237,7 +267,7 @@ def _create_job_with_label():
(
(create1, create2)
>> assert_created_jobs
>> (execute1, execute2)
>> (execute1, execute2, execute3)
>> assert_executed_jobs
>> list_jobs_limit
>> assert_jobs_limit
Expand Down

0 comments on commit 0bb5631

Please sign in to comment.