Skip to content

Commit

Permalink
Dataproc submit job operator async (#25302)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjankie1 committed Aug 22, 2022
1 parent bc04c5f commit ecf0460
Show file tree
Hide file tree
Showing 7 changed files with 1,477 additions and 4 deletions.
746 changes: 746 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py

Large diffs are not rendered by default.

76 changes: 74 additions & 2 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry, exponential_sleep_generator
from google.cloud.dataproc_v1 import Batch, Cluster
from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask

Expand All @@ -50,6 +50,7 @@
DataprocLink,
DataprocListLink,
)
from airflow.providers.google.cloud.triggers.dataproc import DataprocBaseTrigger
from airflow.utils import timezone

if TYPE_CHECKING:
Expand Down Expand Up @@ -867,6 +868,9 @@ class DataprocJobBaseOperator(BaseOperator):
:param asynchronous: Flag to return after submitting the job to the Dataproc API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
:param deferrable: Run operator in the deferrable mode
:param polling_interval_seconds: time in seconds between polling for job completion.
The value is considered only when running in deferrable mode. Must be greater than 0.
:var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
This is useful for identifying or linking to the job in the Google Cloud Console
Expand Down Expand Up @@ -894,9 +898,13 @@ def __init__(
job_error_states: Optional[Set[str]] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
deferrable: bool = False,
polling_interval_seconds: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.labels = labels
Expand All @@ -914,6 +922,8 @@ def __init__(
self.job: Optional[dict] = None
self.dataproc_job_id = None
self.asynchronous = asynchronous
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def create_job_template(self) -> DataProcJobBuilder:
"""Initialize `self.job_template` with default values"""
Expand Down Expand Up @@ -958,6 +968,19 @@ def execute(self, context: 'Context'):
context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id
)

if self.deferrable:
self.defer(
trigger=DataprocBaseTrigger(
job_id=job_id,
project_id=self.project_id,
region=self.region,
delegate_to=self.delegate_to,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
self.hook.wait_for_job(job_id=job_id, region=self.region, project_id=self.project_id)
Expand All @@ -966,6 +989,20 @@ def execute(self, context: 'Context'):
else:
raise AirflowException("Create a job template before")

def execute_complete(self, context, event=None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
job_state = event["job_state"]
job_id = event["job_id"]
if job_state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job_id}')
if job_state == JobStatus.State.CANCELLED:
raise AirflowException(f'Job was cancelled:\n{job_id}')
self.log.info("%s completed successfully.", self.task_id)

def on_kill(self) -> None:
"""
Callback called when the operator is killed.
Expand Down Expand Up @@ -1771,6 +1808,9 @@ class DataprocSubmitJobOperator(BaseOperator):
:param asynchronous: Flag to return after submitting the job to the Dataproc API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
:param deferrable: Run operator in the deferrable mode
:param polling_interval_seconds: time in seconds between polling for job completion.
The value is considered only when running in deferrable mode. Must be greater than 0.
:param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
:param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False
"""
Expand All @@ -1793,11 +1833,15 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
deferrable: bool = False,
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
wait_timeout: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.project_id = project_id
self.region = region
self.job = job
Expand All @@ -1808,6 +1852,8 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.asynchronous = asynchronous
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.hook: Optional[DataprocHook] = None
self.job_id: Optional[str] = None
Expand All @@ -1833,7 +1879,19 @@ def execute(self, context: 'Context'):
)

self.job_id = new_job_id
if not self.asynchronous:
if self.deferrable:
self.defer(
trigger=DataprocBaseTrigger(
job_id=self.job_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
elif not self.asynchronous:
self.log.info('Waiting for job %s to complete', new_job_id)
self.hook.wait_for_job(
job_id=new_job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
Expand All @@ -1842,6 +1900,20 @@ def execute(self, context: 'Context'):

return self.job_id

def execute_complete(self, context, event=None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
job_state = event["job_state"]
job_id = event["job_id"]
if job_state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job_id}')
if job_state == JobStatus.State.CANCELLED:
raise AirflowException(f'Job was cancelled:\n{job_id}')
self.log.info("%s completed successfully.", self.task_id)

def on_kill(self):
if self.job_id and self.cancel_on_kill:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, region=self.region)
Expand Down
86 changes: 86 additions & 0 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""This module contains Google Dataproc triggers."""

import asyncio
from typing import Optional, Sequence, Union

from google.cloud.dataproc_v1 import JobStatus

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class DataprocBaseTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify job status.
Implementation leverages asynchronous transport.
"""

def __init__(
self,
job_id: str,
region: str,
project_id: Optional[str] = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
delegate_to: Optional[str] = None,
polling_interval_seconds: int = 30,
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.job_id = job_id
self.project_id = project_id
self.region = region
self.polling_interval_seconds = polling_interval_seconds
self.delegate_to = delegate_to
self.hook = DataprocAsyncHook(
delegate_to=self.delegate_to,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger",
{
"job_id": self.job_id,
"project_id": self.project_id,
"region": self.region,
"gcp_conn_id": self.gcp_conn_id,
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
},
)

async def run(self):
while True:
job = await self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
state = job.status.state
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
if state in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED):
break
elif state == JobStatus.State.ERROR:
raise AirflowException(f"Dataproc job execution failed {self.job_id}")
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state})
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ Example of the configuration for a Spark Job:
:start-after: [START how_to_cloud_dataproc_spark_config]
:end-before: [END how_to_cloud_dataproc_spark_config]

Example of the configuration for a Spark Job running in `deferrable mode <https://airflow.apache.org/docs/apache-airflow/stable/concepts/deferring.html>`__:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py
:language: python
:dedent: 0
:start-after: [START how_to_cloud_dataproc_spark_deferrable_config]
:end-before: [END how_to_cloud_dataproc_spark_deferrable_config]

Example of the configuration for a Hive Job:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_hive.py
Expand Down
Loading

0 comments on commit ecf0460

Please sign in to comment.