Skip to content

Commit

Permalink
DataflowStopJobOperator Operator (#27033)
Browse files Browse the repository at this point in the history
  • Loading branch information
Voldurk committed Oct 31, 2022
1 parent eb8c0cf commit 50d217a
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 22 deletions.
25 changes: 25 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_dataflow.py
Expand Up @@ -34,6 +34,7 @@
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
from airflow.providers.google.cloud.operators.dataflow import (
CheckJobRunning,
DataflowStopJobOperator,
DataflowTemplatedJobStartOperator,
)
from airflow.providers.google.cloud.sensors.dataflow import (
Expand Down Expand Up @@ -261,3 +262,27 @@ def check_autoscaling_event(autoscaling_events: list[dict]) -> bool:
location="europe-west3",
)
# [END howto_operator_start_template_job]

with models.DAG(
"example_gcp_stop_dataflow_job",
default_args=default_args,
start_date=START_DATE,
catchup=False,
tags=["example"],
) as dag_template:
# [START howto_operator_stop_dataflow_job]
stop_dataflow_job = DataflowStopJobOperator(
task_id="stop-dataflow-job",
location="europe-west3",
job_name_prefix="start-template-job",
)
# [END howto_operator_stop_dataflow_job]
start_template_job = DataflowTemplatedJobStartOperator(
task_id="start-template-job",
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
location="europe-west3",
append_job_name=False,
)

stop_dataflow_job >> start_template_job
25 changes: 15 additions & 10 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -237,6 +237,8 @@ def _get_current_jobs(self) -> list[dict]:
"""
if not self._multiple_jobs and self._job_id:
return [self.fetch_job_by_id(self._job_id)]
elif self._jobs:
return [self.fetch_job_by_id(job["id"]) for job in self._jobs]
elif self._job_name:
jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower())
if len(jobs) == 1:
Expand Down Expand Up @@ -445,11 +447,11 @@ def _wait_for_states(self, expected_states: set[str]):
job_states = {job["currentState"] for job in self._jobs}
if not job_states.difference(expected_states):
return
unexpected_failed_end_states = expected_states - DataflowJobStatus.FAILED_END_STATES
unexpected_failed_end_states = DataflowJobStatus.FAILED_END_STATES - expected_states
if unexpected_failed_end_states.intersection(job_states):
unexpected_failed_jobs = {
unexpected_failed_jobs = [
job for job in self._jobs if job["currentState"] in unexpected_failed_end_states
}
]
raise AirflowException(
"Jobs failed: "
+ ", ".join(
Expand All @@ -461,18 +463,19 @@ def _wait_for_states(self, expected_states: set[str]):

def cancel(self) -> None:
"""Cancels or drains current job"""
jobs = self.get_jobs()
job_ids = [job["id"] for job in jobs if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES]
self._jobs = [
job for job in self.get_jobs() if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES
]
job_ids = [job["id"] for job in self._jobs]
if job_ids:
batch = self._dataflow.new_batch_http_request()
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job in jobs:
for job in self._jobs:
requested_state = (
DataflowJobStatus.JOB_STATE_DRAINED
if self.drain_pipeline and job["type"] == DataflowJobType.JOB_TYPE_STREAMING
else DataflowJobStatus.JOB_STATE_CANCELLED
)
batch.add(
request = (
self._dataflow.projects()
.locations()
.jobs()
Expand All @@ -483,14 +486,16 @@ def cancel(self) -> None:
body={"requestedState": requested_state},
)
)
batch.execute()
request.execute(num_retries=self._num_retries)
if self._cancel_timeout and isinstance(self._cancel_timeout, int):
timeout_error_message = (
f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}"
)
tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message)
with tm:
self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
self._wait_for_states(
{DataflowJobStatus.JOB_STATE_CANCELLED, DataflowJobStatus.JOB_STATE_DRAINED}
)
else:
self.log.info("No jobs to cancel")

Expand Down
93 changes: 93 additions & 0 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -1137,3 +1137,96 @@ def on_kill(self) -> None:
self.dataflow_hook.cancel_job(
job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
)


class DataflowStopJobOperator(BaseOperator):
"""
Stops the job with the specified name prefix or Job ID.
All jobs with provided name prefix will be stopped.
Streaming jobs are drained by default.
Parameter ``job_name_prefix`` and ``job_id`` are mutually exclusive.
.. seealso::
For more details on stopping a pipeline see:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DataflowStopJobOperator`
:param job_name_prefix: Name prefix specifying which jobs are to be stopped.
:param job_id: Job ID specifying which jobs are to be stopped.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param location: Optional, Job location. If set to None or missing, "us-central1" will be used.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:param poll_sleep: The time in seconds to sleep between polling Google
Cloud Platform for the dataflow job status to confirm it's stopped.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it
instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:param stop_timeout: wait time in seconds for successful job canceling/draining
"""

def __init__(
self,
job_name_prefix: str | None = None,
job_id: str | None = None,
project_id: str | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
poll_sleep: int = 10,
impersonation_chain: str | Sequence[str] | None = None,
stop_timeout: int | None = 10 * 60,
drain_pipeline: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poll_sleep = poll_sleep
self.stop_timeout = stop_timeout
self.job_name = job_name_prefix
self.job_id = job_id
self.project_id = project_id
self.location = location
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.hook: DataflowHook | None = None
self.drain_pipeline = drain_pipeline

def execute(self, context: Context) -> None:
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.stop_timeout,
drain_pipeline=self.drain_pipeline,
)
if self.job_id or self.dataflow_hook.is_job_dataflow_running(
name=self.job_name,
project_id=self.project_id,
location=self.location,
):
self.dataflow_hook.cancel_job(
job_name=self.job_name,
project_id=self.project_id,
location=self.location,
job_id=self.job_id,
)
else:
self.log.info("No jobs to stop")

return None
18 changes: 18 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
Expand Up @@ -238,6 +238,24 @@ Here is an example of running Dataflow SQL job with
See the `Dataflow SQL reference
<https://cloud.google.com/dataflow/docs/reference/sql>`_.

.. _howto/operator:DataflowStopJobOperator:

Stopping a pipeline
^^^^^^^^^^^^^^^^^^^
To stop one or more Dataflow pipelines you can use
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStopJobOperator`.
Streaming pipelines are drained by default, setting ``drain_pipeline`` to ``False`` will cancel them instead.
Provide ``job_id`` to stop a specific job, or ``job_name_prefix`` to stop all jobs with provided name prefix.

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataflow.py
:language: python
:dedent: 4
:start-after: [START howto_operator_stop_dataflow_job]
:end-before: [END howto_operator_stop_dataflow_job]

See: `Stopping a running pipeline
<https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline>`_.

.. _howto/operator:DataflowJobStatusSensor:
.. _howto/operator:DataflowJobMetricsSensor:
.. _howto/operator:DataflowJobMessagesSensor:
Expand Down
16 changes: 4 additions & 12 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Expand Up @@ -1505,16 +1505,14 @@ def test_dataflow_job_cancel_job(self):
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)

self.mock_dataflow.new_batch_http_request.assert_called_once_with()
mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = mock_jobs.return_value.update
mock_update.assert_called_once_with(
body={"requestedState": "JOB_STATE_CANCELLED"},
jobId="test-job-id",
location=TEST_LOCATION,
projectId="test-project",
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
mock_update.return_value.execute.assert_called_once_with(num_retries=20)

@mock.patch("airflow.providers.google.cloud.hooks.dataflow.timeout")
@mock.patch("time.sleep")
Expand Down Expand Up @@ -1546,16 +1544,15 @@ def test_dataflow_job_cancel_job_cancel_timeout(self, mock_sleep, mock_timeout):
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)

self.mock_dataflow.new_batch_http_request.assert_called_once_with()
mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = mock_jobs.return_value.update
mock_update.assert_called_once_with(
body={"requestedState": "JOB_STATE_CANCELLED"},
jobId="test-job-id",
location=TEST_LOCATION,
projectId="test-project",
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
mock_update.return_value.execute.assert_called_once_with(num_retries=20)

mock_sleep.assert_has_calls([mock.call(4), mock.call(4), mock.call(4)])
mock_timeout.assert_called_once_with(
seconds=10, error_message="Canceling jobs failed due to timeout (10s): test-job-id"
Expand Down Expand Up @@ -1603,18 +1600,14 @@ def test_dataflow_job_cancel_or_drain_job(self, drain_pipeline, job_type, reques

get_method.return_value.execute.assert_called_once_with(num_retries=20)

self.mock_dataflow.new_batch_http_request.assert_called_once_with()

mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.update
mock_update.assert_called_once_with(
body={"requestedState": requested_state},
jobId="test-job-id",
location=TEST_LOCATION,
projectId="test-project",
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
mock_batch.execute.assert_called_once()
mock_update.return_value.execute.assert_called_once_with(num_retries=20)

def test_dataflow_job_cancel_job_no_running_jobs(self):
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
Expand Down Expand Up @@ -1643,7 +1636,6 @@ def test_dataflow_job_cancel_job_no_running_jobs(self):
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)

self.mock_dataflow.new_batch_http_request.assert_not_called()
mock_jobs.return_value.update.assert_not_called()

def test_fetch_list_job_messages_responses(self):
Expand Down
54 changes: 54 additions & 0 deletions tests/providers/google/cloud/operators/test_dataflow.py
Expand Up @@ -29,6 +29,7 @@
DataflowCreatePythonJobOperator,
DataflowStartFlexTemplateOperator,
DataflowStartSqlJobOperator,
DataflowStopJobOperator,
DataflowTemplatedJobStartOperator,
)
from airflow.version import version
Expand Down Expand Up @@ -561,3 +562,56 @@ def test_execute(self, mock_hook):
mock_hook.return_value.cancel_job.assert_called_once_with(
job_id="test-job-id", project_id=None, location=None
)


class TestDataflowStopJobOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
def test_exec_job_id(self, dataflow_mock):
self.dataflow = DataflowStopJobOperator(
task_id=TASK_ID,
project_id=TEST_PROJECT,
job_id=JOB_ID,
poll_sleep=POLL_SLEEP,
location=TEST_LOCATION,
)
"""
Test DataflowHook is created and the right args are passed to cancel_job.
"""
cancel_job_hook = dataflow_mock.return_value.cancel_job
self.dataflow.execute(None)
assert dataflow_mock.called
cancel_job_hook.assert_called_once_with(
job_name=None,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
job_id=JOB_ID,
)

@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
def test_exec_job_name_prefix(self, dataflow_mock):
self.dataflow = DataflowStopJobOperator(
task_id=TASK_ID,
project_id=TEST_PROJECT,
job_name_prefix=JOB_NAME,
poll_sleep=POLL_SLEEP,
location=TEST_LOCATION,
)
"""
Test DataflowHook is created and the right args are passed to cancel_job
and is_job_dataflow_running.
"""
is_job_running_hook = dataflow_mock.return_value.is_job_dataflow_running
cancel_job_hook = dataflow_mock.return_value.cancel_job
self.dataflow.execute(None)
assert dataflow_mock.called
is_job_running_hook.assert_called_once_with(
name=JOB_NAME,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
)
cancel_job_hook.assert_called_once_with(
job_name=JOB_NAME,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
job_id=None,
)

0 comments on commit 50d217a

Please sign in to comment.