Skip to content

Commit

Permalink
Add EC2HibernateInstanceOperator and EC2RebootInstanceOperator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dashton90 committed Nov 23, 2023
1 parent ca97fee commit ca1202f
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 0 deletions.
126 changes: 126 additions & 0 deletions airflow/providers/amazon/aws/operators/ec2.py
Expand Up @@ -19,6 +19,7 @@

from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook

Expand Down Expand Up @@ -254,3 +255,128 @@ def execute(self, context: Context):
"MaxAttempts": self.max_attempts,
},
)


class EC2RebootInstanceOperator(BaseOperator):
"""
Reboot Amazon EC2 instances.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EC2RebootInstanceOperator`
:param instance_ids: ID of the instance(s) to be rebooted.
:param aws_conn_id: AWS connection to use
:param region_name: AWS region name associated with the client.
:param poll_interval: Number of seconds to wait before attempting to
check state of instance. Only used if wait_for_completion is True. Default is 20.
:param max_attempts: Maximum number of attempts when checking state of instance.
Only used if wait_for_completion is True. Default is 20.
:param wait_for_completion: If True, the operator will wait for the instance to be
in the `running` state before returning.
"""

template_fields: Sequence[str] = ("instance_ids", "region_name")
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"

def __init__(
self,
*,
instance_ids: str | list[str],
aws_conn_id: str = "aws_default",
region_name: str | None = None,
poll_interval: int = 20,
max_attempts: int = 20,
wait_for_completion: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.instance_ids = instance_ids
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.wait_for_completion = wait_for_completion

def execute(self, context: Context):
if isinstance(self.instance_ids, str):
self.instance_ids = [self.instance_ids]
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
self.log.info("Rebooting EC2 instances %s", ", ".join(self.instance_ids))
ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids)

if self.wait_for_completion:
ec2_hook.get_waiter("instance_running").wait(
InstanceIds=self.instance_ids,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)


class EC2HibernateInstanceOperator(BaseOperator):
"""
Hibernate Amazon EC2 instances.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EC2HibernateInstanceOperator`
:param instance_ids: ID of the instance(s) to be hibernated.
:param aws_conn_id: AWS connection to use
:param region_name: AWS region name associated with the client.
:param poll_interval: Number of seconds to wait before attempting to
check state of instance. Only used if wait_for_completion is True. Default is 20.
:param max_attempts: Maximum number of attempts when checking state of instance.
Only used if wait_for_completion is True. Default is 20.
:param wait_for_completion: If True, the operator will wait for the instance to be
in the `stopped` state before returning.
"""

template_fields: Sequence[str] = ("instance_ids", "region_name")
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"

