Skip to content

Commit

Permalink
Fix sample expunge VPC, if-len, and process deployment maps (#716)
Browse files Browse the repository at this point in the history
**Why?**

Amazon CodeGuru suggested improvements to:

* Use paginators in the sample-expunge-vpc code, after reviewing, this sample
  required some more attention.
* Change `if len(...) > or ==` statements to match against the value. A list
  with elements is True, an empty list is False. Improving readability.
* Move boto3 client and resource creation to the initialization phase of the
  process deployment map Lambda function, to speed up its execution in
  consecutive runs.
  • Loading branch information
sbkok committed Apr 19, 2024
1 parent 4ebf49b commit 325c960
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 97 deletions.
198 changes: 114 additions & 84 deletions samples/sample-expunge-vpc/src/lambda_vpc/lambda_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright Amazon.com Inc. or its affiliates.
# SPDX-License-Identifier: MIT-0

from crhelper import CfnResource
import logging
import boto3
from os import environ
import logging
import hashlib

from crhelper import CfnResource
import boto3

logger = logging.getLogger(__name__)

helper = CfnResource(
Expand All @@ -15,7 +16,7 @@
boto_level='CRITICAL'
)

region_name = environ['region_name']
REGION_NAME = environ['region_name']


def generate_dummy_resource_id(event):
Expand All @@ -27,7 +28,10 @@ def generate_dummy_resource_id(event):

def create_ec2_client(region_name, **kwargs):
if 'profile' in kwargs:
logger.info("Creating Boto3 EC2 Client using profile: {}".format(kwargs['profile']))
logger.info(
"Creating Boto3 EC2 Client using profile: %s",
kwargs['profile'],
)
session = boto3.Session(profile_name=kwargs['profile'])
client = session.client('ec2', region_name=region_name)
else:
Expand All @@ -38,110 +42,130 @@ def create_ec2_client(region_name, **kwargs):

def delete_subnets(client, vpc_id):
logger.info("Getting subnets for VPC")
subnet = client.describe_subnets(
subnet_paginator = client.get_paginator('describe_subnets')
subnet_pages = subnet_paginator.paginate(
Filters=[
{
'Name': 'vpc-id',
'Values': [
vpc_id
]
}
]
vpc_id,
],
},
],
)
logger.info(f"{len(subnet['Subnets'])} Subnets found")
for s in subnet['Subnets']:
logger.info(f"Deleting subnet with ID: {s['SubnetId']}")
client.delete_subnet(
SubnetId=s['SubnetId']
)
for subnet in subnet_pages:
logger.info("%d Subnets found", len(subnet['Subnets']))
for s in subnet['Subnets']:
logger.info("Deleting subnet with ID: %s", s['SubnetId'])
client.delete_subnet(
SubnetId=s['SubnetId'],
)


def delete_internet_gateway(client, vpc_id):
logger.info(f"Getting Internet Gateways attached to {vpc_id}")
igw = client.describe_internet_gateways(
logger.info("Getting Internet Gateways attached to %s", vpc_id)
igw_paginator = client.get_paginator('describe_internet_gateways')
igw_pages = igw_paginator.paginate(
Filters=[
{
'Name': 'attachment.vpc-id',
'Values': [
vpc_id,
]
],
},
]
],
)
logger.info(f"{len(igw['InternetGateways'])} Gateways found")
for gw in igw['InternetGateways']:
logger.info(f"Detaching internet gateway: {gw['InternetGatewayId']} from VPC")
client.detach_internet_gateway(
InternetGatewayId=gw['InternetGatewayId'],
VpcId=vpc_id
)
logger.info(f"Deleting internet gateway: {gw['InternetGatewayId']}")
client.delete_internet_gateway(
InternetGatewayId=gw['InternetGatewayId']
)
for page in igw_pages:
logger.info("%d Gateways found", len(page['InternetGateways']))
for gw in page['InternetGateways']:
logger.info(
"Detaching internet gateway: %s from VPC",
gw['InternetGatewayId'],
)
client.detach_internet_gateway(
InternetGatewayId=gw['InternetGatewayId'],
VpcId=vpc_id,
)
logger.info(
"Deleting internet gateway: %s",
gw['InternetGatewayId'],
)
client.delete_internet_gateway(
InternetGatewayId=gw['InternetGatewayId'],
)


