Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(ecs): add unit tests for drain hook lambda #27803

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions packages/aws-cdk-lib/aws-ecs/lib/drain-hook/lambda-source/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,33 @@

def lambda_handler(event, context):
print(json.dumps(dict(event, ResponseURL='...')))

cluster = os.environ['CLUSTER']
snsTopicArn = event['Records'][0]['Sns']['TopicArn']
lifecycle_event = json.loads(event['Records'][0]['Sns']['Message'])
instance_id = lifecycle_event.get('EC2InstanceId')
instance_id = lifecycle_event.get('EC2InstanceId', None)

if not instance_id:
print('Got event without EC2InstanceId: %s', json.dumps(dict(event, ResponseURL='...')))
print(f'Got event without EC2InstanceId: {json.dumps(dict(event, ResponseURL="..."))}')
return

instance_arn = container_instance_arn(cluster, instance_id)
print('Instance %s has container instance ARN %s' % (lifecycle_event['EC2InstanceId'], instance_arn))
print(f'Instance {lifecycle_event["EC2InstanceId"]} has container instance ARN {instance_arn}')

if not instance_arn:
return

task_arns = container_instance_task_arns(cluster, instance_arn)

if task_arns:
print('Instance ARN %s has task ARNs %s' % (instance_arn, ', '.join(task_arns)))
print(f'Instance ARN {instance_arn} has task ARNs {", ".join(task_arns)}')

while has_tasks(cluster, instance_arn, task_arns):
time.sleep(10)

complete_lifecycle_action(instance_id, lifecycle_event)