def __init__(
self,
*,
instance_ids: str | list[str],
aws_conn_id: str = "aws_default",
region_name: str | None = None,
poll_interval: int = 20,
max_attempts: int = 20,
wait_for_completion: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.instance_ids = instance_ids
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.wait_for_completion = wait_for_completion

def execute(self, context: Context):
if isinstance(self.instance_ids, str):
self.instance_ids = [self.instance_ids]
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
self.log.info("Hibernating EC2 instances %s", ", ".join(self.instance_ids))
instances = ec2_hook.get_instances(instance_ids=self.instance_ids)

for instance in instances:
hibernation_options = instance.get("HibernationOptions")
if not hibernation_options or not hibernation_options["Configured"]:
raise AirflowException(f"Instance {instance['InstanceId']} is not configured for hibernation")

ec2_hook.conn.stop_instances(InstanceIds=self.instance_ids, Hibernate=True)

if self.wait_for_completion:
ec2_hook.get_waiter("instance_stopped").wait(
InstanceIds=self.instance_ids,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)
28 changes: 28 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/ec2.rst
Expand Up @@ -86,6 +86,34 @@ To terminate an Amazon EC2 instance you can use
:start-after: [START howto_operator_ec2_terminate_instance]
:end-before: [END howto_operator_ec2_terminate_instance]

.. _howto/operator:EC2RebootInstanceOperator:

Reboot an Amazon EC2 instance
================================

To reboot an Amazon EC2 instance you can use
:class:`~airflow.providers.amazon.aws.operators.ec2.EC2RebootInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
:language: python
:dedent: 4
:start-after: [START howto_operator_ec2_reboot_instance]
:end-before: [END howto_operator_ec2_reboot_instance]

.. _howto/operator:EC2HibernateInstanceOperator:

Hibernate an Amazon EC2 instance
================================

To hibernate an Amazon EC2 instance you can use
:class:`~airflow.providers.amazon.aws.operators.ec2.EC2HibernateInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
:language: python
:dedent: 4
:start-after: [START howto_operator_ec2_hibernate_instance]
:end-before: [END howto_operator_ec2_hibernate_instance]

Sensors
-------

Expand Down
167 changes: 167 additions & 0 deletions tests/providers/amazon/aws/operators/test_ec2.py
Expand Up @@ -17,11 +17,15 @@
# under the License.
from __future__ import annotations

import pytest
from moto import mock_ec2

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.operators.ec2 import (
EC2CreateInstanceOperator,
EC2HibernateInstanceOperator,
EC2RebootInstanceOperator,
EC2StartInstanceOperator,
EC2StopInstanceOperator,
EC2TerminateInstanceOperator,
Expand Down Expand Up @@ -205,3 +209,166 @@ def test_stop_instance(self):
stop_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"


class TestEC2HibernateInstanceOperator(BaseEc2TestClass):
def test_init(self):
ec2_operator = EC2HibernateInstanceOperator(
task_id="task_test",
instance_ids="i-123abc",
)
assert ec2_operator.task_id == "task_test"
assert ec2_operator.instance_ids == "i-123abc"

@mock_ec2
def test_hibernate_instance(self):
# create instance
ec2_hook = EC2Hook()
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
config={"HibernationOptions": {"Configured": True}},
)
instance_id = create_instance.execute(None)

# hibernate instance
hibernate_test = EC2HibernateInstanceOperator(
task_id="hibernate_test",
instance_ids=instance_id[0],
)
hibernate_test.execute(None)
# assert instance state is stopped
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"

@mock_ec2
def test_hibernate_multiple_instances(self):
ec2_hook = EC2Hook()
create_instances = EC2CreateInstanceOperator(
task_id="test_create_multiple_instances",
image_id=self._get_image_id(hook=ec2_hook),
config={"HibernationOptions": {"Configured": True}},
min_count=5,
max_count=5,
)
instance_ids = create_instances.execute(None)
assert len(instance_ids) == 5

for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

hibernate_instance = EC2HibernateInstanceOperator(
task_id="test_hibernate_instance", instance_ids=instance_ids
)
hibernate_instance.execute(None)
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "stopped"

@mock_ec2
def test_cannot_hibernate_instance(self):
# create instance
ec2_hook = EC2Hook()
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

# hibernate instance
hibernate_test = EC2HibernateInstanceOperator(
task_id="hibernate_test",
instance_ids=instance_id[0],
)

# assert hibernating an instance not configured for hibernation raises an error
with pytest.raises(
AirflowException,
match="Instance .* is not configured for hibernation",
):
hibernate_test.execute(None)

# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

@mock_ec2
def test_cannot_hibernate_some_instances(self):
# create instance
ec2_hook = EC2Hook()
create_instance_hibernate = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
config={"HibernationOptions": {"Configured": True}},
)
instance_id_hibernate = create_instance_hibernate.execute(None)
create_instance_cannot_hibernate = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id_cannot_hibernate = create_instance_cannot_hibernate.execute(None)
instance_ids = [instance_id_hibernate[0], instance_id_cannot_hibernate[0]]

# hibernate instance
hibernate_test = EC2HibernateInstanceOperator(
task_id="hibernate_test",
instance_ids=instance_ids,
)
# assert hibernating an instance not configured for hibernation raises an error
with pytest.raises(
AirflowException,
match="Instance .* is not configured for hibernation",
):
hibernate_test.execute(None)

# assert instance state is running
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"


class TestEC2RebootInstanceOperator(BaseEc2TestClass):
def test_init(self):
ec2_operator = EC2RebootInstanceOperator(
task_id="task_test",
instance_ids="i-123abc",
)
assert ec2_operator.task_id == "task_test"
assert ec2_operator.instance_ids == "i-123abc"

@mock_ec2
def test_reboot_instance(self):
# create instance
ec2_hook = EC2Hook()
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

# reboot instance
reboot_test = EC2RebootInstanceOperator(
task_id="reboot_test",
instance_ids=instance_id[0],
)
reboot_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

@mock_ec2
def test_reboot_multiple_instances(self):
ec2_hook = EC2Hook()
create_instances = EC2CreateInstanceOperator(
task_id="test_create_multiple_instances",
image_id=self._get_image_id(hook=ec2_hook),
min_count=5,
max_count=5,
)
instance_ids = create_instances.execute(None)
assert len(instance_ids) == 5

for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

terminate_instance = EC2RebootInstanceOperator(
task_id="test_reboot_instance", instance_ids=instance_ids
)
terminate_instance.execute(None)
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

0 comments on commit ca1202f

Please sign in to comment.