diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 7cf2cbfeacd17..e917d1d81d8d3 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -112,6 +112,7 @@ class BatchOperator(BaseOperator): "array_properties", "node_overrides", "parameters", + "retry_strategy", "waiters", "tags", "wait_for_completion", @@ -122,6 +123,7 @@ class BatchOperator(BaseOperator): "container_overrides": "json", "parameters": "json", "node_overrides": "json", + "retry_strategy": "json", } @property @@ -160,6 +162,7 @@ def __init__( share_identifier: str | None = None, scheduling_priority_override: int | None = None, parameters: dict | None = None, + retry_strategy: dict | None = None, job_id: str | None = None, waiters: Any | None = None, max_retries: int = 4200, @@ -201,6 +204,7 @@ def __init__( self.scheduling_priority_override = scheduling_priority_override self.array_properties = array_properties self.parameters = parameters or {} + self.retry_strategy = retry_strategy or {} self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion @@ -287,6 +291,7 @@ def submit_job(self, context: Context): "tags": self.tags, "containerOverrides": self.container_overrides, "nodeOverrides": self.node_overrides, + "retryStrategy": self.retry_strategy, "shareIdentifier": self.share_identifier, "schedulingPriorityOverride": self.scheduling_priority_override, } diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 8eb6601dfd9f7..8a0d0e788a197 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -63,6 +63,7 @@ def setup_method(self, _, get_client_type_mock): max_retries=self.MAX_RETRIES, status_retries=self.STATUS_RETRIES, parameters=None, + retry_strategy=None, container_overrides={}, array_properties=None, aws_conn_id="airflow_test", @@ -96,6 +97,7 @@ def test_init(self): assert self.batch.hook.max_retries == self.MAX_RETRIES assert self.batch.hook.status_retries == self.STATUS_RETRIES assert self.batch.parameters == {} + assert self.batch.retry_strategy == {} assert self.batch.container_overrides == {} assert self.batch.array_properties is None assert self.batch.node_overrides is None @@ -119,6 +121,7 @@ def test_template_fields_overrides(self): "array_properties", "node_overrides", "parameters", + "retry_strategy", "waiters", "tags", "wait_for_completion", @@ -143,6 +146,7 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m containerOverrides={}, jobDefinition="hello-world", parameters={}, + retryStrategy={}, tags={}, ) @@ -166,6 +170,7 @@ def test_execute_with_failures(self): containerOverrides={}, jobDefinition="hello-world", parameters={}, + retryStrategy={}, tags={}, ) @@ -232,6 +237,7 @@ def test_override_not_sent_if_not_set(self, client_mock, override): "jobName": JOB_NAME, "jobDefinition": "hello-world", "parameters": {}, + "retryStrategy": {}, "tags": {}, } if override == "overrides":