Skip to content

Commit

Permalink
feat: add argument max retry and retry delay
Browse files Browse the repository at this point in the history
  • Loading branch information
valter-silva-au committed Aug 3, 2023
1 parent 35775ea commit 5a15d5b
Showing 1 changed file with 151 additions and 25 deletions.
176 changes: 151 additions & 25 deletions scan.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
# -*- coding: utf-8 -*-
import boto3
import logging
import json
# Required modules
import argparse
import boto3
import botocore
import concurrent.futures
import datetime
import json
import logging
import os
from datetime import datetime
import traceback
import botocore
import time
import traceback
from datetime import datetime

MAX_RETRIES = 3

# Get the current timestamp
# Define the timestamp as a string, which will be the same throughout the execution of the script.
timestamp = datetime.now().isoformat(timespec="minutes")


class DateTimeEncoder(json.JSONEncoder):
"""Custom JSONEncoder that supports encoding datetime objects."""

def default(self, o):
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)


def setup_logging(log_dir, log_level):
"""Set up the logging system."""
os.makedirs(log_dir, exist_ok=True)
log_filename = f"aws_resources_{timestamp}.log"
log_file = os.path.join(log_dir, log_filename)

# Configure the logger
logger = logging.getLogger(__name__)
logger.setLevel(log_level)
handler = logging.FileHandler(log_file)
Expand All @@ -40,9 +45,16 @@ def setup_logging(log_dir, log_level):
return logging.getLogger(__name__)


def api_call_with_retry(client, function_name, parameters):
def api_call_with_retry(client, function_name, parameters, max_retries, retry_delay):
"""
Make an API call with exponential backoff.
This function will make an API call with retries. It will exponentially back off
with a delay of `retry_delay * 2^attempt` for transient errors.
"""

def api_call():
for attempt in range(MAX_RETRIES):
for attempt in range(max_retries):
try:
function_to_call = getattr(client, function_name)
if parameters:
Expand All @@ -52,21 +64,39 @@ def api_call():
except botocore.exceptions.ClientError as error:
error_code = error.response["Error"]["Code"]
if error_code == "Throttling":
if attempt < (MAX_RETRIES - 1): # no delay on last attempt
time.sleep(2**attempt)
if attempt < (max_retries - 1): # no delay on last attempt
time.sleep(retry_delay**attempt)
continue
elif error_code == "RequestLimitExceeded":
time.sleep(retry_delay**attempt)
continue
else:
raise
except botocore.exceptions.BotoCoreError:
if attempt < (MAX_RETRIES - 1): # no delay on last attempt
time.sleep(2**attempt)
if attempt < (max_retries - 1): # no delay on last attempt
time.sleep(retry_delay**attempt)
continue
return None

return api_call


def _get_service_data(session, region_name, service, log):
def _get_service_data(session, region_name, service, log, max_retries, retry_delay):
"""
Get data for a specific AWS service in a region.
Arguments:
session -- The boto3 Session.
region_name -- The AWS region to process.
service -- The AWS service to scan.
log -- The logger object.
max_retries -- The maximum number of retries for each service.
retry_delay -- The delay before each retry.
Returns:
service_data -- The service data.
"""

function = service["function"]
result_key = service.get("result_key", None)
parameters = service.get("parameters", None)
Expand All @@ -88,7 +118,9 @@ def _get_service_data(session, region_name, service, log):
region_name,
)
return None
api_call = api_call_with_retry(client, function, parameters)
api_call = api_call_with_retry(
client, function, parameters, max_retries, retry_delay
)
if result_key:
response = api_call().get(result_key)
else:
Expand Down Expand Up @@ -117,13 +149,41 @@ def _get_service_data(session, region_name, service, log):
return {"region": region_name, "service": service["service"], "result": response}


def process_region(region, services, session, log):
def process_region(
region, services, session, log, max_retries, retry_delay, concurrent_services
):
"""
Processes a single AWS region.
Arguments:
region -- The AWS region to process.
services -- The AWS services to scan.
session -- The boto3 Session.
log -- The logger object.
max_retries -- The maximum number of retries for each service.
retry_delay -- The delay before each retry.
concurrent_services -- The number of services to process concurrently for each region.
Returns:
region_results -- The scan results for the region.
"""

log.info("Started processing for region: %s", region)

region_results = []
with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_services
) as executor:
future_to_service = {
executor.submit(_get_service_data, session, region, service, log): service
executor.submit(
_get_service_data,
session,
region,
service,
log,
max_retries,
retry_delay,
): service
for service in services
}
for future in concurrent.futures.as_completed(future_to_service):
Expand All @@ -150,8 +210,29 @@ def display_time(seconds):
return f"{int(hours)}h:{int(minutes)}m:{int(seconds)}s"


def main(scan, regions, output_dir, log_level):
import time
def main(
scan,
regions,
output_dir,
log_level,
max_retries,
retry_delay,
concurrent_regions,
concurrent_services,
):
"""
Main function to perform the AWS services scan.
Arguments:
scan -- The path to the JSON file containing the AWS services to scan.
regions -- The AWS regions to scan.
output_dir -- The directory to store the results.
log_level -- The log level for the script.
max_retries -- The maximum number of retries for each service.
retry_delay -- The delay before each retry.
concurrent_regions -- The number of regions to process concurrently.
concurrent_services -- The number of services to process concurrently for each region.
"""

session = boto3.Session()
log = setup_logging(output_dir, log_level)
Expand All @@ -170,9 +251,20 @@ def main(scan, regions, output_dir, log_level):
start_time = time.time()

results = []
with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_regions
) as executor:
future_to_region = {
executor.submit(process_region, region, services, session, log): region
executor.submit(
process_region,
region,
services,
session,
log,
max_retries,
retry_delay,
concurrent_services,
): region
for region in regions
}
for future in concurrent.futures.as_completed(future_to_region):
Expand Down Expand Up @@ -219,5 +311,39 @@ def main(scan, regions, output_dir, log_level):
default="INFO",
help="Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)
# New arguments
parser.add_argument(
"--max-retries",
type=int,
default=3,
help="Maximum number of retries for each service",
)
parser.add_argument(
"--retry-delay",
type=int,
default=2,
help="Delay (in seconds) before each retry",
)
parser.add_argument(
"--concurrent-regions",
type=int,
default=None,
help="Number of regions to process concurrently. Default is None, which means the script will use as many as possible",
)
parser.add_argument(
"--concurrent-services",
type=int,
default=None,
help="Number of services to process concurrently for each region. Default is None, which means the script will use as many as possible",
)
args = parser.parse_args()
main(args.scan, args.regions, args.output_dir, args.log_level)
main(
args.scan,
args.regions,
args.output_dir,
args.log_level,
args.max_retries,
args.retry_delay,
args.concurrent_regions,
args.concurrent_services,
)

0 comments on commit 5a15d5b

Please sign in to comment.