Skip to content

Commit

Permalink
Add Deferrable option to EmrCreateJobFlowOperator (#31641)
Browse files Browse the repository at this point in the history
* Update documentation to include deferrable mode

* Add doc strings to unit tests

* Rebase onto main
Remove caching of hook in trigger
raise Exception directly from Trigger
  • Loading branch information
syedahsn committed Jun 14, 2023
1 parent 212a37f commit 6720456
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 9 deletions.
35 changes: 29 additions & 6 deletions airflow/providers/amazon/aws/operators/emr.py
Expand Up @@ -19,6 +19,7 @@

import ast
import warnings
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from uuid import uuid4
Expand All @@ -27,7 +28,7 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger
from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger, EmrCreateJobFlowTrigger
from airflow.providers.amazon.aws.utils.waiter import waiter
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -624,6 +625,9 @@ class EmrCreateJobFlowOperator(BaseOperator):
wait_for_completion=True, None = no limit) (Deprecated. Please use waiter_max_attempts.)
:param waiter_check_interval_seconds: Number of seconds between polling the jobflow state. Defaults to 60
seconds. (Deprecated. Please use waiter_delay.)
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -652,6 +656,7 @@ def __init__(
waiter_delay: int | None | ArgNotSet = NOTSET,
waiter_countdown: int | None = None,
waiter_check_interval_seconds: int = 60,
deferrable: bool = False,
**kwargs: Any,
):
if waiter_max_attempts is NOTSET:
Expand All @@ -676,10 +681,9 @@ def __init__(
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.waiter_max_attempts = waiter_max_attempts
self.waiter_delay = waiter_delay

self._job_flow_id: str | None = None
self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.deferrable = deferrable

@cached_property
def _emr_hook(self) -> EmrHook:
Expand Down Expand Up @@ -720,7 +724,19 @@ def execute(self, context: Context) -> str | None:
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self._emr_hook.conn, job_flow_id=self._job_flow_id),
)

if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
ClusterId=self._job_flow_id,
Expand All @@ -734,6 +750,13 @@ def execute(self, context: Context) -> str | None:

return self._job_flow_id

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error creating jobFlow: {event}")
else:
self.log.info("JobFlow created successfully")
return event["job_flow_id"]

def on_kill(self) -> None:
"""
Terminate the EMR cluster (job flow). If TerminationProtected=True on the cluster,
Expand Down
76 changes: 76 additions & 0 deletions airflow/providers/amazon/aws/triggers/emr.py
Expand Up @@ -21,8 +21,10 @@

from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict


class EmrAddStepsTrigger(BaseTrigger):
Expand Down Expand Up @@ -97,3 +99,77 @@ async def run(self):
yield TriggerEvent({"status": "failure", "message": "Steps failed: max attempts reached"})
else:
yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids})


class EmrCreateJobFlowTrigger(BaseTrigger):
"""
Trigger for EmrCreateJobFlowOperator.
The trigger will asynchronously poll the boto3 API and wait for the
JobFlow to finish executing.
:param job_flow_id: The id of the job flow to wait for.
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
job_flow_id: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.job_flow_id = job_flow_id
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"job_flow_id": self.job_flow_id,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

async def run(self):
self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
async with self.hook.async_conn as client:
attempt = 0
waiter = self.hook.get_waiter("job_flow_waiting", deferrable=True, client=client)
while attempt < int(self.max_attempts):
attempt = attempt + 1
try:
await waiter.wait(
ClusterId=self.job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.poll_interval,
"MaxAttempts": 1,
}
),
)
break
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"JobFlow creation failed: {error}")
self.log.info(
"Status of jobflow is %s - %s",
error.last_response["Cluster"]["Status"]["State"],
error.last_response["Cluster"]["Status"]["StateChangeReason"],
)
await asyncio.sleep(int(self.poll_interval))
if attempt >= int(self.max_attempts):
raise AirflowException(f"JobFlow creation failed - max attempts reached: {self.max_attempts}")
else:
yield TriggerEvent(
{
"status": "success",
"message": "JobFlow completed successfully",
"job_flow_id": self.job_flow_id,
}
)
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/waiters/emr.json
Expand Up @@ -58,19 +58,19 @@
"acceptors": [
{
"matcher": "path",
"argument": "cluster.status",
"argument": "Cluster.Status.State",
"expected": "WAITING",
"state": "success"
},
{
"matcher": "path",
"argument": "cluster.status",
"argument": "Cluster.Status.State",
"expected": "TERMINATED",
"state": "success"
},
{
"matcher": "path",
"argument": "cluster.status",
"argument": "Cluster.Status.State",
"expected": "TERMINATED_WITH_ERRORS",
"state": "failure"
}
Expand Down
4 changes: 4 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/emr/emr.rst
Expand Up @@ -47,6 +47,10 @@ Create an EMR job flow

You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator` to
create a new EMR job flow. The cluster will be terminated automatically after finishing the steps.
This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter.
Using ``deferrable`` mode will release worker slots and leads to efficient utilization of
resources within Airflow cluster.However this mode will need the Airflow triggerer to be
available in your deployment.

JobFlow configuration
"""""""""""""""""""""
Expand Down
28 changes: 28 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
Expand Up @@ -22,12 +22,15 @@
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
from botocore.waiter import Waiter
from jinja2 import StrictUndefined

from airflow.exceptions import TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
from airflow.utils import timezone
from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type
from tests.test_utils import AIRFLOW_MAIN_FOLDER
Expand Down Expand Up @@ -192,3 +195,28 @@ def test_execute_with_wait(self, mock_waiter, *_):
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
assert_expected_waiter_type(mock_waiter, "job_flow_waiting")

@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_create_job_flow_deferrable(self, _):
"""
Test to make sure that the operator raises a TaskDeferred exception
if run in deferrable mode.
"""
self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Mock out the emr_client creator
emr_session_mock = MagicMock()
emr_session_mock.client.return_value = self.emr_client_mock
boto3_session_mock = MagicMock(return_value=emr_session_mock)

self.operator.deferrable = True
with patch("boto3.session.Session", boto3_session_mock), patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)

assert isinstance(
exc.value.trigger, EmrCreateJobFlowTrigger
), "Trigger is not a EmrCreateJobFlowTrigger"

0 comments on commit 6720456

Please sign in to comment.