def delete_route_tables(client, vpc_id):
logger.info(f"Getting Route Tables attached to {vpc_id}")
route_tables = client.describe_route_tables(
logger.info("Getting Route Tables attached to %s", vpc_id)
route_table_paginator = client.get_paginator('describe_route_tables')
route_table_pages = route_table_paginator.paginate(
Filters=[
{
'Name': 'vpc-id',
'Values': [
vpc_id
]
vpc_id,
],
},
]
],
)
logger.info(f"{len(route_tables['RouteTables'])} Route Tables found")
for route_table in route_tables['RouteTables']:
for route in route_table['Routes']:
if route['GatewayId'] != 'local':
client.delete_route(
DestinationCidrBlock=route['DestinationCidrBlock'],
RouteTableId=route_table['RouteTableId']
)
for route_tables in route_table_pages:
logger.info("%d Route Tables found", len(route_tables['RouteTables']))
for route_table in route_tables['RouteTables']:
for route in route_table['Routes']:
if route['GatewayId'] != 'local':
client.delete_route(
DestinationCidrBlock=route['DestinationCidrBlock'],
RouteTableId=route_table['RouteTableId'],
)


def delete_security_groups(client, vpc_id):
logger.info(f"Getting Security Groups attached to {vpc_id}")
groups = client.describe_security_groups(
logger.info("Getting Security Groups attached to %s", vpc_id)
security_group_paginator = client.get_paginator('describe_security_groups')
security_group_pages = security_group_paginator.paginate(
Filters=[
{
'Name': 'vpc-id',
'Values': [
vpc_id
]
vpc_id,
],
},
])
logger.info(f"{len(groups['SecurityGroups'])} Security Groups found")
for group in groups['SecurityGroups']:
if group['GroupName'] != "default":
logger.info(f"Deleting non default group: {group['GroupName']}")
client.delete_security_group(
GroupId=group['GroupId'],
GroupName=group['GroupName']
)
],
)
for groups in security_group_pages:
logger.info("%d Security Groups found", len(groups['SecurityGroups']))
for group in groups['SecurityGroups']:
if group['GroupName'] != "default":
logger.info(
"Deleting non-default group: %s",
group['GroupName'],
)
client.delete_security_group(
GroupId=group['GroupId'],
GroupName=group['GroupName'],
)


def remove_default_vpc(client):
vpcs = client.describe_vpcs()
for vpc in vpcs['Vpcs']:
if vpc['IsDefault']:
logger.info(f"Default VPC found. VPC ID: {vpc['VpcId']}")
vpc_paginator = client.get_paginator('describe_vpcs')
vpc_pages = vpc_paginator.paginate()
for vpcs in vpc_pages:
for vpc in vpcs['Vpcs']:
if vpc['IsDefault']:
logger.info("Default VPC found. VPC ID: %s", vpc['VpcId'])

delete_subnets(client, vpc['VpcId'])
delete_subnets(client, vpc['VpcId'])

delete_internet_gateway(client, vpc['VpcId'])
delete_internet_gateway(client, vpc['VpcId'])

delete_route_tables(client, vpc['VpcId'])
delete_route_tables(client, vpc['VpcId'])

delete_security_groups(client, vpc['VpcId'])
delete_security_groups(client, vpc['VpcId'])

logger.info(f"Deleting VPC: {vpc['VpcId']}")
client.delete_vpc(
VpcId=vpc['VpcId']
)
logger.info("Deleting VPC: %s", vpc['VpcId'])
client.delete_vpc(
VpcId=vpc['VpcId'],
)


def get_regions(client):
Expand All @@ -153,50 +177,56 @@ def get_regions(client):
def create(event, context):
logger.info("Stack creation therefore default VPCs are to be removed")

client = create_ec2_client(region_name)
client = create_ec2_client(REGION_NAME)
regions = get_regions(client)
for region in regions['Regions']:
logger.info(f"Creating ec2 client in {region['RegionName']} region")
logger.info("Creating ec2 client in %s region", region['RegionName'])
logger.info(
"Calling 'remove_default_vpc' function to remove default VPC and associated resources within the region")
"Calling 'remove_default_vpc' function to remove default VPC and "
"associated resources within the region",
)
ec2_client = create_ec2_client(region_name=region['RegionName'])
remove_default_vpc(ec2_client)
logger.info('~' * 72)

# Items stored in helper.Data will be saved
# as outputs in your resource in CloudFormation
helper.Data.update({})
return generate_dummy_resource_id(event) # This is the Physical resource of your ID
# This is the Physical resource of your ID:
return generate_dummy_resource_id(event)


