diff --git a/app/main/helpers/suppliers.py b/app/main/helpers/suppliers.py index 011a84742..67c6ebf77 100644 --- a/app/main/helpers/suppliers.py +++ b/app/main/helpers/suppliers.py @@ -1,5 +1,8 @@ import os import json +from typing import Union + +from flask import current_app def load_countries(): @@ -65,3 +68,18 @@ def get_company_details_from_supplier(supplier): "registered_name": supplier.get("registeredName"), "address": address } + + +def is_g12_recovery_supplier(supplier_id: Union[str, int]) -> bool: + supplier_ids_string = current_app.config.get('DM_G12_RECOVERY_SUPPLIER_IDS') or '' + + try: + supplier_ids = [int(s) for s in supplier_ids_string.split(sep=',')] + except AttributeError as e: + current_app.logger.error("DM_G12_RECOVERY_SUPPLIER_IDS not a string", extra={'error': str(e)}) + return False + except ValueError as e: + current_app.logger.error("DM_G12_RECOVERY_SUPPLIER_IDS not a list of supplier IDs", extra={'error': str(e)}) + return False + + return int(supplier_id) in supplier_ids diff --git a/app/main/views/suppliers.py b/app/main/views/suppliers.py index bb2e26024..e3b97add5 100644 --- a/app/main/views/suppliers.py +++ b/app/main/views/suppliers.py @@ -38,6 +38,7 @@ COUNTRY_TUPLE, get_country_name_from_country_code, supplier_company_details_are_complete, + is_g12_recovery_supplier, ) from ..helpers import login_required @@ -66,6 +67,8 @@ def dashboard(): )['suppliers'] supplier['contact'] = supplier['contactInformation'][0] + supplier['g12_recovery'] = is_g12_recovery_supplier(supplier['id']) + all_frameworks = sorted( data_api_client.find_frameworks()['frameworks'], key=lambda framework: framework['slug'], diff --git a/app/templates/suppliers/_frameworks_coming.html b/app/templates/suppliers/_frameworks_coming.html index 7e748fa1b..3da389308 100644 --- a/app/templates/suppliers/_frameworks_coming.html +++ b/app/templates/suppliers/_frameworks_coming.html @@ -1,3 +1,14 @@ {% for framework in frameworks.coming %} {# add banners/cards here for messages to suppliers about upcoming frameworks #} {% endfor %} + +{% if supplier.g12_recovery %} + {% with + main = true, + header_level = "h2", + messages = ['This is a placeholder'], + heading = "G-Cloud 12 recovery" + %} + {% include "toolkit/temporary-message.html" %} + {% endwith %} +{% endif %} diff --git a/config.py b/config.py index 4a3b476e8..421b9de5c 100644 --- a/config.py +++ b/config.py @@ -94,6 +94,8 @@ class Config(object): DM_LOG_PATH = None DM_APP_NAME = 'supplier-frontend' + DM_G12_RECOVERY_SUPPLIER_IDS = None + @staticmethod def init_app(app): repo_root = os.path.abspath(os.path.dirname(__file__)) @@ -130,6 +132,8 @@ class Test(Config): DM_ASSETS_URL = 'http://asset-host' + DM_G12_RECOVERY_SUPPLIER_IDS = "577184" + class Development(Config): DEBUG = True @@ -167,6 +171,8 @@ class Development(Config): DM_DNB_API_USERNAME = 'not_a_real_username' DM_DNB_API_PASSWORD = 'not_a_real_password' + DM_G12_RECOVERY_SUPPLIER_IDS = "577184" + class Live(Config): """Base config for deployed environments""" diff --git a/tests/app/main/helpers/test_suppliers.py b/tests/app/main/helpers/test_suppliers.py index 6f0080574..5d3dbf81b 100644 --- a/tests/app/main/helpers/test_suppliers.py +++ b/tests/app/main/helpers/test_suppliers.py @@ -2,10 +2,13 @@ from flask_wtf import FlaskForm from wtforms import StringField from wtforms.validators import Length -from app.main.helpers.suppliers import get_country_name_from_country_code, supplier_company_details_are_complete +from app.main.helpers.suppliers import get_country_name_from_country_code, supplier_company_details_are_complete, \ + is_g12_recovery_supplier from dmtestutils.api_model_stubs import SupplierStub +from tests.app.helpers import BaseApplicationTest + class TestGetCountryNameFromCountryCode: @pytest.mark.parametrize( @@ -49,3 +52,22 @@ class FormForTest(FlaskForm): field_three = StringField('Field three?', validators=[ Length(max=5, message="Field three must be under 5 characters.") ]) + + +class TestG12RecoverySupplier(BaseApplicationTest): + @pytest.mark.parametrize( + 'g12_recovery_supplier_ids, expected_result', + [ + (None, False), + ('', False), + (42, False), + ('12:32', False), + ([123456, 789012], False), + ('123456', True), + ('123456,789012', True), + ] + ) + def test_returns_expected_value_for_input(self, g12_recovery_supplier_ids, expected_result): + with self.app.app_context(): + self.app.config['DM_G12_RECOVERY_SUPPLIER_IDS'] = g12_recovery_supplier_ids + assert is_g12_recovery_supplier('123456') is expected_result diff --git a/tests/app/main/test_suppliers.py b/tests/app/main/test_suppliers.py index fd4b6e6f2..9159db0ea 100644 --- a/tests/app/main/test_suppliers.py +++ b/tests/app/main/test_suppliers.py @@ -627,6 +627,26 @@ def test_shows_continue_with_dos_link(self): assert continue_link assert continue_link[0].values()[0] == "/suppliers/frameworks/digital-outcomes-and-specialists" + def test_shows_placeholder_to_recovery_supplier(self): + self.data_api_client.get_supplier.return_value = get_supplier(id='577184') # Test.DM_G12_RECOVERY_SUPPLIER_IDS + self.login() + + with self.app.app_context(): + response = self.client.get("/suppliers") + document = html.fromstring(response.get_data(as_text=True)) + + assert document.xpath("//h2[normalize-space(string())='G-Cloud 12 recovery']") + + def test_does_not_show_placeholder(self): + self.data_api_client.get_supplier.return_value = get_supplier(id='123456') + self.login() + + with self.app.app_context(): + response = self.client.get("/suppliers") + document = html.fromstring(response.get_data(as_text=True)) + + assert not document.xpath("//h2[normalize-space(string())='G-Cloud 12 recovery']") + class TestSupplierDetails(BaseApplicationTest):