diff --git a/app/__init__.py b/app/__init__.py index 7704e7375..9fad21e30 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -31,7 +31,6 @@ from app.clients.document_download import DocumentDownloadClient from app.clients.email.aws_ses import AwsSesClient from app.clients.email.aws_ses_stub import AwsSesStubClient -from app.clients.pinpoint.aws_pinpoint import AwsPinpointClient from app.clients.sms.aws_sns import AwsSnsClient from notifications_utils import logging, request_helper from notifications_utils.clients.encryption.encryption_client import Encryption @@ -91,17 +90,16 @@ def apply_driver_hacks(self, app, info, options): "pool_pre_ping": True, } ) -migrate = Migrate() +migrate = None notify_celery = NotifyCelery() -aws_ses_client = AwsSesClient() -aws_ses_stub_client = AwsSesStubClient() +aws_ses_client = None +aws_ses_stub_client = None aws_sns_client = AwsSnsClient() -aws_cloudwatch_client = AwsCloudwatchClient() -aws_pinpoint_client = AwsPinpointClient() +aws_cloudwatch_client = None encryption = Encryption() -zendesk_client = ZendeskClient() +zendesk_client = None redis_store = RedisClient() -document_download_client = DocumentDownloadClient() +document_download_client = None socketio = SocketIO( cors_allowed_origins=[ @@ -118,7 +116,39 @@ def apply_driver_hacks(self, app, info, options): authenticated_service = LocalProxy(lambda: g.authenticated_service) +def get_zendesk_client(): + global zendesk_client + # Our unit tests mock anyway + if os.environ.get("NOTIFY_ENVIRONMENT") == "test": + return None + if zendesk_client is None: + zendesk_client = ZendeskClient() + return zendesk_client + + +def get_aws_ses_client(): + global aws_ses_client + if os.environ.get("NOTIFY_ENVIRONMENT") == "test": + return AwsSesClient() + if aws_ses_client is None: + raise RuntimeError(f"Celery not initialized aws_ses_client: {aws_ses_client}") + return aws_ses_client + + +def get_document_download_client(): + global document_download_client + # Our unit tests mock anyway + if os.environ.get("NOTIFY_ENVIRONMENT") == "test": + return None + if document_download_client is None: + raise RuntimeError( + f"Celery not initialized document_download_client: {document_download_client}" + ) + return document_download_client + + def create_app(application): + global zendesk_client, migrate, document_download_client, aws_ses_client, aws_ses_stub_client from app.config import configs notify_environment = os.environ["NOTIFY_ENVIRONMENT"] @@ -135,15 +165,26 @@ def create_app(application): register_socket_handlers(socketio) request_helper.init_app(application) db.init_app(application) - migrate.init_app(application, db=db) - zendesk_client.init_app(application) logging.init_app(application) aws_sns_client.init_app(application) + # start lazy initialization for gevent + migrate = Migrate() + migrate.init_app(application, db=db) + if zendesk_client is None: + zendesk_client = ZendeskClient() + zendesk_client.init_app(application) + document_download_client = DocumentDownloadClient() + document_download_client.init_app(application) + aws_cloudwatch_client = AwsCloudwatchClient() + aws_cloudwatch_client.init_app(application) + aws_ses_client = AwsSesClient() aws_ses_client.init_app() + aws_ses_stub_client = AwsSesStubClient() aws_ses_stub_client.init_app(stub_url=application.config["SES_STUB_URL"]) - aws_cloudwatch_client.init_app(application) - aws_pinpoint_client.init_app(application) + + # end lazy initialization + # If a stub url is provided for SES, then use the stub client rather than the real SES boto client email_clients = ( [aws_ses_stub_client] @@ -157,7 +198,6 @@ def create_app(application): notify_celery.init_app(application) encryption.init_app(application) redis_store.init_app(application) - document_download_client.init_app(application) register_blueprint(application) diff --git a/app/celery/scheduled_tasks.py b/app/celery/scheduled_tasks.py index 05d4f72a7..a8a097030 100644 --- a/app/celery/scheduled_tasks.py +++ b/app/celery/scheduled_tasks.py @@ -5,7 +5,7 @@ from sqlalchemy import between, select, union from sqlalchemy.exc import SQLAlchemyError -from app import db, notify_celery, redis_store, zendesk_client +from app import db, get_zendesk_client, notify_celery, redis_store from app.celery.tasks import ( get_recipient_csv_and_template_and_sender_id, process_incomplete_jobs, @@ -44,6 +44,8 @@ MAX_NOTIFICATION_FAILS = 10000 +zendesk_client = get_zendesk_client() + @notify_celery.task(name="run-scheduled-jobs") def run_scheduled_jobs(): diff --git a/app/job/rest.py b/app/job/rest.py index 45207147c..fa7165d28 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -49,7 +49,6 @@ @job_blueprint.route("/", methods=["GET"]) def get_job_by_service_and_job_id(service_id, job_id): - current_app.logger.info(hilite("ENTER get_job_by_service_and_job_id")) check_suspicious_id(service_id, job_id) job = dao_get_job_by_service_id_and_job_id(service_id, job_id) statistics = dao_get_notification_outcomes_for_job(service_id, job_id) @@ -150,7 +149,6 @@ def get_all_notifications_for_service_job(service_id, job_id): notifications = notification_with_template_schema.dump( paginated_notifications.items, many=True ) - current_app.logger.info(hilite("Got the dumped notifications and returning")) return ( jsonify( diff --git a/setup.cfg b/setup.cfg index f6dc999cb..c3af9dc69 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,7 +6,7 @@ xfail_strict=true exclude = venv*,__pycache__,node_modules,cache,migrations,build,sample_cap_xml_documents.py max-line-length = 120 # W504 line break after binary operator -extend_ignore=B306, W504, E203 +extend_ignore=B306, W504, E203, F824 [isort] profile = black diff --git a/tests/app/celery/test_scheduled_tasks.py b/tests/app/celery/test_scheduled_tasks.py index 894afa6f7..f162541cf 100644 --- a/tests/app/celery/test_scheduled_tasks.py +++ b/tests/app/celery/test_scheduled_tasks.py @@ -498,10 +498,9 @@ def test_check_for_services_with_high_failure_rates_or_sending_to_tv_numbers( ): mock_logger = mocker.patch("app.celery.tasks.current_app.logger.warning") mock_create_ticket = mocker.spy(NotifySupportTicket, "__init__") - mock_send_ticket_to_zendesk = mocker.patch( - "app.celery.scheduled_tasks.zendesk_client.send_ticket_to_zendesk", - autospec=True, - ) + mock_zendesk_client = MagicMock() + mocker.patch("app.celery.scheduled_tasks.zendesk_client", mock_zendesk_client) + mock_send_ticket_to_zendesk = mock_zendesk_client.send_ticket_to_zendesk mock_failure_rates = mocker.patch( "app.celery.scheduled_tasks.dao_find_services_with_high_failure_rates", return_value=failure_rates, diff --git a/tests/app/clients/test_aws_cloudwatch.py b/tests/app/clients/test_aws_cloudwatch.py index 15f57516e..4671dd56d 100644 --- a/tests/app/clients/test_aws_cloudwatch.py +++ b/tests/app/clients/test_aws_cloudwatch.py @@ -7,12 +7,13 @@ import pytest from flask import current_app -from app import aws_cloudwatch_client from app.clients.cloudwatch.aws_cloudwatch import AwsCloudwatchClient def test_check_sms_no_event_error_condition(notify_api, mocker): - boto_mock = mocker.patch.object(aws_cloudwatch_client, "_client", create=True) + client = AwsCloudwatchClient() + + boto_mock = mocker.patch.object(client, "_client", create=True) # TODO # we do this to get the AWS account number, and it seems like unit tests locally have # access to the env variables but when we push the PR they do not. Is there a better way to get it? @@ -21,9 +22,9 @@ def test_check_sms_no_event_error_condition(notify_api, mocker): notification_id = "bbb" boto_mock.filter_log_events.return_value = [] with notify_api.app_context(): - aws_cloudwatch_client.init_app(current_app) + client.init_app(current_app) try: - aws_cloudwatch_client.check_sms(message_id, notification_id) + client.check_sms(message_id, notification_id) assert 1 == 0 except Exception: assert 1 == 1 @@ -61,7 +62,9 @@ def side_effect(filterPattern, logGroupName, startTime, endTime): def test_extract_account_number_gov_cloud(): domain_arn = "arn:aws-us-gov:ses:us-gov-west-1:12345:identity/ses-abc.xxx.xxx.xxx" - actual_account_number = aws_cloudwatch_client._extract_account_number(domain_arn) + client = AwsCloudwatchClient() + client.init_app(current_app) + actual_account_number = client._extract_account_number(domain_arn) assert len(actual_account_number) == 6 expected_account_number = "12345" assert actual_account_number[4] == expected_account_number @@ -69,19 +72,24 @@ def test_extract_account_number_gov_cloud(): def test_extract_account_number_gov_staging(): domain_arn = "arn:aws:ses:us-south-14:12345:identity/ses-abc.xxx.xxx.xxx" - actual_account_number = aws_cloudwatch_client._extract_account_number(domain_arn) + client = AwsCloudwatchClient() + client.init_app(current_app) + actual_account_number = client._extract_account_number(domain_arn) assert len(actual_account_number) == 6 expected_account_number = "12345" assert actual_account_number[4] == expected_account_number def test_event_to_db_format_with_missing_fields(): + client = AwsCloudwatchClient() + client.init_app(current_app) + event = { "notification": {"messageId": "12345"}, "status": "UNKNOWN", "delivery": {}, } - result = aws_cloudwatch_client.event_to_db_format(event) + result = client.event_to_db_format(event) assert result == { "notification.messageId": "12345", "status": "UNKNOWN", @@ -104,7 +112,10 @@ def test_event_to_db_format_with_string_input(): }, } ) - result = aws_cloudwatch_client.event_to_db_format(event) + client = AwsCloudwatchClient() + client.init_app(current_app) + + result = client.event_to_db_format(event) assert result == { "notification.messageId": "67890", "status": "FAILED", @@ -128,8 +139,7 @@ def fake_event(): } -@patch("app.clients.cloudwatch.aws_cloudwatch.current_app") -def test_warn_if_dev_is_opted_out(current_app_mock): +def test_warn_if_dev_is_opted_out(): # os.environ["NOTIFIY_ENVIRONMENT"] = "development" client = AwsCloudwatchClient() logline = client.warn_if_dev_is_opted_out("Number is opted out", "notif123") @@ -182,8 +192,8 @@ def test_extract_account_number(): @patch("app.clients.cloudwatch.aws_cloudwatch.client") def test_get_log_with_pagination(mock_client): client = AwsCloudwatchClient() + client.init_app(current_app) client._client = mock_client - mock_client.filter_log_events.side_effect = [ {"events": [{"message": "msg1"}], "nextToken": "abc"}, {"events": [{"message": "msg2"}]}, @@ -198,8 +208,8 @@ def test_get_log_with_pagination(mock_client): assert logs[1]["message"] == "msg2" -@patch("app.clients.cloudwatch.aws_cloudwatch.current_app") -def test_get_receipts(mock_current_app): +# @patch("app.clients.cloudwatch.aws_cloudwatch.current_app") +def test_get_receipts(): client = AwsCloudwatchClient() client._get_log = MagicMock( return_value=[ @@ -225,9 +235,9 @@ def test_get_receipts(mock_current_app): assert event["status"] == "DELIVERED" -@patch("app.clients.cloudwatch.aws_cloudwatch.current_app") +# @patch("app.clients.cloudwatch.aws_cloudwatch.current_app") @patch("app.clients.cloudwatch.aws_cloudwatch.cloud_config") -def test_check_delivery_receipts(mock_cloud_config, current_app_mock): +def test_check_delivery_receipts(mock_cloud_config): client = AwsCloudwatchClient() mock_cloud_config.sns_regions = "us-north-1" mock_cloud_config.ses_domain_arn = ( diff --git a/tests/app/clients/test_aws_ses.py b/tests/app/clients/test_aws_ses.py index 302fe2dd9..a4fe8003e 100644 --- a/tests/app/clients/test_aws_ses.py +++ b/tests/app/clients/test_aws_ses.py @@ -5,13 +5,14 @@ import botocore import pytest -from app import AwsSesStubClient, aws_ses_client from app.clients.email import EmailClientNonRetryableException from app.clients.email.aws_ses import ( + AwsSesClient, AwsSesClientException, AwsSesClientThrottlingSendRateException, get_aws_responses, ) +from app.clients.email.aws_ses_stub import AwsSesStubClient from app.enums import NotificationStatus, StatisticsType @@ -65,6 +66,8 @@ def test_should_be_none_if_unrecognised_status_code(): def test_send_email_handles_reply_to_address( notify_api, mocker, reply_to_address, expected_value ): + aws_ses_client = AwsSesClient() + boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True) with notify_api.app_context(): @@ -82,6 +85,7 @@ def test_send_email_handles_reply_to_address( def test_send_email_handles_punycode_to_address(notify_api, mocker): + aws_ses_client = AwsSesClient() boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True) with notify_api.app_context(): @@ -104,6 +108,7 @@ def test_send_email_handles_punycode_to_address(notify_api, mocker): def test_send_email_raises_invalid_parameter_value_error_as_EmailClientNonRetryableException( mocker, ): + aws_ses_client = AwsSesClient() boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True) error_response = { "Error": { @@ -130,6 +135,7 @@ def test_send_email_raises_invalid_parameter_value_error_as_EmailClientNonRetrya def test_send_email_raises_send_rate_throttling_as_AwsSesClientThrottlingSendRateException( mocker, ): + aws_ses_client = AwsSesClient() boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True) error_response = { "Error": { @@ -151,6 +157,7 @@ def test_send_email_raises_send_rate_throttling_as_AwsSesClientThrottlingSendRat def test_send_email_does_not_raise_AwsSesClientThrottlingSendRateException_if_non_send_rate_throttling( mocker, ): + aws_ses_client = AwsSesClient() boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True) error_response = { "Error": { @@ -170,6 +177,7 @@ def test_send_email_does_not_raise_AwsSesClientThrottlingSendRateException_if_no def test_send_email_raises_other_errs_as_AwsSesClientException(mocker): + aws_ses_client = AwsSesClient() boto_mock = mocker.patch.object(aws_ses_client, "_client", create=True) error_response = { "Error": { diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index a06d13a51..ba0837f8e 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -1,6 +1,6 @@ import json from collections import namedtuple -from unittest.mock import ANY +from unittest.mock import ANY, MagicMock import pytest from flask import current_app @@ -142,9 +142,19 @@ def test_should_send_personalised_template_to_correct_email_provider_and_persist template=sample_email_template_with_html, ) db_notification.personalisation = {"name": "Jo"} - mocker.patch("app.aws_ses_client.send_email", return_value="reference") + + mock_boto_client = mocker.patch("boto3.client") + mock_ses = MagicMock() + mock_boto_client.return_value = mock_ses + mock_ses.send_email.return_value = "reference" + mock_ses.name = "ses" + mocker.patch( + "app.delivery.send_to_providers.provider_to_use", + return_value=mock_ses, + ) + send_to_providers.send_email_to_provider(db_notification) - app.aws_ses_client.send_email.assert_called_once_with( + mock_ses.send_email.assert_called_once_with( f'"Sample service" ', "jo.smith@example.com", "Jo some HTML", @@ -153,10 +163,10 @@ def test_should_send_personalised_template_to_correct_email_provider_and_persist reply_to_address=None, ) - assert "