@helper.update
def update(event, context):
logger.info("Stack update therefore no real changes required on resources")
return generate_dummy_resource_id(event) # This is the Physical resource of your ID
# This is the Physical resource of your ID:
return generate_dummy_resource_id(event)


@helper.delete
def delete(event, context):
logger.info("Stack deletion therefore default VPCs are to be recreated")

client = create_ec2_client(region_name)
client = create_ec2_client(REGION_NAME)
regions = get_regions(client)
for region in regions['Regions']:
logger.info(f"Creating ec2 client in {region['RegionName']} region")
logger.info("Creating ec2 client in %s region", region['RegionName'])
ec2_client = create_ec2_client(region_name=region['RegionName'])
vpcs = ec2_client.describe_vpcs()
if len(vpcs['Vpcs']) == 0:
vpc_paginator = ec2_client.get_paginator('describe_vpcs')
vpc_pages = vpc_paginator.paginate()
default_vpc_found = False
for vpcs in vpc_pages:
for vpc in vpcs['Vpcs']:
if vpc['IsDefault']:
default_vpc_found = True
if not default_vpc_found:
logger.info("Creating default VPC in region")
ec2_client.create_default_vpc()
else:
for vpc in vpcs['Vpcs']:
if not vpc['IsDefault']:
logger.info("Creating default VPC in region")
ec2_client.create_default_vpc()

logger.info('~' * 72)

return generate_dummy_resource_id(event) # This is the Physical resource of your ID
# This is the Physical resource of your ID
return generate_dummy_resource_id(event)


def lambda_handler(event, context):
Expand Down
1 change: 1 addition & 0 deletions samples/sample-expunge-vpc/src/lambda_vpc/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
crhelper
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def grant_access_to_s3_buckets(self, bucket_names):
'calling grant_s3_buckets_access for bucket_names %s',
bucket_names,
)
if len(bucket_names) == 0:
if not bucket_names:
return

statement = self._get_statement('S3')
Expand Down Expand Up @@ -112,7 +112,7 @@ def grant_access_to_kms_keys(self, kms_key_arns):
'calling grant_usage_on_kms_keys for key arns %s',
kms_key_arns,
)
if len(kms_key_arns) == 0:
if not kms_key_arns:
return

statement = self._get_statement('KMS')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def lambda_handler(event, _):
)

output = {}
if len(out_of_date_pipelines) > 0:
if out_of_date_pipelines:
output["pipelines_to_be_deleted"] = out_of_date_pipelines

data_md5 = hashlib.md5(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
ADF_VERSION = os.getenv("ADF_VERSION")
ADF_VERSION_METADATA_KEY = "adf_version"

S3_RESOURCE = boto3.resource("s3")
SFN_CLIENT = boto3.client("stepfunctions")


class DeploymentMapFileData(TypedDict):
"""
Expand Down Expand Up @@ -211,16 +214,14 @@ def lambda_handler(event, context):
dict: The input event is returned.
"""
output = event.copy()
s3_resource = boto3.resource("s3")
sfn_client = boto3.client("stepfunctions")
s3_details = get_details_from_event(event)
deployment_map = get_file_from_s3(s3_details, s3_resource)
deployment_map = get_file_from_s3(s3_details, S3_RESOURCE)
if deployment_map.get("content"):
deployment_map["content"]["definition_bucket"] = s3_details.get(
"object_key",
)
start_executions(
sfn_client,
SFN_CLIENT,
s3_details.get("object_key"),
deployment_map["content"],
codepipeline_execution_id=deployment_map.get("execution_id"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def _delete_base_stacks(
)
if not matches_search:
continue
if len(stack.get('ParentId', '')) > 0:
if stack.get('ParentId', ''):
# Skip nested stacks
continue

Expand Down Expand Up @@ -635,7 +635,7 @@ def _get_stack_status(self, name):
response = self.client.describe_stacks(
StackName=name,
)
if response and len(response.get('Stacks', [])) > 0:
if response and response.get('Stacks', []):
return response['Stacks'][0]['StackStatus']
return None
except (ClientError, ValidationError) as error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def _flatten_list(input_list):
result = []
for item in input_list:
if isinstance(item, list):
if len(item) > 0:
if item:
result.extend(
_flatten_list(item),
)
Expand Down
Loading

0 comments on commit 325c960

Please sign in to comment.