Skip to content

Commit

Permalink
Add deferrable mode in Redshift delete cluster (#30244)
Browse files Browse the repository at this point in the history
Add the deferrable param in RedshiftDeleteClusterOperator.
This will allow running RedshiftDeleteClusterOperator in an async way
that means we only submit a job from the worker to delete a redshift cluster
then defer to the trigger for the polling and waiter for a cluster to get removed
and the worker slot won't be occupied for the whole period of
task execution.
  • Loading branch information
pankajastro committed Jun 4, 2023
1 parent 86b5ba2 commit a247a8f
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 6 deletions.
33 changes: 29 additions & 4 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Expand Up @@ -26,6 +26,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftDeleteClusterTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)
Expand Down Expand Up @@ -629,6 +630,8 @@ class RedshiftDeleteClusterOperator(BaseOperator):
The default value is ``True``
:param aws_conn_id: aws connection to use
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
:param deferrable: Run operator in the deferrable mode.
:param max_attempts: (Deferrable mode only) The maximum number of attempts to be made
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -643,7 +646,9 @@ def __init__(
final_cluster_snapshot_identifier: str | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
poll_interval: float = 30.0,
poll_interval: int = 30,
deferrable: bool = False,
max_attempts: int = 30,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -658,8 +663,12 @@ def __init__(
self._attempts = 10
self._attempt_interval = 15
self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.max_attempts = max_attempts

def execute(self, context: Context):

while self._attempts >= 1:
try:
self.redshift_hook.delete_cluster(
Expand All @@ -676,10 +685,26 @@ def execute(self, context: Context):
time.sleep(self._attempt_interval)
else:
raise

if self.wait_for_completion:
if self.deferrable:
self.defer(
timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60),
trigger=RedshiftDeleteClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted")
waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30},
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts},
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error deleting cluster: {event}")
else:
self.log.info("Cluster deleted successfully")
75 changes: 75 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Expand Up @@ -357,3 +357,78 @@ async def run(self):
)
else:
yield TriggerEvent({"status": "success", "message": "Cluster resumed"})


class RedshiftDeleteClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftDeleteClusterOperator
:param cluster_identifier: A unique identifier for the cluster.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param poll_interval: The amount of time in seconds to wait between attempts.
"""

def __init__(
self,
cluster_identifier: str,
max_attempts: int = 30,
aws_conn_id: str = "aws_default",
poll_interval: int = 30,
):
super().__init__()
self.cluster_identifier = cluster_identifier
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger",
{
"cluster_identifier": self.cluster_identifier,
"max_attempts": self.max_attempts,
"aws_conn_id": self.aws_conn_id,
"poll_interval": self.poll_interval,
},
)

@cached_property
def hook(self):
return RedshiftHook(aws_conn_id=self.aws_conn_id)

async def run(self) -> AsyncIterator[TriggerEvent]:
async with self.hook.async_conn as client:
attempt = 0
waiter = client.get_waiter("cluster_deleted")
while attempt < self.max_attempts:
attempt = attempt + 1
try:
await waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": 1,
},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
yield TriggerEvent(
{"status": "failure", "message": f"Delete Cluster Failed: {error}"}
)
break
self.log.info(
"Cluster status is %s. Retrying attempt %s/%s",
error.last_response["Clusters"][0]["ClusterStatus"],
attempt,
self.max_attempts,
)
await asyncio.sleep(int(self.poll_interval))

if attempt >= self.max_attempts:
yield TriggerEvent(
{"status": "failure", "message": "Delete Cluster Failed - max attempts reached."}
)
else:
yield TriggerEvent({"status": "success", "message": "Cluster deleted."})
Expand Up @@ -53,7 +53,8 @@ Resume an Amazon Redshift cluster

To resume a 'paused' Amazon Redshift cluster you can use
:class:`RedshiftResumeClusterOperator <airflow.providers.amazon.aws.operators.redshift_cluster>`
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.
This will ensure that the task is deferred from the Airflow worker slot and polling for the task status happens on the trigger.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py
:language: python
Expand Down Expand Up @@ -110,7 +111,8 @@ Delete an Amazon Redshift cluster
=================================

To delete an Amazon Redshift cluster you can use
:class:`RedshiftDeleteClusterOperator <airflow.providers.amazon.aws.operators.redshift_cluster>`
:class:`RedshiftDeleteClusterOperator <airflow.providers.amazon.aws.operators.redshift_cluster>`.
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py
:language: python
Expand Down
44 changes: 44 additions & 0 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Expand Up @@ -33,6 +33,7 @@
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftCreateClusterSnapshotTrigger,
RedshiftDeleteClusterTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)
Expand Down Expand Up @@ -520,3 +521,46 @@ def test_delete_cluster_multiple_attempts_fail(self, _, mock_conn, mock_delete_c
redshift_operator.execute(None)

assert mock_delete_cluster.call_count == 10

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.delete_cluster")
def test_delete_cluster_deferrable_mode(self, mock_delete_cluster):
"""Test delete cluster operator with defer when deferrable param is true"""
mock_delete_cluster.return_value = True
delete_cluster = RedshiftDeleteClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
deferrable=True,
wait_for_completion=False,
)

with pytest.raises(TaskDeferred) as exc:
delete_cluster.execute(context=None)

assert isinstance(
exc.value.trigger, RedshiftDeleteClusterTrigger
), "Trigger is not a RedshiftDeleteClusterTrigger"

def test_delete_cluster_execute_complete_success(self):
"""Asserts that logging occurs as expected"""
task = RedshiftDeleteClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
deferrable=True,
wait_for_completion=False,
)
with mock.patch.object(task.log, "info") as mock_log_info:
task.execute_complete(context=None, event={"status": "success", "message": "Cluster deleted"})
mock_log_info.assert_called_with("Cluster deleted successfully")

def test_delete_cluster_execute_complete_fail(self):
redshift_operator = RedshiftDeleteClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
deferrable=True,
wait_for_completion=False,
)

with pytest.raises(AirflowException):
redshift_operator.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)
133 changes: 133 additions & 0 deletions tests/providers/amazon/aws/triggers/test_redshift_cluster.py
Expand Up @@ -26,6 +26,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftDeleteClusterTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)
Expand Down Expand Up @@ -500,3 +501,135 @@ async def test_redshift_resume_cluster_trigger_run_attempts_failed(
assert response == TriggerEvent(
{"status": "failure", "message": f"Resume Cluster Failed: {error_failed}"}
)


class TestRedshiftDeleteClusterTrigger:
def test_redshift_delete_cluster_trigger_serialize(self):
redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)
class_path, args = redshift_delete_cluster_trigger.serialize()
assert (
class_path
== "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger"
)
assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
assert args["poll_interval"] == TEST_POLL_INTERVAL
assert args["max_attempts"] == TEST_MAX_ATTEMPT
assert args["aws_conn_id"] == TEST_AWS_CONN_ID

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn")
async def test_redshift_delete_cluster_trigger_run(self, mock_async_conn):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock
a_mock.get_waiter().wait = AsyncMock()

redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_delete_cluster_trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster deleted."})

@pytest.mark.asyncio
@mock.patch("asyncio.sleep")
@mock.patch.object(RedshiftHook, "async_conn")
async def test_redshift_delete_cluster_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock
error = WaiterError(
name="test_name",
reason="test_reason",
last_response={"Clusters": [{"ClusterStatus": "available"}]},
)
a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True])
mock_sleep.return_value = True

redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_delete_cluster_trigger.run()
response = await generator.asend(None)

assert a_mock.get_waiter().wait.call_count == 3
assert response == TriggerEvent({"status": "success", "message": "Cluster deleted."})

@pytest.mark.asyncio
@mock.patch("asyncio.sleep")
@mock.patch.object(RedshiftHook, "async_conn")
async def test_redshift_delete_cluster_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock

error = WaiterError(
name="test_name",
reason="test_reason",
last_response={"Clusters": [{"ClusterStatus": "deleting"}]},
)
a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True])
mock_sleep.return_value = True

redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=2,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_delete_cluster_trigger.run()
response = await generator.asend(None)

assert a_mock.get_waiter().wait.call_count == 2
assert response == TriggerEvent(
{"status": "failure", "message": "Delete Cluster Failed - max attempts reached."}
)

@pytest.mark.asyncio
@mock.patch("asyncio.sleep")
@mock.patch.object(RedshiftHook, "async_conn")
async def test_redshift_delete_cluster_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock

error_available = WaiterError(
name="test_name",
reason="Max attempts exceeded",
last_response={"Clusters": [{"ClusterStatus": "deleting"}]},
)
error_failed = WaiterError(
name="test_name",
reason="Waiter encountered a terminal failure state:",
last_response={"Clusters": [{"ClusterStatus": "available"}]},
)
a_mock.get_waiter().wait.side_effect = AsyncMock(
side_effect=[error_available, error_available, error_failed]
)
mock_sleep.return_value = True

redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_delete_cluster_trigger.run()
response = await generator.asend(None)

assert a_mock.get_waiter().wait.call_count == 3
assert response == TriggerEvent(
{"status": "failure", "message": f"Delete Cluster Failed: {error_failed}"}
)

0 comments on commit a247a8f

Please sign in to comment.