Skip to content

Commit

Permalink
Get failure information on EMR job failure (#32151)
Browse files Browse the repository at this point in the history
* Use util wait method in wait_for_completion.
* Add logs to display failure reason if EMR Job fails
* Fix waiter parameters, use FailureDetails instead of StateChangeReason
* Only log failure details if it is inlcuded in the response
  • Loading branch information
syedahsn committed Aug 8, 2023
1 parent ad9d8d4 commit 8bbea92
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 15 deletions.
32 changes: 21 additions & 11 deletions airflow/providers/amazon/aws/hooks/emr.py
Expand Up @@ -26,7 +26,7 @@

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import prune_dict
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait


class EmrHook(AwsBaseHook):
Expand Down Expand Up @@ -158,6 +158,9 @@ def add_job_flow_steps(
:param execution_role_arn: The ARN of the runtime role for a step on the cluster.
"""
config = {}
waiter_delay = waiter_delay or 30
waiter_max_attempts = waiter_max_attempts or 60

if execution_role_arn:
config["ExecutionRoleArn"] = execution_role_arn
response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps, **config)
Expand All @@ -169,16 +172,23 @@ def add_job_flow_steps(
if wait_for_completion:
waiter = self.get_conn().get_waiter("step_complete")
for step_id in response["StepIds"]:
waiter.wait(
ClusterId=job_flow_id,
StepId=step_id,
WaiterConfig=prune_dict(
{
"Delay": waiter_delay,
"MaxAttempts": waiter_max_attempts,
}
),
)
try:
wait(
waiter=waiter,
waiter_max_attempts=waiter_max_attempts,
waiter_delay=waiter_delay,
args={"ClusterId": job_flow_id, "StepId": step_id},
failure_message=f"EMR Steps failed: {step_id}",
status_message="EMR Step status is",
status_args=["Step.Status.State", "Step.Status.StateChangeReason"],
)
except AirflowException as ex:
if "EMR Steps failed" in str(ex):
resp = self.get_conn().describe_step(ClusterId=job_flow_id, StepId=step_id)
failure_details = resp["Step"]["Status"].get("FailureDetails", None)
if failure_details:
self.log.error("EMR Steps failed: %s", failure_details)
raise
return response["StepIds"]

def test_connection(self):
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/emr.py
Expand Up @@ -100,8 +100,8 @@ def __init__(
aws_conn_id: str = "aws_default",
steps: list[dict] | str | None = None,
wait_for_completion: bool = False,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int | None = 30,
waiter_max_attempts: int | None = 60,
execution_role_arn: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
Expand Down
38 changes: 38 additions & 0 deletions tests/providers/amazon/aws/hooks/test_emr.py
Expand Up @@ -22,6 +22,7 @@

import boto3
import pytest
from botocore.exceptions import WaiterError
from moto import mock_emr

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -113,6 +114,43 @@ def test_add_job_flow_steps_wait_for_completion(self, mock_conn):

mock_conn.get_waiter.assert_called_once_with("step_complete")

@mock.patch("time.sleep", return_value=True)
@mock.patch.object(EmrHook, "conn")
def test_add_job_flow_steps_raises_exception_on_failure(self, mock_conn, mock_sleep, caplog):
hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default", region_name="us-east-1")
mock_conn.describe_step.return_value = {
"Step": {
"Status": {
"State": "FAILED",
"FailureDetails": "test failure details",
}
}
}
mock_conn.add_job_flow_steps.return_value = {
"StepIds": [
"step_id",
],
"ResponseMetadata": {"HTTPStatusCode": 200},
}
steps = [
{
"ActionOnFailure": "test_step",
"HadoopJarStep": {
"Args": ["test args"],
"Jar": "test.jar",
},
"Name": "step_1",
}
]
waiter_error = WaiterError(name="test_error", reason="test_reason", last_response={})
waiter_error_failure = WaiterError(name="test_error", reason="terminal failure", last_response={})
mock_conn.get_waiter().wait.side_effect = [waiter_error, waiter_error_failure]

with pytest.raises(AirflowException):
hook.add_job_flow_steps(job_flow_id="job_flow_id", steps=steps, wait_for_completion=True)
assert "test failure details" in caplog.messages[-1]
mock_conn.get_waiter.assert_called_with("step_complete")

@mock_emr
def test_create_job_flow_extra_args(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/operators/test_emr_add_steps.py
Expand Up @@ -241,8 +241,8 @@ def test_wait_for_completion(self, mock_add_job_flow_steps, *_):
job_flow_id=job_flow_id,
steps=[],
wait_for_completion=False,
waiter_delay=None,
waiter_max_attempts=None,
waiter_delay=30,
waiter_max_attempts=60,
execution_role_arn=None,
)

Expand Down

0 comments on commit 8bbea92

Please sign in to comment.