Skip to content

Commit

Permalink
Handle wrong value of DM_G12_RECOVERY_SUPPLIER_IDS
Browse files Browse the repository at this point in the history
It's set from an environment variable. So it should always be set to something. However, in case we mess things up, log the error and treat the supplier as not part of the G12 recovery.
  • Loading branch information
bjgill committed Dec 1, 2020
1 parent 51be819 commit 5a608a3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
15 changes: 12 additions & 3 deletions app/main/helpers/suppliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ def get_company_details_from_supplier(supplier):


def is_g12_recovery_supplier(supplier_id: Union[str, int]) -> bool:
supplier_ids = current_app.config.get('DM_G12_RECOVERY_SUPPLIER_IDS') or ''

return str(supplier_id) in supplier_ids.split(sep=',')
supplier_ids_string = current_app.config.get('DM_G12_RECOVERY_SUPPLIER_IDS') or ''

try:
supplier_ids = [int(s) for s in supplier_ids_string.split(sep=',')]
except AttributeError as e:
current_app.logger.error("DM_G12_RECOVERY_SUPPLIER_IDS not a string", extra={'error': str(e)})
return False
except ValueError as e:
current_app.logger.error("DM_G12_RECOVERY_SUPPLIER_IDS not a list of supplier IDs", extra={'error': str(e)})
return False

return int(supplier_id) in supplier_ids
24 changes: 23 additions & 1 deletion tests/app/main/helpers/test_suppliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from flask_wtf import FlaskForm
from wtforms import StringField
from wtforms.validators import Length
from app.main.helpers.suppliers import get_country_name_from_country_code, supplier_company_details_are_complete
from app.main.helpers.suppliers import get_country_name_from_country_code, supplier_company_details_are_complete, \
is_g12_recovery_supplier

from dmtestutils.api_model_stubs import SupplierStub

from tests.app.helpers import BaseApplicationTest


class TestGetCountryNameFromCountryCode:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -49,3 +52,22 @@ class FormForTest(FlaskForm):
field_three = StringField('Field three?', validators=[
Length(max=5, message="Field three must be under 5 characters.")
])


class TestG12RecoverySupplier(BaseApplicationTest):
@pytest.mark.parametrize(
'g12_recovery_supplier_ids, expected_result',
[
(None, False),
('', False),
(42, False),
('12:32', False),
([123456, 789012], False),
('123456', True),
('123456,789012', True),
]
)
def test_returns_expected_value_for_input(self, g12_recovery_supplier_ids, expected_result):
with self.app.app_context():
self.app.config['DM_G12_RECOVERY_SUPPLIER_IDS'] = g12_recovery_supplier_ids
assert is_g12_recovery_supplier('123456') is expected_result

0 comments on commit 5a608a3

Please sign in to comment.