Skip to content

Commit

Permalink
Add deferrable mode to BatchSensor (#30279)
Browse files Browse the repository at this point in the history
* Implement BatchAsyncSensor
  • Loading branch information
phanikumv committed Jun 14, 2023
1 parent 4e73e47 commit 688f91b
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 3 deletions.
44 changes: 43 additions & 1 deletion airflow/providers/amazon/aws/sensors/batch.py
Expand Up @@ -16,13 +16,15 @@
# under the License.
from __future__ import annotations

from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand All @@ -41,6 +43,10 @@ class BatchSensor(BaseSensorOperator):
:param job_id: Batch job_id to check the state for
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param region_name: aws region name associated with the client
:param deferrable: Run sensor in the deferrable mode.
:param poke_interval: polling period in seconds to check for the status of the job.
:param max_retries: Number of times to poll for job state before
returning the current state.
"""

template_fields: Sequence[str] = ("job_id",)
Expand All @@ -53,12 +59,18 @@ def __init__(
job_id: str,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
deferrable: bool = False,
poke_interval: float = 5,
max_retries: int = 5,
**kwargs,
):
super().__init__(**kwargs)
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.deferrable = deferrable
self.poke_interval = poke_interval
self.max_retries = max_retries

def poke(self, context: Context) -> bool:
job_description = self.hook.get_job_description(self.job_id)
Expand All @@ -75,6 +87,36 @@ def poke(self, context: Context) -> bool:

raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}")

def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
else:
timeout = (
timedelta(seconds=self.max_retries * self.poke_interval + 60)
if self.max_retries
else self.execution_timeout
)
self.defer(
timeout=timeout,
trigger=BatchSensorTrigger(
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "failure":
raise AirflowException(event["message"])
self.log.info(event["message"])

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> BatchClientHook:
"""Create and return a BatchClientHook."""
Expand Down
83 changes: 83 additions & 0 deletions airflow/providers/amazon/aws/triggers/batch.py
Expand Up @@ -105,3 +105,86 @@ async def run(self):
yield TriggerEvent({"status": "failure", "message": "Job Failed - max attempts reached."})
else:
yield TriggerEvent({"status": "success", "job_id": self.job_id})


class BatchSensorTrigger(BaseTrigger):
"""
Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state.
BatchSensorTrigger is fired as deferred class with params to poll the job state in Triggerer.
:param job_id: the job ID, to poll for job completion or not
:param region_name: AWS region name to use
Override the region_name in connection (if provided)
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
:param poke_interval: polling period in seconds to check for the status of the job
"""

def __init__(
self,
job_id: str,
region_name: str | None,
aws_conn_id: str | None = "aws_default",
poke_interval: float = 5,
):
super().__init__()
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poke_interval = poke_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BatchSensorTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
{
"job_id": self.job_id,
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"poke_interval": self.poke_interval,
},
)

@cached_property
def hook(self) -> BatchClientHook:
return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

async def run(self):
"""
Make async connection using aiobotocore library to AWS Batch,
periodically poll for the Batch job status.
The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
"""
async with self.hook.async_conn as client:
waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client)
attempt = 0
while True:
attempt = attempt + 1
try:
await waiter.wait(
jobs=[self.job_id],
WaiterConfig={
"Delay": int(self.poke_interval),
"MaxAttempts": 1,
},
)
break
except WaiterError as error:
if "error" in str(error):
yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"})
break
self.log.info(
"Job response is %s. Retrying attempt %s",
error.last_response["Error"]["Message"],
attempt,
)
await asyncio.sleep(int(self.poke_interval))

yield TriggerEvent(
{
"status": "success",
"job_id": self.job_id,
"message": f"Job {self.job_id} Succeeded",
}
)
9 changes: 9 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/batch.rst
Expand Up @@ -77,6 +77,15 @@ use :class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor`.
:start-after: [START howto_sensor_batch]
:end-before: [END howto_sensor_batch]

In order to monitor the state of the AWS Batch Job asynchronously, use
:class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor` with the
parameter ``deferrable`` set to True.

Since this will release the Airflow worker slot , it will lead to efficient
utilization of available resources on your Airflow deployment.
This will also need the triggerer component to be available in your
Airflow deployment.

.. _howto/sensor:BatchComputeEnvironmentSensor:

Wait on an AWS Batch compute environment status
Expand Down
40 changes: 39 additions & 1 deletion tests/providers/amazon/aws/sensors/test_batch.py
Expand Up @@ -20,16 +20,18 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.sensors.batch import (
BatchComputeEnvironmentSensor,
BatchJobQueueSensor,
BatchSensor,
)
from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger

TASK_ID = "batch_job_sensor"
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
AWS_REGION = "eu-west-1"


class TestBatchSensor:
Expand Down Expand Up @@ -195,3 +197,39 @@ def test_poke_invalid(self, mock_batch_client):
jobQueues=[self.job_queue],
)
assert "AWS Batch job queue failed" in str(ctx.value)


class TestBatchAsyncSensor:
TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True)

def test_batch_sensor_async(self):
"""
Asserts that a task is deferred and a BatchSensorTrigger will be fired
when the BatchSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute({})
assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger"

def test_batch_sensor_async_execute_failure(self):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException) as exc_info:
self.TASK.execute_complete(
context={}, event={"status": "failure", "message": "test failure message"}
)

