Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from app.clients.document_download import DocumentDownloadClient
from app.clients.email.aws_ses import AwsSesClient
from app.clients.email.aws_ses_stub import AwsSesStubClient
from app.clients.pinpoint.aws_pinpoint import AwsPinpointClient
from app.clients.sms.aws_sns import AwsSnsClient
from notifications_utils import logging, request_helper
from notifications_utils.clients.encryption.encryption_client import Encryption
Expand Down Expand Up @@ -91,17 +90,16 @@ def apply_driver_hacks(self, app, info, options):
"pool_pre_ping": True,
}
)
migrate = Migrate()
migrate = None
notify_celery = NotifyCelery()
aws_ses_client = AwsSesClient()
aws_ses_stub_client = AwsSesStubClient()
aws_ses_client = None
aws_ses_stub_client = None
aws_sns_client = AwsSnsClient()
aws_cloudwatch_client = AwsCloudwatchClient()
aws_pinpoint_client = AwsPinpointClient()
aws_cloudwatch_client = None
encryption = Encryption()
zendesk_client = ZendeskClient()
zendesk_client = None
redis_store = RedisClient()
document_download_client = DocumentDownloadClient()
document_download_client = None

socketio = SocketIO(
cors_allowed_origins=[
Expand All @@ -118,7 +116,39 @@ def apply_driver_hacks(self, app, info, options):
authenticated_service = LocalProxy(lambda: g.authenticated_service)


def get_zendesk_client():
global zendesk_client
# Our unit tests mock anyway
if os.environ.get("NOTIFY_ENVIRONMENT") == "test":
return None
if zendesk_client is None:
zendesk_client = ZendeskClient()
return zendesk_client


def get_aws_ses_client():
global aws_ses_client
if os.environ.get("NOTIFY_ENVIRONMENT") == "test":
return AwsSesClient()
if aws_ses_client is None:
raise RuntimeError(f"Celery not initialized aws_ses_client: {aws_ses_client}")
return aws_ses_client


def get_document_download_client():
global document_download_client
# Our unit tests mock anyway
if os.environ.get("NOTIFY_ENVIRONMENT") == "test":
return None
if document_download_client is None:
raise RuntimeError(
f"Celery not initialized document_download_client: {document_download_client}"
)
return document_download_client


def create_app(application):
global zendesk_client, migrate, document_download_client, aws_ses_client, aws_ses_stub_client
from app.config import configs

notify_environment = os.environ["NOTIFY_ENVIRONMENT"]
Expand All @@ -135,15 +165,26 @@ def create_app(application):
register_socket_handlers(socketio)
request_helper.init_app(application)
db.init_app(application)
migrate.init_app(application, db=db)
zendesk_client.init_app(application)
logging.init_app(application)
aws_sns_client.init_app(application)

# start lazy initialization for gevent
migrate = Migrate()
migrate.init_app(application, db=db)
if zendesk_client is None:
zendesk_client = ZendeskClient()
zendesk_client.init_app(application)
document_download_client = DocumentDownloadClient()
document_download_client.init_app(application)
aws_cloudwatch_client = AwsCloudwatchClient()
aws_cloudwatch_client.init_app(application)
aws_ses_client = AwsSesClient()
aws_ses_client.init_app()
aws_ses_stub_client = AwsSesStubClient()
aws_ses_stub_client.init_app(stub_url=application.config["SES_STUB_URL"])
aws_cloudwatch_client.init_app(application)
aws_pinpoint_client.init_app(application)

# end lazy initialization

# If a stub url is provided for SES, then use the stub client rather than the real SES boto client
email_clients = (
[aws_ses_stub_client]
Expand All @@ -157,7 +198,6 @@ def create_app(application):
notify_celery.init_app(application)
encryption.init_app(application)
redis_store.init_app(application)
document_download_client.init_app(application)

register_blueprint(application)

Expand Down
4 changes: 3 additions & 1 deletion app/celery/scheduled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import between, select, union
from sqlalchemy.exc import SQLAlchemyError

from app import db, notify_celery, redis_store, zendesk_client
from app import db, get_zendesk_client, notify_celery, redis_store
from app.celery.tasks import (
get_recipient_csv_and_template_and_sender_id,
process_incomplete_jobs,
Expand Down Expand Up @@ -44,6 +44,8 @@

MAX_NOTIFICATION_FAILS = 10000

zendesk_client = get_zendesk_client()


@notify_celery.task(name="run-scheduled-jobs")
def run_scheduled_jobs():
Expand Down
2 changes: 0 additions & 2 deletions app/job/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@

@job_blueprint.route("/<job_id>", methods=["GET"])
def get_job_by_service_and_job_id(service_id, job_id):
current_app.logger.info(hilite("ENTER get_job_by_service_and_job_id"))
check_suspicious_id(service_id, job_id)
job = dao_get_job_by_service_id_and_job_id(service_id, job_id)
statistics = dao_get_notification_outcomes_for_job(service_id, job_id)
Expand Down Expand Up @@ -150,7 +149,6 @@ def get_all_notifications_for_service_job(service_id, job_id):
notifications = notification_with_template_schema.dump(
paginated_notifications.items, many=True
)
current_app.logger.info(hilite("Got the dumped notifications and returning"))

return (
jsonify(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ xfail_strict=true
exclude = venv*,__pycache__,node_modules,cache,migrations,build,sample_cap_xml_documents.py
max-line-length = 120
# W504 line break after binary operator
extend_ignore=B306, W504, E203
extend_ignore=B306, W504, E203, F824

[isort]
profile = black
Expand Down
7 changes: 3 additions & 4 deletions tests/app/celery/test_scheduled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,9 @@ def test_check_for_services_with_high_failure_rates_or_sending_to_tv_numbers(
):
mock_logger = mocker.patch("app.celery.tasks.current_app.logger.warning")
mock_create_ticket = mocker.spy(NotifySupportTicket, "__init__")
mock_send_ticket_to_zendesk = mocker.patch(
"app.celery.scheduled_tasks.zendesk_client.send_ticket_to_zendesk",
autospec=True,
)
mock_zendesk_client = MagicMock()
mocker.patch("app.celery.scheduled_tasks.zendesk_client", mock_zendesk_client)
mock_send_ticket_to_zendesk = mock_zendesk_client.send_ticket_to_zendesk
mock_failure_rates = mocker.patch(
"app.celery.scheduled_tasks.dao_find_services_with_high_failure_rates",
return_value=failure_rates,
Expand Down
40 changes: 25 additions & 15 deletions tests/app/clients/test_aws_cloudwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import pytest
from flask import current_app

from app import aws_cloudwatch_client
from app.clients.cloudwatch.aws_cloudwatch import AwsCloudwatchClient


def test_check_sms_no_event_error_condition(notify_api, mocker):
boto_mock = mocker.patch.object(aws_cloudwatch_client, "_client", create=True)
client = AwsCloudwatchClient()

boto_mock = mocker.patch.object(client, "_client", create=True)
# TODO
# we do this to get the AWS account number, and it seems like unit tests locally have
# access to the env variables but when we push the PR they do not. Is there a better way to get it?
Expand All @@ -21,9 +22,9 @@ def test_check_sms_no_event_error_condition(notify_api, mocker):
notification_id = "bbb"
boto_mock.filter_log_events.return_value = []
with notify_api.app_context():
aws_cloudwatch_client.init_app(current_app)
client.init_app(current_app)
try:
aws_cloudwatch_client.check_sms(message_id, notification_id)
client.check_sms(message_id, notification_id)
assert 1 == 0
except Exception:
assert 1 == 1
Expand Down Expand Up @@ -61,27 +62,34 @@ def side_effect(filterPattern, logGroupName, startTime, endTime):

def test_extract_account_number_gov_cloud():
domain_arn = "arn:aws-us-gov:ses:us-gov-west-1:12345:identity/ses-abc.xxx.xxx.xxx"
actual_account_number = aws_cloudwatch_client._extract_account_number(domain_arn)
client = AwsCloudwatchClient()
client.init_app(current_app)
actual_account_number = client._extract_account_number(domain_arn)
assert len(actual_account_number) == 6
expected_account_number = "12345"
assert actual_account_number[4] == expected_account_number


def test_extract_account_number_gov_staging():
domain_arn = "arn:aws:ses:us-south-14:12345:identity/ses-abc.xxx.xxx.xxx"
actual_account_number = aws_cloudwatch_client._extract_account_number(domain_arn)
client = AwsCloudwatchClient()
client.init_app(current_app)
actual_account_number = client._extract_account_number(domain_arn)
assert len(actual_account_number) == 6
expected_account_number = "12345"
assert actual_account_number[4] == expected_account_number


def test_event_to_db_format_with_missing_fields():
client = AwsCloudwatchClient()
client.init_app(current_app)

event = {
"notification": {"messageId": "12345"},
"status": "UNKNOWN",
"delivery": {},
}
result = aws_cloudwatch_client.event_to_db_format(event)
result = client.event_to_db_format(event)
assert result == {
"notification.messageId": "12345",
"status": "UNKNOWN",
Expand All @@ -104,7 +112,10 @@ def test_event_to_db_format_with_string_input():
},
}
)
result = aws_cloudwatch_client.event_to_db_format(event)
client = AwsCloudwatchClient()
client.init_app(current_app)

result = client.event_to_db_format(event)
assert result == {
"notification.messageId": "67890",
"status": "FAILED",
Expand All @@ -128,8 +139,7 @@ def fake_event():
}


@patch("app.clients.cloudwatch.aws_cloudwatch.current_app")
def test_warn_if_dev_is_opted_out(current_app_mock):
def test_warn_if_dev_is_opted_out():
# os.environ["NOTIFIY_ENVIRONMENT"] = "development"
client = AwsCloudwatchClient()
logline = client.warn_if_dev_is_opted_out("Number is opted out", "notif123")
Expand Down Expand Up @@ -182,8 +192,8 @@ def test_extract_account_number():
@patch("app.clients.cloudwatch.aws_cloudwatch.client")
def test_get_log_with_pagination(mock_client):
client = AwsCloudwatchClient()
client.init_app(current_app)
client._client = mock_client

mock_client.filter_log_events.side_effect = [
{"events": [{"message": "msg1"}], "nextToken": "abc"},
{"events": [{"message": "msg2"}]},
Expand All @@ -198,8 +208,8 @@ def test_get_log_with_pagination(mock_client):
assert logs[1]["message"] == "msg2"


@patch("app.clients.cloudwatch.aws_cloudwatch.current_app")
def test_get_receipts(mock_current_app):
# @patch("app.clients.cloudwatch.aws_cloudwatch.current_app")
def test_get_receipts():
client = AwsCloudwatchClient()
client._get_log = MagicMock(
return_value=[
Expand All @@ -225,9 +235,9 @@ def test_get_receipts(mock_current_app):
assert event["status"] == "DELIVERED"


@patch("app.clients.cloudwatch.aws_cloudwatch.current_app")
# @patch("app.clients.cloudwatch.aws_cloudwatch.current_app")
@patch("app.clients.cloudwatch.aws_cloudwatch.cloud_config")
def test_check_delivery_receipts(mock_cloud_config, current_app_mock):
def test_check_delivery_receipts(mock_cloud_config):
client = AwsCloudwatchClient()
mock_cloud_config.sns_regions = "us-north-1"
mock_cloud_config.ses_domain_arn = (
Expand Down
10 changes: 9 additions & 1 deletion tests/app/clients/test_aws_ses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import botocore
import pytest

from app import AwsSesStubClient, aws_ses_client
from app.clients.email import EmailClientNonRetryableException
from app.clients.email.aws_ses import (
AwsSesClient,
AwsSesClientException,
AwsSesClientThrottlingSendRateException,
get_aws_responses,
)
from app.clients.email.aws_ses_stub import AwsSesStubClient
from app.enums import NotificationStatus, StatisticsType


Expand Down Expand Up @@ -65,6 +66,8 @@ def test_should_be_none_if_unrecognised_status_code():
def test_send_email_handles_reply_to_address(
notify_api, mocker, reply_to_address, expected_value
):
aws_ses_client = AwsSesClient()

boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True)

with notify_api.app_context():
Expand All @@ -82,6 +85,7 @@ def test_send_email_handles_reply_to_address(


def test_send_email_handles_punycode_to_address(notify_api, mocker):
aws_ses_client = AwsSesClient()
boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True)

with notify_api.app_context():
Expand All @@ -104,6 +108,7 @@ def test_send_email_handles_punycode_to_address(notify_api, mocker):
def test_send_email_raises_invalid_parameter_value_error_as_EmailClientNonRetryableException(
mocker,
):
aws_ses_client = AwsSesClient()
boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True)
error_response = {
"Error": {
Expand All @@ -130,6 +135,7 @@ def test_send_email_raises_invalid_parameter_value_error_as_EmailClientNonRetrya
def test_send_email_raises_send_rate_throttling_as_AwsSesClientThrottlingSendRateException(
mocker,
):
aws_ses_client = AwsSesClient()
boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True)
error_response = {
"Error": {
Expand All @@ -151,6 +157,7 @@ def test_send_email_raises_send_rate_throttling_as_AwsSesClientThrottlingSendRat
def test_send_email_does_not_raise_AwsSesClientThrottlingSendRateException_if_non_send_rate_throttling(
mocker,
):
aws_ses_client = AwsSesClient()
boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True)
error_response = {
"Error": {
Expand All @@ -170,6 +177,7 @@ def test_send_email_does_not_raise_AwsSesClientThrottlingSendRateException_if_no


def test_send_email_raises_other_errs_as_AwsSesClientException(mocker):
aws_ses_client = AwsSesClient()
boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True)
error_response = {
"Error": {
Expand Down
Loading