Skip to content

Commit

Permalink
AWS region refactoring
Browse files Browse the repository at this point in the history
* Make AWS region a CLI option and update the appropriate classes to take it as an argument
* Update the corresponding unit tests

* https://mitlibraries.atlassian.net/browse/DLSPP-130
  • Loading branch information
ehanson8 committed Jan 25, 2022
1 parent 3eb7328 commit 3482bef
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 25 deletions.
22 changes: 15 additions & 7 deletions awd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def doi_to_be_retried(doi, doi_items):


@click.group()
@click.option(
"--aws_region",
required=True,
default=config.AWS_REGION_NAME,
help="The AWS region to use for clients.",
)
@click.option(
"--doi_table",
required=True,
Expand Down Expand Up @@ -68,11 +74,12 @@ def doi_to_be_retried(doi, doi_items):
"--log_recipient_email",
required=True,
default=config.LOG_RECIPIENT_EMAIL,
help="The email address receiving the logs. Repeatable",
help="The email address receiving the logs.",
)
@click.pass_context
def cli(
ctx,
aws_region,
doi_table,
sqs_base_url,
sqs_output_queue,
Expand All @@ -87,6 +94,7 @@ def cli(
)
ctx.ensure_object(dict)
ctx.obj["stream"] = stream
ctx.obj["aws_region"] = aws_region
ctx.obj["doi_table"] = doi_table
ctx.obj["sqs_base_url"] = sqs_base_url
ctx.obj["sqs_output_queue"] = sqs_output_queue
Expand Down Expand Up @@ -137,8 +145,8 @@ def deposit(
date = datetime.today().strftime("%m-%d-%Y %H:%M:%S")
stream = ctx.obj["stream"]
s3_client = s3.S3()
sqs_client = sqs.SQS()
dynamodb_client = dynamodb.DynamoDB()
sqs_client = sqs.SQS(ctx.obj["aws_region"])
dynamodb_client = dynamodb.DynamoDB(ctx.obj["aws_region"])

try:
s3_client.client.list_objects_v2(Bucket=bucket)
Expand Down Expand Up @@ -223,7 +231,7 @@ def deposit(
s3_client.archive_file_with_new_key(bucket, doi_file, "archived")
logger.debug("Submission process has completed")

ses_client = ses.SES()
ses_client = ses.SES(ctx.obj["aws_region"])
message = ses_client.create_email(
f"Automated Wiley deposit errors {date}",
stream.getvalue(),
Expand Down Expand Up @@ -254,8 +262,8 @@ def listen(
):
date = datetime.today().strftime("%m-%d-%Y %H:%M:%S")
stream = ctx.obj["stream"]
sqs = SQS()
dynamodb_client = DynamoDB()
sqs = SQS(ctx.obj["aws_region"])
dynamodb_client = DynamoDB(ctx.obj["aws_region"])
try:
for message in sqs.receive(
ctx.obj["sqs_base_url"], ctx.obj["sqs_output_queue"]
Expand Down Expand Up @@ -298,7 +306,7 @@ def listen(
logger.error(
f"Failure while retrieving SQS messages, {e.response['Error']['Message']}"
)
ses_client = SES()
ses_client = SES(ctx.obj["aws_region"])
message = ses_client.create_email(
f"DSS results {date}",
stream.getvalue(),
Expand Down
2 changes: 1 addition & 1 deletion awd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
AWS_REGION_NAME = "us-east-1"

if ENV == "stage" or ENV == "prod":
ssm = SSM()
ssm = SSM(AWS_REGION_NAME)
DOI_TABLE = ssm.get_parameter_value(f"{WILEY_SSM_PATH}dynamodb_table_name")
METADATA_URL = ssm.get_parameter_value(f"{WILEY_SSM_PATH}wiley_metadata_url")
CONTENT_URL = ssm.get_parameter_value(f"{WILEY_SSM_PATH}wiley_content_url")
Expand Down
4 changes: 2 additions & 2 deletions awd/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
class DynamoDB:
"""An DynamoDB class that provides a generic boto3 DynamoDB client."""

def __init__(self):
def __init__(self, region):
self.client = boto3.client(
"dynamodb",
region_name="us-east-1",
region_name=region,
)

def add_doi_item_to_database(self, doi_table, doi):
Expand Down
4 changes: 2 additions & 2 deletions awd/ses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class SES:
"""An SES class that provides a generic boto3 SES client."""

def __init__(self):
self.client = client("ses", region_name="us-east-1")
def __init__(self, region):
self.client = client("ses", region_name=region)

def check_permissions(self, source_email_address, recipient_email_address):
"""Verify that an email can be sent from the specified email address"""
Expand Down
4 changes: 2 additions & 2 deletions awd/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class SQS:
"""An SQS class that provides a generic boto3 SQS client."""

def __init__(self):
self.client = client("sqs", region_name="us-east-1")
def __init__(self, region):
self.client = client("sqs", region_name=region)

def check_read_permissions(self, sqs_base_url, queue_name):
"""Verify that messages can be received from the specified queue."""
Expand Down
4 changes: 2 additions & 2 deletions awd/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class SSM:
still used by boto3. This service contains Parameter Store, which is using for
storing values that can be retrieved via an SSM client."""

def __init__(self):
self.client = client("ssm", region_name="us-east-1")
def __init__(self, region):
self.client = client("ssm", region_name=region)

def check_permissions(self, ssm_path):
"""Check whether we can retrieve an encrypted ssm parameter.
Expand Down
9 changes: 5 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from click.testing import CliRunner
from moto import mock_dynamodb2, mock_iam, mock_s3, mock_ses, mock_sqs, mock_ssm

from awd import config
from awd.dynamodb import DynamoDB
from awd.s3 import S3
from awd.ses import SES
Expand Down Expand Up @@ -65,7 +66,7 @@ def test_aws_user(aws_credentials):

@pytest.fixture(scope="function")
def dynamodb_class():
return DynamoDB()
return DynamoDB(config.AWS_REGION_NAME)


@pytest.fixture(scope="function")
Expand All @@ -75,17 +76,17 @@ def s3_class():

@pytest.fixture(scope="function")
def ses_class():
return SES()
return SES(config.AWS_REGION_NAME)


@pytest.fixture(scope="function")
def sqs_class():
return SQS()
return SQS(config.AWS_REGION_NAME)


@pytest.fixture(scope="function")
def ssm_class():
return SSM()
return SSM(config.AWS_REGION_NAME)


@pytest.fixture(scope="function")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from botocore.exceptions import ClientError
from moto.core import set_initial_no_auth_action_count

from awd import ses
from awd import config, ses


def test_check_permissions_success(mocked_ses, ses_class):
Expand All @@ -23,7 +23,7 @@ def test_check_permissions_raises_error_if_address_not_verified(
os.environ["AWS_ACCESS_KEY_ID"] = test_aws_user["AccessKeyId"]
os.environ["AWS_SECRET_ACCESS_KEY"] = test_aws_user["SecretAccessKey"]
boto3.setup_default_session()
ses_class = ses.SES()
ses_class = ses.SES(config.AWS_REGION_NAME)
with pytest.raises(ClientError) as e:
ses_class.check_permissions("noreply@example.com", "mock@mock.mock")
assert e.value.response["Error"]["Message"] == (
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from botocore.exceptions import ClientError
from moto.core import set_initial_no_auth_action_count

from awd import sqs
from awd import config, sqs


def test_check_read_permissions_success(
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_check_read_permissions_raises_error_if_no_permission(
os.environ["AWS_ACCESS_KEY_ID"] = test_aws_user["AccessKeyId"]
os.environ["AWS_SECRET_ACCESS_KEY"] = test_aws_user["SecretAccessKey"]
boto3.setup_default_session()
sqs_class = sqs.SQS()
sqs_class = sqs.SQS(config.AWS_REGION_NAME)
with pytest.raises(ClientError) as e:
sqs_class.check_read_permissions(
"https://queue.amazonaws.com/123456789012/", "mock-output-queue"
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_check_write_permissions_raises_error_if_no_permission(
os.environ["AWS_ACCESS_KEY_ID"] = test_aws_user["AccessKeyId"]
os.environ["AWS_SECRET_ACCESS_KEY"] = test_aws_user["SecretAccessKey"]
boto3.setup_default_session()
sqs_class = sqs.SQS()
sqs_class = sqs.SQS(config.AWS_REGION_NAME)
with pytest.raises(ClientError) as e:
sqs_class.check_write_permissions(
"https://queue.amazonaws.com/123456789012/", "empty_input_queue"
Expand Down

0 comments on commit 3482bef

Please sign in to comment.