assert str(exc_info.value) == "test failure message"

@pytest.mark.parametrize(
"event",
[{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}],
)
def test_batch_sensor_async_execute_complete(self, caplog, event):
"""Tests that execute_complete method returns None and that it prints expected log"""

with mock.patch.object(self.TASK.log, "info") as mock_log_info:
assert self.TASK.execute_complete(context={}, event=event) is None

mock_log_info.assert_called_with(event["message"])
114 changes: 113 additions & 1 deletion tests/providers/amazon/aws/triggers/test_batch.py
Expand Up @@ -20,15 +20,17 @@
from unittest.mock import AsyncMock

import pytest
from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger
from airflow.triggers.base import TriggerEvent

BATCH_JOB_ID = "job_id"
POLL_INTERVAL = 5
MAX_ATTEMPT = 5
AWS_CONN_ID = "aws_batch_job_conn"
AWS_REGION = "us-east-2"
pytest.importorskip("aiobotocore")


class TestBatchOperatorTrigger:
Expand Down Expand Up @@ -69,3 +71,113 @@ async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter):
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID})


class TestBatchSensorTrigger:
TRIGGER = BatchSensorTrigger(
job_id=BATCH_JOB_ID,
region_name=AWS_REGION,
aws_conn_id=AWS_CONN_ID,
poke_interval=POLL_INTERVAL,
)

def test_batch_sensor_trigger_serialization(self):
"""
Asserts that the BatchSensorTrigger correctly serializes its arguments
and classpath.
"""

classpath, kwargs = self.TRIGGER.serialize()
assert classpath == "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger"
assert kwargs == {
"job_id": BATCH_JOB_ID,
"region_name": AWS_REGION,
"aws_conn_id": AWS_CONN_ID,
"poke_interval": POLL_INTERVAL,
}

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter):
the_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = the_mock

mock_get_waiter().wait = AsyncMock()

batch_trigger = BatchOperatorTrigger(
job_id=BATCH_JOB_ID,
poll_interval=POLL_INTERVAL,
max_retries=MAX_ATTEMPT,
aws_conn_id=AWS_CONN_ID,
region_name=AWS_REGION,
)

generator = batch_trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID})

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
async def test_batch_sensor_trigger_completed(self, mock_response, mock_async_conn, mock_get_waiter):
"""Test if the success event is returned from trigger."""
mock_response.return_value = {"status": "SUCCEEDED"}

the_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = the_mock

mock_get_waiter().wait = AsyncMock()

trigger = BatchSensorTrigger(
job_id=BATCH_JOB_ID,
region_name=AWS_REGION,
aws_conn_id=AWS_CONN_ID,
)
generator = trigger.run()
actual_response = await generator.asend(None)
assert (
TriggerEvent(
{"status": "success", "job_id": BATCH_JOB_ID, "message": f"Job {BATCH_JOB_ID} Succeeded"}
)
== actual_response
)

@pytest.mark.asyncio
@mock.patch("asyncio.sleep")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
async def test_batch_sensor_trigger_failure(
self, mock_async_conn, mock_response, mock_get_waiter, mock_sleep
):
"""Test if the failure event is returned from trigger."""
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock

mock_response.return_value = {"status": "failed"}

name = "batch_job_complete"
reason = (
"An error occurred (UnrecognizedClientException): The security token included in the "
"request is invalid. "
)
last_response = ({"Error": {"Message": "The security token included in the request is invalid."}},)

error_failed = WaiterError(
name=name,
reason=reason,
last_response=last_response,
)

mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error_failed])
mock_sleep.return_value = True

trigger = BatchSensorTrigger(job_id=BATCH_JOB_ID, region_name=AWS_REGION, aws_conn_id=AWS_CONN_ID)
generator = trigger.run()
actual_response = await generator.asend(None)
assert actual_response == TriggerEvent(
{"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"}
)

0 comments on commit 688f91b

Please sign in to comment.