Skip to content

Commit

Permalink
Add deferrable mode in EMR operator and sensor (#32029)
Browse files Browse the repository at this point in the history
* Add deferrable mode in EMR operator and sensor

Add the deferrable param in EmrContainerOperator/EmrStepSensor/EmrJobFlowSensor.
This will allow running Operator/Sensors in an async way
that means we only submit a job from the worker to run a job
then defer to the trigger for polling and wait for a job the job status
and the worker slot won't be occupied for the whole period of task execution.
  • Loading branch information
pankajastro committed Jun 27, 2023
1 parent 3a85d4e commit 06b5a1e
Show file tree
Hide file tree
Showing 10 changed files with 382 additions and 30 deletions.
48 changes: 41 additions & 7 deletions airflow/providers/amazon/aws/operators/emr.py
Expand Up @@ -30,6 +30,7 @@
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
EmrTerminateJobFlowTrigger,
)
Expand Down Expand Up @@ -480,6 +481,7 @@ class EmrContainerOperator(BaseOperator):
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param tags: The tags assigned to job runs.
Defaults to None
:param deferrable: Run operator in the deferrable mode.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -508,6 +510,7 @@ def __init__(
max_tries: int | None = None,
tags: dict | None = None,
max_polling_attempts: int | None = None,
deferrable: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -524,6 +527,7 @@ def __init__(
self.max_polling_attempts = max_polling_attempts
self.tags = tags
self.job_id: str | None = None
self.deferrable = deferrable

if max_tries:
warnings.warn(
Expand Down Expand Up @@ -556,27 +560,57 @@ def execute(self, context: Context) -> str | None:
self.client_request_token,
self.tags,
)
if self.deferrable:
query_status = self.hook.check_query_status(job_id=self.job_id)
self.check_failure(query_status)
if query_status in EmrContainerHook.SUCCESS_STATES:
return self.job_id
timeout = (
timedelta(seconds=self.max_polling_attempts * self.poll_interval)
if self.max_polling_attempts
else self.execution_timeout
)
self.defer(
timeout=timeout,
trigger=EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
if self.wait_for_completion:
query_status = self.hook.poll_query_status(
self.job_id,
max_polling_attempts=self.max_polling_attempts,
poll_interval=self.poll_interval,
)

if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)
elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
self.check_failure(query_status)
if not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
raise AirflowException(
f"Final state of EMR Containers job is {query_status}. "
f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
)

return self.job_id

def check_failure(self, query_status):
if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")

self.log.info("%s", event["message"])
return event["job_id"]

def on_kill(self) -> None:
"""Cancel the submitted job run."""
if self.job_id:
Expand Down
68 changes: 64 additions & 4 deletions airflow/providers/amazon/aws/sensors/emr.py
Expand Up @@ -26,7 +26,11 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
from airflow.providers.amazon.aws.triggers.emr import (
EmrContainerTrigger,
EmrStepSensorTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -310,7 +314,7 @@ def execute(self, context: Context):
)
self.defer(
timeout=timeout,
trigger=EmrContainerSensorTrigger(
trigger=EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
Expand Down Expand Up @@ -406,9 +410,12 @@ class EmrJobFlowSensor(EmrBaseSensor):
:param job_flow_id: job_flow_id to check the state of
:param target_states: the target states, sensor waits until
job flow reaches any of these states
job flow reaches any of these states. In deferrable mode it would
run until reach the terminal state.
:param failed_states: the failure states, sensor fails when
job flow reaches any of these states
:param max_attempts: Maximum number of tries before failing
:param deferrable: Run sensor in the deferrable mode.
"""

template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
Expand All @@ -424,12 +431,16 @@ def __init__(
job_flow_id: str,
target_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
self.max_attempts = max_attempts
self.deferrable = deferrable

def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -488,6 +499,26 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None:
)
return None

def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.poke_interval * self.max_attempts),
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
poll_interval=int(self.poke_interval),
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
self.log.info("Job completed.")


class EmrStepSensor(EmrBaseSensor):
"""
Expand All @@ -503,9 +534,12 @@ class EmrStepSensor(EmrBaseSensor):
:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param target_states: the target states, sensor waits until
step reaches any of these states
step reaches any of these states. In case of deferrable sensor it will
for reach to terminal state
:param failed_states: the failure states, sensor fails when
step reaches any of these states
:param max_attempts: Maximum number of tries before failing
:param deferrable: Run sensor in the deferrable mode.
"""

template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states")
Expand All @@ -522,13 +556,17 @@ def __init__(
step_id: str,
target_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.step_id = step_id
self.target_states = target_states or ["COMPLETED"]
self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]
self.max_attempts = max_attempts
self.deferrable = deferrable

def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -587,3 +625,25 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None:
f"with message {fail_details.get('Message')} and log file {fail_details.get('LogFile')}"
)
return None

def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.max_attempts * self.poke_interval),
trigger=EmrStepSensorTrigger(
job_flow_id=self.job_flow_id,
step_id=self.step_id,
aws_conn_id=self.aws_conn_id,
max_attempts=self.max_attempts,
poke_interval=int(self.poke_interval),
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")

self.log.info("Job completed.")
67 changes: 64 additions & 3 deletions airflow/providers/amazon/aws/triggers/emr.py
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict

Expand Down Expand Up @@ -249,7 +250,7 @@ async def run(self):
)


class EmrContainerSensorTrigger(BaseTrigger):
class EmrContainerTrigger(BaseTrigger):
"""
Poll for the status of EMR container until reaches terminal state.
Expand Down Expand Up @@ -278,9 +279,9 @@ def hook(self) -> EmrContainerHook:
return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
"""Serializes EmrContainerTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger",
{
"virtual_cluster_id": self.virtual_cluster_id,
"job_id": self.job_id,
Expand Down Expand Up @@ -317,3 +318,63 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
await asyncio.sleep(int(self.poll_interval))

yield TriggerEvent({"status": "success", "job_id": self.job_id})


class EmrStepSensorTrigger(BaseTrigger):
"""
Poll for the status of EMR container until reaches terminal state.
:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param aws_conn_id: Reference to AWS connection id
:param max_attempts: The maximum number of attempts to be made
:param poke_interval: polling period in seconds to check for the status
"""

def __init__(
self,
job_flow_id: str,
step_id: str,
aws_conn_id: str = "aws_default",
max_attempts: int = 60,
poke_interval: int = 30,
**kwargs: Any,
):
self.job_flow_id = job_flow_id
self.step_id = step_id
self.aws_conn_id = aws_conn_id
self.max_attempts = max_attempts
self.poke_interval = poke_interval
super().__init__(**kwargs)

@cached_property
def hook(self) -> EmrHook:
return EmrHook(self.aws_conn_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger",
{
"job_flow_id": self.job_flow_id,
"step_id": self.step_id,
"aws_conn_id": self.aws_conn_id,
"max_attempts": self.max_attempts,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:

async with self.hook.async_conn as client:
waiter = client.get_waiter("step_wait_for_terminal", deferrable=True, client=client)
await async_wait(
waiter=waiter,
waiter_delay=self.poke_interval,
waiter_max_attempts=self.max_attempts,
args={"ClusterId": self.job_flow_id, "StepId": self.step_id},
failure_message=f"Error while waiting for step {self.step_id} to complete",
status_message=f"Step id: {self.step_id}, Step is still in non-terminal state",
status_args=["Step.Status.State"],
)

yield TriggerEvent({"status": "success"})
31 changes: 31 additions & 0 deletions airflow/providers/amazon/aws/waiters/emr.json
Expand Up @@ -94,6 +94,37 @@
"state": "failure"
}
]
},
"step_wait_for_terminal": {
"operation": "DescribeStep",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "COMPLETED",
"state": "success"
},
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "CANCELLED",
"state": "failure"
},
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "FAILED",
"state": "failure"
},
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "INTERRUPTED",
"state": "failure"
}
]
}
}
}
8 changes: 7 additions & 1 deletion tests/providers/amazon/aws/hooks/test_emr.py
Expand Up @@ -32,7 +32,13 @@ class TestEmrHook:
def test_service_waiters(self):
hook = EmrHook(aws_conn_id=None)
official_waiters = hook.conn.waiter_names
custom_waiters = ["job_flow_waiting", "job_flow_terminated", "notebook_running", "notebook_stopped"]
custom_waiters = [
"job_flow_waiting",
"job_flow_terminated",
"notebook_running",
"notebook_stopped",
"step_wait_for_terminal",
]

assert sorted(hook.list_waiters()) == sorted([*official_waiters, *custom_waiters])

Expand Down

0 comments on commit 06b5a1e

Please sign in to comment.