Skip to content

Commit

Permalink
Add retry configuration in EmrContainerOperator (#37426)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Feb 14, 2024
1 parent 56c27f8 commit f91c93c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr.py
Expand Up @@ -383,6 +383,7 @@ def submit_job(
configuration_overrides: dict | None = None,
client_request_token: str | None = None,
tags: dict | None = None,
retry_max_attempts: int | None = None,
) -> str:
"""
Submit a job to the EMR Containers API and return the job ID.
Expand All @@ -402,6 +403,7 @@ def submit_job(
:param client_request_token: The client idempotency token of the job run request.
Use this if you want to specify a unique ID to prevent two jobs from getting started.
:param tags: The tags assigned to job runs.
:param retry_max_attempts: The maximum number of attempts on the job's driver.
:return: The ID of the job run request.
"""
params = {
Expand All @@ -415,6 +417,10 @@ def submit_job(
}
if client_request_token:
params["clientToken"] = client_request_token
if retry_max_attempts:
params["retryPolicyConfiguration"] = {
"maxAttempts": retry_max_attempts,
}

response = self.conn.start_job_run(**params)

Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Expand Up @@ -503,6 +503,8 @@ class EmrContainerOperator(BaseOperator):
:param max_tries: Deprecated - use max_polling_attempts instead.
:param max_polling_attempts: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param job_retry_max_attempts: Maximum number of times to retry when the EMR job fails.
Defaults to None, which disable the retry.
:param tags: The tags assigned to job runs.
Defaults to None
:param deferrable: Run operator in the deferrable mode.
Expand Down Expand Up @@ -534,6 +536,7 @@ def __init__(
max_tries: int | None = None,
tags: dict | None = None,
max_polling_attempts: int | None = None,
job_retry_max_attempts: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
Expand All @@ -549,6 +552,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
self.max_polling_attempts = max_polling_attempts
self.job_retry_max_attempts = job_retry_max_attempts
self.tags = tags
self.job_id: str | None = None
self.deferrable = deferrable
Expand Down Expand Up @@ -583,6 +587,7 @@ def execute(self, context: Context) -> str | None:
self.configuration_overrides,
self.client_request_token,
self.tags,
self.job_retry_max_attempts,
)
if self.deferrable:
query_status = self.hook.check_query_status(job_id=self.job_id)
Expand Down
Expand Up @@ -71,7 +71,7 @@ def test_execute_without_failure(
self.emr_container.execute(None)

mock_submit_job.assert_called_once_with(
"test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, GENERATED_UUID, {}
"test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, GENERATED_UUID, {}, None
)
mock_check_query_status.assert_called_once_with("jobid_123456")
assert self.emr_container.release_label == "6.3.0-latest"
Expand Down
1 change: 1 addition & 0 deletions tests/system/providers/amazon/aws/example_emr_eks.py
Expand Up @@ -282,6 +282,7 @@ def delete_virtual_cluster(virtual_cluster_id):
)
# [END howto_operator_emr_container]
job_starter.wait_for_completion = False
job_starter.job_retry_max_attempts = 5

# [START howto_sensor_emr_container]
job_waiter = EmrContainerSensor(
Expand Down

0 comments on commit f91c93c

Please sign in to comment.