def complete_lifecycle_action(instance_id, lifecycle_event):
try:
print('Terminating instance %s' % instance_id)
autoscaling.complete_lifecycle_action(
Expand All @@ -40,19 +45,29 @@ def lambda_handler(event, context):

def container_instance_arn(cluster, instance_id):
"""Turn an instance ID into a container instance ARN."""
arns = ecs.list_container_instances(cluster=cluster, filter='ec2InstanceId==' + instance_id)['containerInstanceArns']
arns = list_container_instances(cluster, instance_id)
if not arns:
return None
return arns[0]


def list_container_instances(cluster, instance_id):
return ecs.list_container_instances(cluster=cluster, filter='ec2InstanceId==' + instance_id)['containerInstanceArns']


def container_instance_task_arns(cluster, instance_arn):
"""Fetch tasks for a container instance ARN."""
arns = ecs.list_tasks(cluster=cluster, containerInstance=instance_arn)['taskArns']
arns = list_tasks(cluster, instance_arn)
return arns


def list_tasks(cluster, instance_arn):
return ecs.list_tasks(cluster=cluster, containerInstance=instance_arn)['taskArns']


def has_tasks(cluster, instance_arn, task_arns):
"""Return True if the instance is running tasks for the given cluster."""
instances = ecs.describe_container_instances(cluster=cluster, containerInstances=[instance_arn])['containerInstances']
instances = describe_container_instances(cluster, instance_arn)
if not instances:
return False
instance = instances[0]
Expand All @@ -66,7 +81,7 @@ def has_tasks(cluster, instance_arn, task_arns):

if task_arns:
# Fetch details for tasks running on the container instance
tasks = ecs.describe_tasks(cluster=cluster, tasks=task_arns)['tasks']
tasks = describe_tasks(cluster, task_arns)
if tasks:
# Consider any non-stopped tasks as running
task_count = sum(task['lastStatus'] != 'STOPPED' for task in tasks) + instance['pendingTasksCount']
Expand All @@ -79,6 +94,15 @@ def has_tasks(cluster, instance_arn, task_arns):

return task_count > 0


def describe_container_instances(cluster, instance_arn):
return ecs.describe_container_instances(cluster=cluster, containerInstances=[instance_arn])['containerInstances']


def describe_tasks(cluster, task_arns):
return ecs.describe_tasks(cluster=cluster, tasks=task_arns)['tasks']


def set_container_instance_to_draining(cluster, instance_arn):
ecs.update_container_instances_state(
cluster=cluster,
Expand Down
9 changes: 9 additions & 0 deletions packages/aws-cdk-lib/aws-ecs/test/drain-hook/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM public.ecr.aws/lambda/python:3.7

ADD . /opt/lambda
WORKDIR /opt/lambda

RUN pip3 install boto3==1.17.42
RUN python3 test_index.py

ENTRYPOINT [ "/bin/bash" ]
27 changes: 27 additions & 0 deletions packages/aws-cdk-lib/aws-ecs/test/drain-hook/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
#---------------------------------------------------------------------------------------------------
# executes unit tests
#
# prepares a staging directory with the requirements
set -e
script_dir=$(cd $(dirname $0) && pwd)

# prepare staging directory
staging=$(mktemp -d)
mkdir -p ${staging}
cd ${staging}

# copy src and overlay with test
cp ${script_dir}/../../lib/drain-hook/lambda-source/index.py $PWD
cp ${script_dir}/test_index.py $PWD
cp ${script_dir}/Dockerfile $PWD

DRAIN_HOOK_TEST_NO_DOCKER=${DRAIN_HOOK_TEST_NO_DOCKER:-""}
DOCKER_CMD=${CDK_DOCKER:-docker}

if [ -z ${DRAIN_HOOK_TEST_NO_DOCKER} ]; then
# this will run our tests inside the right environment
$DOCKER_CMD build .
else
python3 test_index.py
fi
114 changes: 114 additions & 0 deletions packages/aws-cdk-lib/aws-ecs/test/drain-hook/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import unittest
import os
import sys
from unittest.mock import patch

os.environ["CLUSTER"] = "my-cluster"

try:
# this is available only if executed with ./test.sh
import index
except ModuleNotFoundError as _:
print(
"Unable to import index. Use ./test.sh to run these tests. "
+ 'If you want to avoid running them in docker, run "DRAIN_HOOK_TEST_NO_DOCKER=true ./test.sh"'
)
sys.exit(1)


def make_event():
records = []
records.append({'Sns': {'Message': '{"EC2InstanceId": "i-xxxxxx", "LifecycleHookName": "my-hook", "LifecycleActionToken": "my-token", "AutoScalingGroupName": "my-asg"}'}})
return {'Records': records}


def make_event_no_instance_id():
records = []
records.append({'Sns': {'Message': '{"food": "bar"}'}})
return {'Records': records}


class DrainHookTest(unittest.TestCase):
@patch("index.list_container_instances")
def test_no_instance_id(self, list):
event = make_event_no_instance_id()
index.lambda_handler(event, {})
list.assert_not_called()

@patch("index.complete_lifecycle_action")
@patch("index.list_tasks")
@patch("index.list_container_instances")
def test_no_instance_arn(self, list, tasks, complete):
event = make_event()

list.return_value = []
index.lambda_handler(event, {})

list.assert_called_once_with(
os.environ["CLUSTER"],
"i-xxxxxx",
)
tasks.assert_not_called()
complete.assert_not_called()

@patch("index.complete_lifecycle_action")
@patch("index.describe_container_instances")
@patch("index.list_tasks")
@patch("index.list_container_instances")
def test_no_list_tasks_no_container_instances(self, list, tasks, describe, complete):
event = make_event()

list.return_value = ['some-container-instance-arn']
tasks.return_value = []
describe.return_value = []
index.lambda_handler(event, {})

list.assert_called_once_with(
os.environ["CLUSTER"],
"i-xxxxxx",
)
tasks.assert_called_once_with(
os.environ["CLUSTER"],
'some-container-instance-arn',
)
describe.assert_called_once_with(
os.environ["CLUSTER"],
'some-container-instance-arn',
)
complete.assert_called_once()

@patch("index.complete_lifecycle_action")
@patch("index.describe_tasks")
@patch("index.describe_container_instances")
@patch("index.list_tasks")
@patch("index.list_container_instances")
def test_has_list_tasks_no_describe_tasks(self, list, tasks, describe, describe_tasks, complete):
event = make_event()

list.return_value = ['some-container-instance-arn']
tasks.return_value = ['task-arn']
describe.return_value = [{'id': 'i-xxxx', 'status': 'TERMINATED', 'runningTasksCount': 0, 'pendingTasksCount': 0}]
describe_tasks.return_value = []
index.lambda_handler(event, {})

list.assert_called_once_with(
os.environ["CLUSTER"],
"i-xxxxxx",
)
tasks.assert_called_once_with(
os.environ["CLUSTER"],
'some-container-instance-arn',
)
describe.assert_called_once_with(
os.environ["CLUSTER"],
'some-container-instance-arn',
)
describe_tasks.assert_called_once_with(
os.environ["CLUSTER"],
tasks.return_value,
)
complete.assert_called_once()


if __name__ == "__main__":
unittest.main()