Skip to content

Commit

Permalink
Add EC2CreateInstanceOperator, EC2TerminateInstanceOperator (#29548)
Browse files Browse the repository at this point in the history
* Add EC2CreateInstanceOperator and EC2TerminteInstanceOperator
Change system test to use the new operators
Add unit tests for new operators

* Add support for multiple ids to EC2TerminateInstanceOperator
Change system test to terminate without stopping instances

* Fix failing tests for terminate operator

* Update doc strings to add that the operators can create/terminate multiple instances
Add tests for creating/terminating multiple instances

* Fix system test so it passes
Fix doc string on EC2TerminateInstanceOperator

---------

Co-authored-by: syedahsn <syedahsn@ud74d1a752d7e5b.ant.amazon.com>
  • Loading branch information
syedahsn and syedahsn committed Mar 7, 2023
1 parent b069df9 commit d2cc9df
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 55 deletions.
138 changes: 138 additions & 0 deletions airflow/providers/amazon/aws/operators/ec2.py
Expand Up @@ -116,3 +116,141 @@ def execute(self, context: Context):
target_state="stopped",
check_interval=self.check_interval,
)


class EC2CreateInstanceOperator(BaseOperator):
"""
Create and start a specified number of EC2 Instances using boto3
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EC2CreateInstanceOperator`
:param image_id: ID of the AMI used to create the instance.
:param max_count: Maximum number of instances to launch. Defaults to 1.
:param min_count: Minimum number of instances to launch. Defaults to 1.
:param aws_conn_id: AWS connection to use
:param region_name: AWS region name associated with the client.
:param poll_interval: Number of seconds to wait before attempting to
check state of instance. Only used if wait_for_completion is True. Default is 20.
:param max_attempts: Maximum number of attempts when checking state of instance.
Only used if wait_for_completion is True. Default is 20.
:param config: Dictionary for arbitrary parameters to the boto3 run_instances call.
:param wait_for_completion: If True, the operator will wait for the instance to be
in the `running` state before returning.
"""

template_fields: Sequence[str] = (
"image_id",
"max_count",
"min_count",
"aws_conn_id",
"region_name",
"config",
"wait_for_completion",
)

