Skip to content

Commit

Permalink
Cancel workflow in on_kill in DataprocInstantiate{Inline}WorkflowTemp…
Browse files Browse the repository at this point in the history
…lateOperator (#34957)

* Cancel operation in on_kill in DataprocInstantiateWorkflowTemplateOperator

* Test on_kill method in DataprocInstantiateWorkflowTemplateOperator
  • Loading branch information
michalsosn committed Oct 16, 2023
1 parent 105743e commit 0b49f33
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 11 deletions.
42 changes: 32 additions & 10 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -1790,6 +1790,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
:param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
"""

template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
Expand All @@ -1812,6 +1813,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1830,6 +1832,8 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.operation_name: str | None = None

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand All @@ -1845,24 +1849,26 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
self.workflow_id = operation.operation.name.split("/")[-1]
operation_name = operation.operation.name
self.operation_name = operation_name
workflow_id = operation_name.split("/")[-1]
project_id = self.project_id or hook.project_id
if project_id:
DataprocWorkflowLink.persist(
context=context,
operator=self,
workflow_id=self.workflow_id,
workflow_id=workflow_id,
region=self.region,
project_id=project_id,
)
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
self.log.info("Workflow %s completed successfully", self.workflow_id)
self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
name=operation.operation.name,
name=operation_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -1884,6 +1890,11 @@ def execute_complete(self, context, event=None) -> None:

self.log.info("Workflow %s completed successfully", event["operation_name"])

def on_kill(self) -> None:
if self.cancel_on_kill and self.operation_name:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)


class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator):
"""Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc.
Expand Down Expand Up @@ -1926,6 +1937,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
:param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
"""

template_fields: Sequence[str] = ("template", "impersonation_chain")
Expand All @@ -1946,6 +1958,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1963,6 +1976,8 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.operation_name: str | None = None

def execute(self, context: Context):
self.log.info("Instantiating Inline Template")
Expand All @@ -1977,23 +1992,25 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
self.workflow_id = operation.operation.name.split("/")[-1]
operation_name = operation.operation.name
self.operation_name = operation_name
workflow_id = operation_name.split("/")[-1]
if project_id:
DataprocWorkflowLink.persist(
context=context,
operator=self,
workflow_id=self.workflow_id,
workflow_id=workflow_id,
region=self.region,
project_id=project_id,
)
if not self.deferrable:
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
operation.result()
self.log.info("Workflow %s completed successfully", self.workflow_id)
self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
name=operation.operation.name,
name=operation_name,
project_id=self.project_id or hook.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -2015,6 +2032,11 @@ def execute_complete(self, context, event=None) -> None:

self.log.info("Workflow %s completed successfully", event["operation_name"])

def on_kill(self) -> None:
if self.cancel_on_kill and self.operation_name:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)


class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
"""Submit a job to a cluster.
Expand Down
64 changes: 63 additions & 1 deletion tests/providers/google/cloud/operators/test_dataproc.py
Expand Up @@ -1399,7 +1399,7 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_
assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED


class TestDataprocWorkflowTemplateInstantiateOperator:
class TestDataprocInstantiateWorkflowTemplateOperator:
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
version = 6
Expand Down Expand Up @@ -1463,6 +1463,37 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
operation_name = "operation_name"
mock_hook.return_value.instantiate_workflow_template.return_value.operation.name = operation_name
op = DataprocInstantiateWorkflowTemplateOperator(
task_id=TASK_ID,
template_id=TEMPLATE_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
version=2,
parameters={},
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=False,
)

op.execute(context=mock.MagicMock())

op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()

op.cancel_on_kill = True
op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
name=operation_name
)


@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down Expand Up @@ -1561,6 +1592,37 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
operation_name = "operation_name"
mock_hook.return_value.instantiate_inline_workflow_template.return_value.operation.name = (
operation_name
)
op = DataprocInstantiateInlineWorkflowTemplateOperator(
task_id=TASK_ID,
template={},
region=GCP_REGION,
project_id=GCP_PROJECT,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=False,
)

op.execute(context=mock.MagicMock())

op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()

op.cancel_on_kill = True
op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
name=operation_name
)


@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down

0 comments on commit 0b49f33

Please sign in to comment.