Skip to content

Commit

Permalink
Add deferrable mode to DataprocCreateClusterOperator and DataprocUpda…
Browse files Browse the repository at this point in the history
…teClusterOperator (#28529)

Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
  • Loading branch information
bkossakowska and Beata Kossakowska committed Jan 25, 2023
1 parent b8f15a9 commit 9fd8013
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 22 deletions.
96 changes: 84 additions & 12 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -26,13 +26,13 @@
import uuid
import warnings
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from google.api_core import operation # type: ignore
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry, exponential_sleep_generator
from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask

Expand All @@ -50,7 +50,7 @@
DataprocLink,
DataprocListLink,
)
from airflow.providers.google.cloud.triggers.dataproc import DataprocBaseTrigger
from airflow.providers.google.cloud.triggers.dataproc import DataprocClusterTrigger, DataprocSubmitTrigger
from airflow.utils import timezone

if TYPE_CHECKING:
Expand Down Expand Up @@ -438,6 +438,8 @@ class DataprocCreateClusterOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -470,6 +472,8 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
polling_interval_seconds: int = 10,
**kwargs,
) -> None:

Expand Down Expand Up @@ -502,7 +506,8 @@ def __init__(
del kwargs[arg]

super().__init__(**kwargs)

if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.cluster_config = cluster_config
self.cluster_name = cluster_name
self.labels = labels
Expand All @@ -517,9 +522,11 @@ def __init__(
self.use_if_exists = use_if_exists
self.impersonation_chain = impersonation_chain
self.virtual_cluster_config = virtual_cluster_config
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def _create_cluster(self, hook: DataprocHook):
operation = hook.create_cluster(
return hook.create_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
Expand All @@ -531,9 +538,6 @@ def _create_cluster(self, hook: DataprocHook):
timeout=self.timeout,
metadata=self.metadata,
)
cluster = operation.result()
self.log.info("Cluster created.")
return cluster

def _delete_cluster(self, hook):
self.log.info("Deleting the cluster")
Expand Down Expand Up @@ -596,7 +600,25 @@ def execute(self, context: Context) -> dict:
)
try:
# First try to create a new cluster
cluster = self._create_cluster(hook)
operation = self._create_cluster(hook)
if not self.deferrable:
cluster = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.retry, operation=operation
)
self.log.info("Cluster created.")
return Cluster.to_dict(cluster)
else:
self.defer(
trigger=DataprocClusterTrigger(
cluster_name=self.cluster_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
except AlreadyExists:
if not self.use_if_exists:
raise
Expand All @@ -618,6 +640,21 @@ def execute(self, context: Context) -> dict:

return Cluster.to_dict(cluster)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
cluster_state = event["cluster_state"]
cluster_name = event["cluster_name"]

if cluster_state == ClusterStatus.State.ERROR:
raise AirflowException(f"Cluster is in ERROR state:\n{cluster_name}")

self.log.info("%s completed successfully.", self.task_id)
return event["cluster"]


class DataprocScaleClusterOperator(BaseOperator):
"""
Expand Down Expand Up @@ -974,7 +1011,7 @@ def execute(self, context: Context):

if self.deferrable:
self.defer(
trigger=DataprocBaseTrigger(
trigger=DataprocSubmitTrigger(
job_id=job_id,
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -1888,7 +1925,7 @@ def execute(self, context: Context):
self.job_id = new_job_id
if self.deferrable:
self.defer(
trigger=DataprocBaseTrigger(
trigger=DataprocSubmitTrigger(
job_id=self.job_id,
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -1964,6 +2001,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -1991,9 +2030,13 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
polling_interval_seconds: int = 10,
**kwargs,
):
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
Expand All @@ -2006,6 +2049,8 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand All @@ -2026,9 +2071,36 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
operation.result()

if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
else:
self.defer(
trigger=DataprocClusterTrigger(
cluster_name=self.cluster_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
self.log.info("Updated %s cluster.", self.cluster_name)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
cluster_state = event["cluster_state"]
cluster_name = event["cluster_name"]

if cluster_state == ClusterStatus.State.ERROR:
raise AirflowException(f"Cluster is in ERROR state:\n{cluster_name}")
self.log.info("%s completed successfully.", self.task_id)


class DataprocCreateBatchOperator(BaseOperator):
"""
Expand Down
68 changes: 64 additions & 4 deletions airflow/providers/google/cloud/triggers/dataproc.py
Expand Up @@ -20,16 +20,16 @@

import asyncio
import warnings
from typing import Sequence
from typing import Any, AsyncIterator, Sequence

from google.cloud.dataproc_v1 import JobStatus
from google.cloud.dataproc_v1 import ClusterStatus, JobStatus

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class DataprocBaseTrigger(BaseTrigger):
class DataprocSubmitTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify job status.
Implementation leverages asynchronous transport.
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger",
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger",
{
"job_id": self.job_id,
"project_id": self.project_id,
Expand All @@ -89,3 +89,63 @@ async def run(self):
raise AirflowException(f"Dataproc job execution failed {self.job_id}")
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state})


class DataprocClusterTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify status.
Implementation leverages asynchronous transport.
"""

def __init__(
self,
cluster_name: str,
region: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 10,
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.cluster_name = cluster_name
self.project_id = project_id
self.region = region
self.polling_interval_seconds = polling_interval_seconds

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger",
{
"cluster_name": self.cluster_name,
"project_id": self.project_id,
"region": self.region,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
hook = self._get_hook()
while True:
cluster = await hook.get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
state = cluster.status.state
self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state)
if state in (
ClusterStatus.State.ERROR,
ClusterStatus.State.RUNNING,
):
break
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})

def _get_hook(self) -> DataprocAsyncHook:
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
16 changes: 16 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
Expand Up @@ -75,6 +75,14 @@ With this configuration we can create the cluster:
:start-after: [START how_to_cloud_dataproc_create_cluster_operator_in_gke]
:end-before: [END how_to_cloud_dataproc_create_cluster_operator_in_gke]

You can use deferrable mode for this action in order to run the operator asynchronously:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_create_cluster_operator_async]
:end-before: [END how_to_cloud_dataproc_create_cluster_operator_async]

Generating Cluster Config
^^^^^^^^^^^^^^^^^^^^^^^^^
You can also generate **CLUSTER_CONFIG** using functional API,
Expand Down Expand Up @@ -111,6 +119,14 @@ To update a cluster you can use:
:start-after: [START how_to_cloud_dataproc_update_cluster_operator]
:end-before: [END how_to_cloud_dataproc_update_cluster_operator]

You can use deferrable mode for this action in order to run the operator asynchronously:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
:end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]

Deleting a cluster
------------------

Expand Down

0 comments on commit 9fd8013

Please sign in to comment.