Skip to content

Commit

Permalink
check job_status before BatchOperator execute in deferrable mode (#36523
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Lee-W committed Jan 10, 2024
1 parent 6bd450d commit 88c9596
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 17 deletions.
48 changes: 33 additions & 15 deletions airflow/providers/amazon/aws/operators/batch.py
Expand Up @@ -230,36 +230,54 @@ def hook(self) -> BatchClientHook:
region_name=self.region_name,
)

def execute(self, context: Context):
def execute(self, context: Context) -> str | None:
"""Submit and monitor an AWS Batch job.
:raises: AirflowException
"""
self.submit_job(context)

if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=BatchJobTrigger(
job_id=self.job_id,
waiter_max_attempts=self.max_retries,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

job = self.hook.get_job_description(self.job_id)
job_status = job.get("status")
if job_status == self.hook.SUCCESS_STATE:
self.log.info("Job completed.")
return self.job_id
elif job_status == self.hook.FAILURE_STATE:
raise AirflowException(f"Error while running job: {self.job_id} is in {job_status} state")
elif job_status in self.hook.INTERMEDIATE_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=BatchJobTrigger(
job_id=self.job_id,
waiter_max_attempts=self.max_retries,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)

raise AirflowException(f"Unexpected status: {job_status}")

if self.wait_for_completion:
self.monitor_job(context)

return self.job_id

def execute_complete(self, context, event=None):
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
if event is None:
err_msg = "Trigger error: event is None"
self.log.info(err_msg)
raise AirflowException(err_msg)

if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info("Job completed.")

self.log.info("Job completed.")
return event["job_id"]

def on_kill(self):
Expand Down
63 changes: 61 additions & 2 deletions tests/providers/amazon/aws/operators/test_batch.py
Expand Up @@ -268,8 +268,11 @@ def test_cant_set_old_and_new_override_param(self):
container_overrides={"a": "b"},
)

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def test_defer_if_deferrable_param_set(self, mock_client):
def test_defer_if_deferrable_param_set(self, mock_client, mock_get_job_description):
mock_get_job_description.return_value = {"status": "SUBMITTED"}

batch = BatchOperator(
task_id="task",
job_name=JOB_NAME,
Expand All @@ -280,9 +283,65 @@ def test_defer_if_deferrable_param_set(self, mock_client):
)

with pytest.raises(TaskDeferred) as exc:
batch.execute(context=None)
batch.execute(self.mock_context)
assert isinstance(exc.value.trigger, BatchJobTrigger)

@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def test_defer_but_failed_due_to_job_id_not_found(self, mock_client):
"""Test that an AirflowException is raised if job_id is not set before deferral."""
mock_client.return_value.submit_job.return_value = {
"jobName": JOB_NAME,
"jobId": None,
}

batch = BatchOperator(
task_id="task",
job_name=JOB_NAME,
job_queue="queue",
job_definition="hello-world",
do_xcom_push=False,
deferrable=True,
)
with pytest.raises(AirflowException) as exc:
batch.execute(self.mock_context)
assert "AWS Batch job - job_id was not found" in str(exc.value)

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def test_defer_but_success_before_deferred(self, mock_client, mock_get_job_description):
"""Test that an AirflowException is raised if job_id is not set before deferral."""
mock_client.return_value.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
mock_get_job_description.return_value = {"status": "SUCCEEDED"}

batch = BatchOperator(
task_id="task",
job_name=JOB_NAME,
job_queue="queue",
job_definition="hello-world",
do_xcom_push=False,
deferrable=True,
)
assert batch.execute(self.mock_context) == JOB_ID

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def test_defer_but_fail_before_deferred(self, mock_client, mock_get_job_description):
"""Test that an AirflowException is raised if job_id is not set before deferral."""
mock_client.return_value.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
mock_get_job_description.return_value = {"status": "FAILED"}

batch = BatchOperator(
task_id="task",
job_name=JOB_NAME,
job_queue="queue",
job_definition="hello-world",
do_xcom_push=False,
deferrable=True,
)
with pytest.raises(AirflowException) as exc:
batch.execute(self.mock_context)
assert f"Error while running job: {JOB_ID} is in FAILED state" in str(exc.value)

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch.object(BatchClientHook, "wait_for_job")
@mock.patch.object(BatchClientHook, "check_job_success")
Expand Down

0 comments on commit 88c9596

Please sign in to comment.