Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drain option when canceling Dataflow pipelines #11374

Merged
merged 4 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class _DataflowJobsController(LoggingMixin):
:param num_retries: Maximum number of retries in case of connection problems.
:param multiple_jobs: If set to true this task will be searched by name prefix (``name`` parameter),
not by specific job ID, then actions will be performed on all matching jobs.
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling.
"""

def __init__(
Expand All @@ -157,6 +159,7 @@ def __init__(
job_id: Optional[str] = None,
num_retries: int = 0,
multiple_jobs: bool = False,
drain_pipeline: bool = False,
) -> None:

super().__init__()
Expand All @@ -168,6 +171,7 @@ def __init__(
self._job_id = job_id
self._num_retries = num_retries
self._poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self._jobs: Optional[List[dict]] = None

def is_job_running(self) -> bool:
Expand Down Expand Up @@ -304,22 +308,27 @@ def get_jobs(self, refresh=False) -> List[dict]:
return self._jobs

def cancel(self) -> None:
"""Cancels current job"""
"""Cancels or drains current job"""
jobs = self.get_jobs()
job_ids = [job['id'] for job in jobs if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES]
if job_ids:
batch = self._dataflow.new_batch_http_request()
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
for job in jobs:
requested_state = (
DataflowJobStatus.JOB_STATE_DRAINED
if self.drain_pipeline and job['type'] == DataflowJobType.JOB_TYPE_STREAMING
else DataflowJobStatus.JOB_STATE_CANCELLED
)
batch.add(
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
jobId=job['id'],
body={"requestedState": requested_state},
)
)
batch.execute()
Expand Down Expand Up @@ -427,8 +436,10 @@ def __init__(
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
drain_pipeline: bool = False,
) -> None:
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
Expand Down Expand Up @@ -464,6 +475,7 @@ def _start_dataflow(
job_id=job_id,
num_retries=self.num_retries,
multiple_jobs=multiple_jobs,
drain_pipeline=self.drain_pipeline,
)
job_controller.wait_for_done()

Expand Down Expand Up @@ -633,6 +645,7 @@ def start_template_dataflow(
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
drain_pipeline=self.drain_pipeline,
)
jobs_controller.wait_for_done()
return response["job"]
Expand Down Expand Up @@ -870,6 +883,7 @@ def is_job_dataflow_running(
name=name,
location=location,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
)
return jobs_controller.is_job_running()

Expand Down Expand Up @@ -903,5 +917,6 @@ def cancel_job(
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
)
jobs_controller.cancel()
25 changes: 21 additions & 4 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def __init__(

def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
)
dataflow_options = copy.copy(self.dataflow_default_options)
dataflow_options.update(self.options)
Expand Down Expand Up @@ -467,6 +469,10 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
"""

template_fields = ["body", 'location', 'project_id', 'gcp_conn_id']
Expand All @@ -479,6 +485,7 @@ def __init__(
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
*args,
**kwargs,
) -> None:
Expand All @@ -490,11 +497,11 @@ def __init__(
self.delegate_to = delegate_to
self.job_id = None
self.hook: Optional[DataflowHook] = None
self.drain_pipeline = drain_pipeline

def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, drain_pipeline=self.drain_pipeline
)

def set_current_job_id(job_id):
Expand All @@ -515,6 +522,7 @@ def on_kill(self) -> None:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)


# pylint: disable=too-many-instance-attributes
class DataflowCreatePythonJobOperator(BaseOperator):
"""
Launching Cloud Dataflow jobs written in python. Note that both
Expand Down Expand Up @@ -582,6 +590,10 @@ class DataflowCreatePythonJobOperator(BaseOperator):
Cloud Platform for the dataflow job status while the job is in the
JOB_STATE_RUNNING state.
:type poll_sleep: int
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
"""

template_fields = ['options', 'dataflow_default_options', 'job_name', 'py_file']
Expand All @@ -603,6 +615,7 @@ def __init__( # pylint: disable=too-many-arguments
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
drain_pipeline: bool = False,
**kwargs,
) -> None:

Expand All @@ -624,6 +637,7 @@ def __init__( # pylint: disable=too-many-arguments
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self.job_id = None
self.hook = None

Expand All @@ -638,7 +652,10 @@ def execute(self, context):
self.py_file = tmp_gcs_file.name

self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
)
dataflow_options = self.dataflow_default_options.copy()
dataflow_options.update(self.options)
Expand Down
60 changes: 60 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
poll_sleep=10,
project_number=TEST_PROJECT,
location=DEFAULT_DATAFLOW_LOCATION,
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -692,6 +693,7 @@ def test_start_template_dataflow_with_custom_region_as_variable(
poll_sleep=10,
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -730,6 +732,7 @@ def test_start_template_dataflow_with_custom_region_as_parameter(
poll_sleep=10,
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -772,6 +775,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow
num_retries=5,
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
)
mock_uuid.assert_called_once_with()

Expand Down Expand Up @@ -818,6 +822,7 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl
num_retries=5,
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
)
mock_uuid.assert_called_once_with()

Expand Down Expand Up @@ -868,6 +873,7 @@ def test_cancel_job(self, mock_get_conn, jobs_controller):
name=UNIQUE_JOB_NAME,
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
)
jobs_controller.cancel()

Expand Down Expand Up @@ -1196,6 +1202,60 @@ def test_dataflow_job_cancel_job(self):
)
mock_batch.add.assert_called_once_with(mock_update.return_value)

@parameterized.expand(
[
(False, "JOB_TYPE_BATCH", "JOB_STATE_CANCELLED"),
(False, "JOB_TYPE_STREAMING", "JOB_STATE_CANCELLED"),
(True, "JOB_TYPE_BATCH", "JOB_STATE_CANCELLED"),
(True, "JOB_TYPE_STREAMING", "JOB_STATE_DRAINED"),
]
)
def test_dataflow_job_cancel_or_drain_job(self, drain_pipeline, job_type, requested_state):
job = {
"id": TEST_JOB_ID,
"name": UNIQUE_JOB_NAME,
"currentState": DataflowJobStatus.JOB_STATE_RUNNING,
"type": job_type,
}
get_method = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.get
get_method.return_value.execute.return_value = job
# fmt: off
job_list_nest_method = (self.mock_dataflow
.projects.return_value.
locations.return_value.
jobs.return_value.list_next)
job_list_nest_method.return_value = None
# fmt: on
dataflow_job = _DataflowJobsController(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name=UNIQUE_JOB_NAME,
location=TEST_LOCATION,
poll_sleep=10,
job_id=TEST_JOB_ID,
num_retries=20,
multiple_jobs=False,
drain_pipeline=drain_pipeline,
)
dataflow_job.cancel()

get_method.assert_called_once_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)

get_method.return_value.execute.assert_called_once_with(num_retries=20)

self.mock_dataflow.new_batch_http_request.assert_called_once_with()

mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.update
mock_update.assert_called_once_with(
body={'requestedState': requested_state},
jobId='test-job-id',
location=TEST_LOCATION,
projectId='test-project',
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
mock_batch.execute.assert_called_once()

def test_dataflow_job_cancel_job_no_running_jobs(self):
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
get_method = mock_jobs.return_value.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_successful_run(self):
hook_instance.start_python_dataflow.return_value = None
summary.execute(None)
mock_dataflow_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10
gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10, drain_pipeline=False
)
hook_instance.start_python_dataflow.assert_called_once_with(
job_name='{{task.task_id}}',
Expand Down