Skip to content

Commit

Permalink
DataflowTemplatedJobStartOperator fix overwriting of location with …
Browse files Browse the repository at this point in the history
…default value, when a region is provided. (#31082)

* Fix overwriting of location with default value,
when a region is provided.

* Update tests/providers/google/cloud/operators/test_dataflow.py

Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com>

* Fix incompatible type for passing location to the TemplateJobStartTrigger.

---------

Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com>
  • Loading branch information
VVildVVolf and pankajastro committed May 8, 2023
1 parent 00a527f commit 810b5d4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -600,7 +600,7 @@ def __init__(
options: dict[str, Any] | None = None,
dataflow_default_options: dict[str, Any] | None = None,
parameters: dict[str, str] | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
location: str | None = None,
gcp_conn_id: str = "google_cloud_default",
poll_sleep: int = 10,
impersonation_chain: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -690,7 +690,7 @@ def set_current_job(current_job):
trigger=TemplateJobStartTrigger(
project_id=self.project_id,
job_id=job_id,
location=self.location,
location=self.location if self.location else DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id=self.gcp_conn_id,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
Expand Down
44 changes: 44 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Expand Up @@ -330,6 +330,50 @@ def test_start_python_dataflow_with_custom_region_as_parameter(
job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION
)

@mock.patch(DATAFLOW_STRING.format("uuid.uuid4"))
@mock.patch(DATAFLOW_STRING.format("DataflowHook.wait_for_done"))
@mock.patch(DATAFLOW_STRING.format("process_line_and_extract_dataflow_job_id_callback"))
def test_start_python_dataflow_with_no_custom_region_or_region(
self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
on_new_job_id_callback = MagicMock()
py_requirements = ["pandas", "numpy"]
job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"

passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)

with pytest.warns(AirflowProviderDeprecationWarning, match="This method is deprecated"):
self.dataflow_hook.start_python_dataflow(
job_name=JOB_NAME,
variables=passed_variables,
dataflow=PY_FILE,
py_options=PY_OPTIONS,
py_interpreter=DEFAULT_PY_INTERPRETER,
py_requirements=py_requirements,
on_new_job_id_callback=on_new_job_id_callback,
)

expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
expected_variables["job_name"] = job_name
expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION

mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
mock_beam_start_python_pipeline.assert_called_once_with(
variables=expected_variables,
py_file=PY_FILE,
py_interpreter=DEFAULT_PY_INTERPRETER,
py_options=PY_OPTIONS,
py_requirements=py_requirements,
py_system_site_packages=False,
process_line_callback=mock_callback_on_job_id.return_value,
)

mock_dataflow_wait_for_done.assert_called_once_with(
job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
)

@mock.patch(DATAFLOW_STRING.format("uuid.uuid4"))
@mock.patch(DATAFLOW_STRING.format("DataflowHook.wait_for_done"))
@mock.patch(DATAFLOW_STRING.format("process_line_and_extract_dataflow_job_id_callback"))
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/google/cloud/operators/test_dataflow.py
Expand Up @@ -76,6 +76,7 @@
},
}
TEST_LOCATION = "custom-location"
TEST_REGION = "custom-region"
TEST_PROJECT = "test-project"
TEST_SQL_JOB_NAME = "test-sql-job-name"
TEST_DATASET = "test-dataset"
Expand Down Expand Up @@ -534,6 +535,42 @@ def test_validation_deferrable_params_raises_error(self):
with pytest.raises(ValueError):
DataflowTemplatedJobStartOperator(**init_kwargs)

@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
def test_start_with_custom_region(self, dataflow_mock):
init_kwargs = {
"task_id": TASK_ID,
"template": TEMPLATE,
"dataflow_default_options": {
"region": TEST_REGION,
},
"poll_sleep": POLL_SLEEP,
"wait_until_finished": True,
"cancel_timeout": CANCEL_TIMEOUT,
}
operator = DataflowTemplatedJobStartOperator(**init_kwargs)
operator.execute(None)
assert dataflow_mock.called
_, kwargs = dataflow_mock.call_args_list[0]
assert kwargs["variables"]["region"] == TEST_REGION
assert kwargs["location"] is None

@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
def test_start_with_location(self, dataflow_mock):
init_kwargs = {
"task_id": TASK_ID,
"template": TEMPLATE,
"location": TEST_LOCATION,
"poll_sleep": POLL_SLEEP,
"wait_until_finished": True,
"cancel_timeout": CANCEL_TIMEOUT,
}
operator = DataflowTemplatedJobStartOperator(**init_kwargs)
operator.execute(None)
assert dataflow_mock.called
_, kwargs = dataflow_mock.call_args_list[0]
assert not kwargs["variables"]
assert kwargs["location"] == TEST_LOCATION


class TestDataflowStartFlexTemplateOperator:
@pytest.fixture
Expand Down

0 comments on commit 810b5d4

Please sign in to comment.