Skip to content

Commit

Permalink
Add deferrable mode to DataprocInstantiateInlineWorkflowTemplateOpera…
Browse files Browse the repository at this point in the history
…tor (#30878)

Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
  • Loading branch information
bkossakowska and Beata Kossakowska committed Apr 27, 2023
1 parent e5d304a commit 0d95ace
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 29 deletions.
42 changes: 37 additions & 5 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -1780,7 +1780,6 @@ def execute(self, context: Context):
else:
self.defer(
trigger=DataprocWorkflowTrigger(
template_name=self.template_id,
name=operation.operation.name,
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -1843,6 +1842,8 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
"""

template_fields: Sequence[str] = ("template", "impersonation_chain")
Expand All @@ -1861,9 +1862,13 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
polling_interval_seconds: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.template = template
self.project_id = project_id
self.region = region
Expand All @@ -1874,13 +1879,15 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
self.log.info("Instantiating Inline Template")
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
operation = hook.instantiate_inline_workflow_template(
template=self.template,
project_id=self.project_id,
project_id=self.project_id or hook.project_id,
region=self.region,
request_id=self.request_id,
retry=self.retry,
Expand All @@ -1891,9 +1898,34 @@ def execute(self, context: Context):
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id
)
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
operation.result()
self.log.info("Workflow %s completed successfully", self.workflow_id)
if not self.deferrable:
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
operation.result()
self.log.info("Workflow %s completed successfully", self.workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
name=operation.operation.name,
project_id=self.project_id or hook.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "failed" or event["status"] == "error":
self.log.exception("Unexpected error in the operation.")
raise AirflowException(event["message"])

self.log.info("Workflow %s completed successfully", event["operation_name"])


class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/google/cloud/triggers/dataproc.py
Expand Up @@ -290,16 +290,14 @@ class DataprocWorkflowTrigger(DataprocBaseTrigger):
Implementation leverages asynchronous transport.
"""

def __init__(self, template_name: str, name: str, **kwargs: Any):
def __init__(self, name: str, **kwargs: Any):
super().__init__(**kwargs)
self.template_name = template_name
self.name = name

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger",
{
"template_name": self.template_name,
"name": self.name,
"project_id": self.project_id,
"region": self.region,
Expand Down
Expand Up @@ -264,7 +264,7 @@ Once a workflow is created users can trigger it using

Also for all this action you can use operator in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_workflow.py
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_workflow_deferrable.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_trigger_workflow_template_async]
Expand All @@ -279,6 +279,15 @@ The inline operator is an alternative. It creates a workflow, run it, and delete
:start-after: [START how_to_cloud_dataproc_instantiate_inline_workflow_template]
:end-before: [END how_to_cloud_dataproc_instantiate_inline_workflow_template]

Also for all this action you can use operator in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_workflow_deferrable.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_instantiate_inline_workflow_template_async]
:end-before: [END how_to_cloud_dataproc_instantiate_inline_workflow_template_async]


Create a Batch
--------------

Expand Down
27 changes: 27 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Expand Up @@ -1464,6 +1464,33 @@ def test_execute(self, mock_hook):
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
operator = DataprocInstantiateInlineWorkflowTemplateOperator(
task_id=TASK_ID,
template={},
region=GCP_REGION,
project_id=GCP_PROJECT,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc:
operator.execute(mock.MagicMock())

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)

mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once()

assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME


@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down
3 changes: 0 additions & 3 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Expand Up @@ -44,7 +44,6 @@
TEST_CLUSTER_NAME = "cluster_name"
TEST_POLL_INTERVAL = 5
TEST_GCP_CONN_ID = "google_cloud_default"
TEST_TEMPLATE_NAME = "template_name"
TEST_OPERATION_NAME = "name"


Expand Down Expand Up @@ -76,7 +75,6 @@ def batch_trigger():
@pytest.fixture
def workflow_trigger():
return DataprocWorkflowTrigger(
template_name=TEST_TEMPLATE_NAME,
name=TEST_OPERATION_NAME,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
Expand Down Expand Up @@ -291,7 +289,6 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, w
classpath, kwargs = workflow_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger"
assert kwargs == {
"template_name": TEST_TEMPLATE_NAME,
"name": TEST_OPERATION_NAME,
"project_id": TEST_PROJECT_ID,
"region": TEST_REGION,
Expand Down
Expand Up @@ -31,7 +31,7 @@

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
DAG_ID = "dataproc_workflow"
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")

REGION = "europe-west1"
CLUSTER_NAME = f"cluster-dataproc-workflow-{ENV_ID}"
Expand Down Expand Up @@ -83,28 +83,13 @@
)
# [END how_to_cloud_dataproc_trigger_workflow_template]

# [START how_to_cloud_dataproc_trigger_workflow_template_async]
trigger_workflow_async = DataprocInstantiateWorkflowTemplateOperator(
task_id="trigger_workflow_async",
region=REGION,
project_id=PROJECT_ID,
template_id=WORKFLOW_NAME,
deferrable=True,
)
# [END how_to_cloud_dataproc_trigger_workflow_template_async]

# [START how_to_cloud_dataproc_instantiate_inline_workflow_template]
instantiate_inline_workflow_template = DataprocInstantiateInlineWorkflowTemplateOperator(
task_id="instantiate_inline_workflow_template", template=WORKFLOW_TEMPLATE, region=REGION
)
# [END how_to_cloud_dataproc_instantiate_inline_workflow_template]

(
create_workflow_template
>> trigger_workflow
>> instantiate_inline_workflow_template
>> trigger_workflow_async
)
(create_workflow_template >> trigger_workflow >> instantiate_inline_workflow_template)

from tests.system.utils.watcher import watcher

Expand Down
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for Dataproc workflow operators.
"""
from __future__ import annotations

import os
from datetime import datetime

from airflow import models
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateWorkflowTemplateOperator,
DataprocInstantiateInlineWorkflowTemplateOperator,
DataprocInstantiateWorkflowTemplateOperator,
)

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
DAG_ID = "dataproc_workflow"
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")

REGION = "europe-west1"
CLUSTER_NAME = f"cluster-dataproc-workflow-{ENV_ID}"
CLUSTER_CONFIG = {
"master_config": {
"num_instances": 1,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
},
}
PIG_JOB = {"query_list": {"queries": ["define sin HiveUDF('sin');"]}}
WORKFLOW_NAME = "airflow-dataproc-test"
WORKFLOW_TEMPLATE = {
"id": WORKFLOW_NAME,
"placement": {
"managed_cluster": {
"cluster_name": CLUSTER_NAME,
"config": CLUSTER_CONFIG,
}
},
"jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB}],
}


with models.DAG(
DAG_ID,
schedule="@once",
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example", "dataproc"],
) as dag:
create_workflow_template = DataprocCreateWorkflowTemplateOperator(
task_id="create_workflow_template",
template=WORKFLOW_TEMPLATE,
project_id=PROJECT_ID,
region=REGION,
)

# [START how_to_cloud_dataproc_trigger_workflow_template_async]
trigger_workflow_async = DataprocInstantiateWorkflowTemplateOperator(
task_id="trigger_workflow_async",
region=REGION,
project_id=PROJECT_ID,
template_id=WORKFLOW_NAME,
deferrable=True,
)
# [END how_to_cloud_dataproc_trigger_workflow_template_async]

# [START how_to_cloud_dataproc_instantiate_inline_workflow_template_async]
instantiate_inline_workflow_template_async = DataprocInstantiateInlineWorkflowTemplateOperator(
task_id="instantiate_inline_workflow_template_async",
template=WORKFLOW_TEMPLATE,
region=REGION,
deferrable=True,
)
# [END how_to_cloud_dataproc_instantiate_inline_workflow_template_async]

(create_workflow_template >> trigger_workflow_async >> instantiate_inline_workflow_template_async)

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)

0 comments on commit 0d95ace

Please sign in to comment.