diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index d650775b9967e..eeec715edfc35 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -600,7 +600,7 @@ def __init__( options: dict[str, Any] | None = None, dataflow_default_options: dict[str, Any] | None = None, parameters: dict[str, str] | None = None, - location: str = DEFAULT_DATAFLOW_LOCATION, + location: str | None = None, gcp_conn_id: str = "google_cloud_default", poll_sleep: int = 10, impersonation_chain: str | Sequence[str] | None = None, @@ -690,7 +690,7 @@ def set_current_job(current_job): trigger=TemplateJobStartTrigger( project_id=self.project_id, job_id=job_id, - location=self.location, + location=self.location if self.location else DEFAULT_DATAFLOW_LOCATION, gcp_conn_id=self.gcp_conn_id, poll_sleep=self.poll_sleep, impersonation_chain=self.impersonation_chain, diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index f514b2705b0b4..c17358ba78cea 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -330,6 +330,50 @@ def test_start_python_dataflow_with_custom_region_as_parameter( job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION ) + @mock.patch(DATAFLOW_STRING.format("uuid.uuid4")) + @mock.patch(DATAFLOW_STRING.format("DataflowHook.wait_for_done")) + @mock.patch(DATAFLOW_STRING.format("process_line_and_extract_dataflow_job_id_callback")) + def test_start_python_dataflow_with_no_custom_region_or_region( + self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid + ): + mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline + mock_uuid.return_value = MOCK_UUID + on_new_job_id_callback = MagicMock() + py_requirements = ["pandas", "numpy"] + job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}" + + passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY) + + with pytest.warns(AirflowProviderDeprecationWarning, match="This method is deprecated"): + self.dataflow_hook.start_python_dataflow( + job_name=JOB_NAME, + variables=passed_variables, + dataflow=PY_FILE, + py_options=PY_OPTIONS, + py_interpreter=DEFAULT_PY_INTERPRETER, + py_requirements=py_requirements, + on_new_job_id_callback=on_new_job_id_callback, + ) + + expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY) + expected_variables["job_name"] = job_name + expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION + + mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback) + mock_beam_start_python_pipeline.assert_called_once_with( + variables=expected_variables, + py_file=PY_FILE, + py_interpreter=DEFAULT_PY_INTERPRETER, + py_options=PY_OPTIONS, + py_requirements=py_requirements, + py_system_site_packages=False, + process_line_callback=mock_callback_on_job_id.return_value, + ) + + mock_dataflow_wait_for_done.assert_called_once_with( + job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION + ) + @mock.patch(DATAFLOW_STRING.format("uuid.uuid4")) @mock.patch(DATAFLOW_STRING.format("DataflowHook.wait_for_done")) @mock.patch(DATAFLOW_STRING.format("process_line_and_extract_dataflow_job_id_callback")) diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index c68938042ffb2..72db54c109f11 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -76,6 +76,7 @@ }, } TEST_LOCATION = "custom-location" +TEST_REGION = "custom-region" TEST_PROJECT = "test-project" TEST_SQL_JOB_NAME = "test-sql-job-name" TEST_DATASET = "test-dataset" @@ -534,6 +535,42 @@ def test_validation_deferrable_params_raises_error(self): with pytest.raises(ValueError): DataflowTemplatedJobStartOperator(**init_kwargs) + @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow") + def test_start_with_custom_region(self, dataflow_mock): + init_kwargs = { + "task_id": TASK_ID, + "template": TEMPLATE, + "dataflow_default_options": { + "region": TEST_REGION, + }, + "poll_sleep": POLL_SLEEP, + "wait_until_finished": True, + "cancel_timeout": CANCEL_TIMEOUT, + } + operator = DataflowTemplatedJobStartOperator(**init_kwargs) + operator.execute(None) + assert dataflow_mock.called + _, kwargs = dataflow_mock.call_args_list[0] + assert kwargs["variables"]["region"] == TEST_REGION + assert kwargs["location"] is None + + @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow") + def test_start_with_location(self, dataflow_mock): + init_kwargs = { + "task_id": TASK_ID, + "template": TEMPLATE, + "location": TEST_LOCATION, + "poll_sleep": POLL_SLEEP, + "wait_until_finished": True, + "cancel_timeout": CANCEL_TIMEOUT, + } + operator = DataflowTemplatedJobStartOperator(**init_kwargs) + operator.execute(None) + assert dataflow_mock.called + _, kwargs = dataflow_mock.call_args_list[0] + assert not kwargs["variables"] + assert kwargs["location"] == TEST_LOCATION + class TestDataflowStartFlexTemplateOperator: @pytest.fixture