diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index 1364de183e64e..b7b0c017dc06e 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -34,6 +34,7 @@ from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, + DataflowStopJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.providers.google.cloud.sensors.dataflow import ( @@ -261,3 +262,27 @@ def check_autoscaling_event(autoscaling_events: list[dict]) -> bool: location="europe-west3", ) # [END howto_operator_start_template_job] + +with models.DAG( + "example_gcp_stop_dataflow_job", + default_args=default_args, + start_date=START_DATE, + catchup=False, + tags=["example"], +) as dag_template: + # [START howto_operator_stop_dataflow_job] + stop_dataflow_job = DataflowStopJobOperator( + task_id="stop-dataflow-job", + location="europe-west3", + job_name_prefix="start-template-job", + ) + # [END howto_operator_stop_dataflow_job] + start_template_job = DataflowTemplatedJobStartOperator( + task_id="start-template-job", + template="gs://dataflow-templates/latest/Word_Count", + parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT}, + location="europe-west3", + append_job_name=False, + ) + + stop_dataflow_job >> start_template_job diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 490024b693a00..b9dbf9478cfc4 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -237,6 +237,8 @@ def _get_current_jobs(self) -> list[dict]: """ if not self._multiple_jobs and self._job_id: return [self.fetch_job_by_id(self._job_id)] + elif self._jobs: + return [self.fetch_job_by_id(job["id"]) for job in self._jobs] elif self._job_name: jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower()) if len(jobs) == 1: @@ -445,11 +447,11 @@ def _wait_for_states(self, expected_states: set[str]): job_states = {job["currentState"] for job in self._jobs} if not job_states.difference(expected_states): return - unexpected_failed_end_states = expected_states - DataflowJobStatus.FAILED_END_STATES + unexpected_failed_end_states = DataflowJobStatus.FAILED_END_STATES - expected_states if unexpected_failed_end_states.intersection(job_states): - unexpected_failed_jobs = { + unexpected_failed_jobs = [ job for job in self._jobs if job["currentState"] in unexpected_failed_end_states - } + ] raise AirflowException( "Jobs failed: " + ", ".join( @@ -461,18 +463,19 @@ def _wait_for_states(self, expected_states: set[str]): def cancel(self) -> None: """Cancels or drains current job""" - jobs = self.get_jobs() - job_ids = [job["id"] for job in jobs if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES] + self._jobs = [ + job for job in self.get_jobs() if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES + ] + job_ids = [job["id"] for job in self._jobs] if job_ids: - batch = self._dataflow.new_batch_http_request() self.log.info("Canceling jobs: %s", ", ".join(job_ids)) - for job in jobs: + for job in self._jobs: requested_state = ( DataflowJobStatus.JOB_STATE_DRAINED if self.drain_pipeline and job["type"] == DataflowJobType.JOB_TYPE_STREAMING else DataflowJobStatus.JOB_STATE_CANCELLED ) - batch.add( + request = ( self._dataflow.projects() .locations() .jobs() @@ -483,14 +486,16 @@ def cancel(self) -> None: body={"requestedState": requested_state}, ) ) - batch.execute() + request.execute(num_retries=self._num_retries) if self._cancel_timeout and isinstance(self._cancel_timeout, int): timeout_error_message = ( f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}" ) tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message) with tm: - self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED}) + self._wait_for_states( + {DataflowJobStatus.JOB_STATE_CANCELLED, DataflowJobStatus.JOB_STATE_DRAINED} + ) else: self.log.info("No jobs to cancel") diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 6d0a0412ae1f4..7cfab6bbff598 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -1137,3 +1137,96 @@ def on_kill(self) -> None: self.dataflow_hook.cancel_job( job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id ) + + +class DataflowStopJobOperator(BaseOperator): + """ + Stops the job with the specified name prefix or Job ID. + All jobs with provided name prefix will be stopped. + Streaming jobs are drained by default. + + Parameter ``job_name_prefix`` and ``job_id`` are mutually exclusive. + + .. seealso:: + For more details on stopping a pipeline see: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowStopJobOperator` + + :param job_name_prefix: Name prefix specifying which jobs are to be stopped. + :param job_id: Job ID specifying which jobs are to be stopped. + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param location: Optional, Job location. If set to None or missing, "us-central1" will be used. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status to confirm it's stopped. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it + instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param stop_timeout: wait time in seconds for successful job canceling/draining + """ + + def __init__( + self, + job_name_prefix: str | None = None, + job_id: str | None = None, + project_id: str | None = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + stop_timeout: int | None = 10 * 60, + drain_pipeline: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poll_sleep = poll_sleep + self.stop_timeout = stop_timeout + self.job_name = job_name_prefix + self.job_id = job_id + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: DataflowHook | None = None + self.drain_pipeline = drain_pipeline + + def execute(self, context: Context) -> None: + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.stop_timeout, + drain_pipeline=self.drain_pipeline, + ) + if self.job_id or self.dataflow_hook.is_job_dataflow_running( + name=self.job_name, + project_id=self.project_id, + location=self.location, + ): + self.dataflow_hook.cancel_job( + job_name=self.job_name, + project_id=self.project_id, + location=self.location, + job_id=self.job_id, + ) + else: + self.log.info("No jobs to stop") + + return None diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index 32ac462f0c6aa..76829275a2ea8 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -238,6 +238,24 @@ Here is an example of running Dataflow SQL job with See the `Dataflow SQL reference `_. +.. _howto/operator:DataflowStopJobOperator: + +Stopping a pipeline +^^^^^^^^^^^^^^^^^^^ +To stop one or more Dataflow pipelines you can use +:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStopJobOperator`. +Streaming pipelines are drained by default, setting ``drain_pipeline`` to ``False`` will cancel them instead. +Provide ``job_id`` to stop a specific job, or ``job_name_prefix`` to stop all jobs with provided name prefix. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataflow.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_stop_dataflow_job] + :end-before: [END howto_operator_stop_dataflow_job] + +See: `Stopping a running pipeline +`_. + .. _howto/operator:DataflowJobStatusSensor: .. _howto/operator:DataflowJobMetricsSensor: .. _howto/operator:DataflowJobMessagesSensor: diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index d2db3b8e1b6c4..ae7478880b36a 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -1505,8 +1505,6 @@ def test_dataflow_job_cancel_job(self): get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT) get_method.return_value.execute.assert_called_with(num_retries=20) - self.mock_dataflow.new_batch_http_request.assert_called_once_with() - mock_batch = self.mock_dataflow.new_batch_http_request.return_value mock_update = mock_jobs.return_value.update mock_update.assert_called_once_with( body={"requestedState": "JOB_STATE_CANCELLED"}, @@ -1514,7 +1512,7 @@ def test_dataflow_job_cancel_job(self): location=TEST_LOCATION, projectId="test-project", ) - mock_batch.add.assert_called_once_with(mock_update.return_value) + mock_update.return_value.execute.assert_called_once_with(num_retries=20) @mock.patch("airflow.providers.google.cloud.hooks.dataflow.timeout") @mock.patch("time.sleep") @@ -1546,8 +1544,6 @@ def test_dataflow_job_cancel_job_cancel_timeout(self, mock_sleep, mock_timeout): get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT) get_method.return_value.execute.assert_called_with(num_retries=20) - self.mock_dataflow.new_batch_http_request.assert_called_once_with() - mock_batch = self.mock_dataflow.new_batch_http_request.return_value mock_update = mock_jobs.return_value.update mock_update.assert_called_once_with( body={"requestedState": "JOB_STATE_CANCELLED"}, @@ -1555,7 +1551,8 @@ def test_dataflow_job_cancel_job_cancel_timeout(self, mock_sleep, mock_timeout): location=TEST_LOCATION, projectId="test-project", ) - mock_batch.add.assert_called_once_with(mock_update.return_value) + mock_update.return_value.execute.assert_called_once_with(num_retries=20) + mock_sleep.assert_has_calls([mock.call(4), mock.call(4), mock.call(4)]) mock_timeout.assert_called_once_with( seconds=10, error_message="Canceling jobs failed due to timeout (10s): test-job-id" @@ -1603,9 +1600,6 @@ def test_dataflow_job_cancel_or_drain_job(self, drain_pipeline, job_type, reques get_method.return_value.execute.assert_called_once_with(num_retries=20) - self.mock_dataflow.new_batch_http_request.assert_called_once_with() - - mock_batch = self.mock_dataflow.new_batch_http_request.return_value mock_update = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.update mock_update.assert_called_once_with( body={"requestedState": requested_state}, @@ -1613,8 +1607,7 @@ def test_dataflow_job_cancel_or_drain_job(self, drain_pipeline, job_type, reques location=TEST_LOCATION, projectId="test-project", ) - mock_batch.add.assert_called_once_with(mock_update.return_value) - mock_batch.execute.assert_called_once() + mock_update.return_value.execute.assert_called_once_with(num_retries=20) def test_dataflow_job_cancel_job_no_running_jobs(self): mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs @@ -1643,7 +1636,6 @@ def test_dataflow_job_cancel_job_no_running_jobs(self): get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT) get_method.return_value.execute.assert_called_with(num_retries=20) - self.mock_dataflow.new_batch_http_request.assert_not_called() mock_jobs.return_value.update.assert_not_called() def test_fetch_list_job_messages_responses(self): diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 1a0f05d404ca8..40e37c9487593 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -29,6 +29,7 @@ DataflowCreatePythonJobOperator, DataflowStartFlexTemplateOperator, DataflowStartSqlJobOperator, + DataflowStopJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.version import version @@ -561,3 +562,56 @@ def test_execute(self, mock_hook): mock_hook.return_value.cancel_job.assert_called_once_with( job_id="test-job-id", project_id=None, location=None ) + + +class TestDataflowStopJobOperator(unittest.TestCase): + @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") + def test_exec_job_id(self, dataflow_mock): + self.dataflow = DataflowStopJobOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + job_id=JOB_ID, + poll_sleep=POLL_SLEEP, + location=TEST_LOCATION, + ) + """ + Test DataflowHook is created and the right args are passed to cancel_job. + """ + cancel_job_hook = dataflow_mock.return_value.cancel_job + self.dataflow.execute(None) + assert dataflow_mock.called + cancel_job_hook.assert_called_once_with( + job_name=None, + project_id=TEST_PROJECT, + location=TEST_LOCATION, + job_id=JOB_ID, + ) + + @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") + def test_exec_job_name_prefix(self, dataflow_mock): + self.dataflow = DataflowStopJobOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + job_name_prefix=JOB_NAME, + poll_sleep=POLL_SLEEP, + location=TEST_LOCATION, + ) + """ + Test DataflowHook is created and the right args are passed to cancel_job + and is_job_dataflow_running. + """ + is_job_running_hook = dataflow_mock.return_value.is_job_dataflow_running + cancel_job_hook = dataflow_mock.return_value.cancel_job + self.dataflow.execute(None) + assert dataflow_mock.called + is_job_running_hook.assert_called_once_with( + name=JOB_NAME, + project_id=TEST_PROJECT, + location=TEST_LOCATION, + ) + cancel_job_hook.assert_called_once_with( + job_name=JOB_NAME, + project_id=TEST_PROJECT, + location=TEST_LOCATION, + job_id=None, + )