def __init__(
self,
image_id: str,
max_count: int = 1,
min_count: int = 1,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
poll_interval: int = 20,
max_attempts: int = 20,
config: dict | None = None,
wait_for_completion: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.image_id = image_id
self.max_count = max_count
self.min_count = min_count
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.config = config or {}
self.wait_for_completion = wait_for_completion

def execute(self, context: Context):
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
instances = ec2_hook.conn.run_instances(
ImageId=self.image_id,
MinCount=self.min_count,
MaxCount=self.max_count,
**self.config,
)["Instances"]
instance_ids = []
for instance in instances:
instance_ids.append(instance["InstanceId"])
self.log.info("Created EC2 instance %s", instance["InstanceId"])

if self.wait_for_completion:
ec2_hook.get_waiter("instance_running").wait(
InstanceIds=[instance["InstanceId"]],
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)

return instance_ids


class EC2TerminateInstanceOperator(BaseOperator):
"""
Terminate EC2 Instances using boto3
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EC2TerminateInstanceOperator`
:param instance_id: ID of the instance to be terminated.
:param aws_conn_id: AWS connection to use
:param region_name: AWS region name associated with the client.
:param poll_interval: Number of seconds to wait before attempting to
check state of instance. Only used if wait_for_completion is True. Default is 20.
:param max_attempts: Maximum number of attempts when checking state of instance.
Only used if wait_for_completion is True. Default is 20.
:param wait_for_completion: If True, the operator will wait for the instance to be
in the `terminated` state before returning.
"""

template_fields: Sequence[str] = ("instance_ids", "region_name", "aws_conn_id", "wait_for_completion")

def __init__(
self,
instance_ids: str | list[str],
aws_conn_id: str = "aws_default",
region_name: str | None = None,
poll_interval: int = 20,
max_attempts: int = 20,
wait_for_completion: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.instance_ids = instance_ids
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.wait_for_completion = wait_for_completion

def execute(self, context: Context):
if isinstance(self.instance_ids, str):
self.instance_ids = [self.instance_ids]
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
ec2_hook.conn.terminate_instances(InstanceIds=self.instance_ids)

for instance_id in self.instance_ids:
self.log.info("Terminating EC2 instance %s", instance_id)
if self.wait_for_completion:
ec2_hook.get_waiter("instance_terminated").wait(
InstanceIds=[instance_id],
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)
28 changes: 28 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/ec2.rst
Expand Up @@ -58,6 +58,34 @@ To stop an Amazon EC2 instance you can use
:start-after: [START howto_operator_ec2_stop_instance]
:end-before: [END howto_operator_ec2_stop_instance]

.. _howto/operator:EC2CreateInstanceOperator:

Create and start an Amazon EC2 instance
=======================================

To create and start an Amazon EC2 instance you can use
:class:`~airflow.providers.amazon.aws.operators.ec2.EC2CreateInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
:language: python
:dedent: 4
:start-after: [START howto_operator_ec2_create_instance]
:end-before: [END howto_operator_ec2_create_instance]

.. _howto/operator:EC2TerminateInstanceOperator:

Terminate an Amazon EC2 instance
================================

To terminate an Amazon EC2 instance you can use
:class:`~airflow.providers.amazon.aws.operators.ec2.EC2TerminateInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
:language: python
:dedent: 4
:start-after: [START howto_operator_ec2_terminate_instance]
:end-before: [END howto_operator_ec2_terminate_instance]

Sensors
-------

Expand Down
130 changes: 118 additions & 12 deletions tests/providers/amazon/aws/operators/test_ec2.py
Expand Up @@ -20,23 +20,121 @@
from moto import mock_ec2

from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
from airflow.providers.amazon.aws.operators.ec2 import (
EC2CreateInstanceOperator,
EC2StartInstanceOperator,
EC2StopInstanceOperator,
EC2TerminateInstanceOperator,
)


class BaseEc2TestClass:
@classmethod
def _create_instance(cls, hook: EC2Hook):
"""Create Instance and return instance id."""
def _get_image_id(cls, hook):
"""Get a valid image id to create an instance."""
conn = hook.get_conn()
try:
ec2_client = conn.meta.client
except AttributeError:
ec2_client = conn

# We need existed AMI Image ID otherwise `moto` will raise DeprecationWarning.
# We need an existing AMI Image ID otherwise `moto` will raise DeprecationWarning.
images = ec2_client.describe_images()["Images"]
response = ec2_client.run_instances(MaxCount=1, MinCount=1, ImageId=images[0]["ImageId"])
return response["Instances"][0]["InstanceId"]
return images[0]["ImageId"]


class TestEC2CreateInstanceOperator(BaseEc2TestClass):
def test_init(self):
ec2_operator = EC2CreateInstanceOperator(
task_id="test_create_instance",
image_id="test_image_id",
)

assert ec2_operator.task_id == "test_create_instance"
assert ec2_operator.image_id == "test_image_id"
assert ec2_operator.max_count == 1
assert ec2_operator.min_count == 1
assert ec2_operator.max_attempts == 20
assert ec2_operator.poll_interval == 20

@mock_ec2
def test_create_instance(self):
ec2_hook = EC2Hook()
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

@mock_ec2
def test_create_multiple_instances(self):
ec2_hook = EC2Hook()
create_instances = EC2CreateInstanceOperator(
task_id="test_create_multiple_instances",
image_id=self._get_image_id(hook=ec2_hook),
min_count=5,
max_count=5,
)
instance_ids = create_instances.execute(None)
assert len(instance_ids) == 5

for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"


class TestEC2TerminateInstanceOperator(BaseEc2TestClass):
def test_init(self):
ec2_operator = EC2TerminateInstanceOperator(
task_id="test_terminate_instance",
instance_ids="test_image_id",
)

assert ec2_operator.task_id == "test_terminate_instance"
assert ec2_operator.max_attempts == 20
assert ec2_operator.poll_interval == 20

@mock_ec2
def test_terminate_instance(self):
ec2_hook = EC2Hook()

create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

terminate_instance = EC2TerminateInstanceOperator(
task_id="test_terminate_instance", instance_ids=instance_id
)
terminate_instance.execute(None)

assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "terminated"

@mock_ec2
def test_terminate_multiple_instances(self):
ec2_hook = EC2Hook()
create_instances = EC2CreateInstanceOperator(
task_id="test_create_multiple_instances",
image_id=self._get_image_id(hook=ec2_hook),
min_count=5,
max_count=5,
)
instance_ids = create_instances.execute(None)
assert len(instance_ids) == 5

for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

terminate_instance = EC2TerminateInstanceOperator(
task_id="test_terminate_instance", instance_ids=instance_ids
)
terminate_instance.execute(None)
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "terminated"


class TestEC2StartInstanceOperator(BaseEc2TestClass):
Expand All @@ -58,16 +156,20 @@ def test_init(self):
def test_start_instance(self):
# create instance
ec2_hook = EC2Hook()
instance_id = self._create_instance(ec2_hook)
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

# start instance
start_test = EC2StartInstanceOperator(
task_id="start_test",
instance_id=instance_id,
instance_id=instance_id[0],
)
start_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id) == "running"
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"


class TestEC2StopInstanceOperator(BaseEc2TestClass):
Expand All @@ -89,13 +191,17 @@ def test_init(self):
def test_stop_instance(self):
# create instance
ec2_hook = EC2Hook()
instance_id = self._create_instance(ec2_hook)
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

# stop instance
stop_test = EC2StopInstanceOperator(
task_id="stop_test",
instance_id=instance_id,
instance_id=instance_id[0],
)
stop_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped"
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"

0 comments on commit d2cc9df

Please sign in to comment.