diff --git a/src/mock_vws/_mock_web_services_api.py b/src/mock_vws/_mock_web_services_api.py index 0aa8c6cbe..cb9d8f060 100644 --- a/src/mock_vws/_mock_web_services_api.py +++ b/src/mock_vws/_mock_web_services_api.py @@ -43,7 +43,9 @@ validate_width, ) from ._services_validators.auth_validators import ( + validate_access_key_exists, validate_auth_header_exists, + validate_auth_header_has_signature, validate_authorization, ) from ._services_validators.date_validators import ( @@ -222,6 +224,8 @@ def decorator(method: Callable[..., str]) -> Callable[..., str]: ] common_decorators = [ + validate_access_key_exists, + validate_auth_header_has_signature, validate_auth_header_exists, set_content_length_header, update_request_count, diff --git a/src/mock_vws/_query_validators/auth_validators.py b/src/mock_vws/_query_validators/auth_validators.py index b2341e97f..ad3ed394e 100644 --- a/src/mock_vws/_query_validators/auth_validators.py +++ b/src/mock_vws/_query_validators/auth_validators.py @@ -101,7 +101,7 @@ def validate_client_key_exists( Returns: The result of calling the endpoint. - An ``UNAUTHORIZED`` FOOBAR. + An ``UNAUTHORIZED`` response if the client key is unknown. """ request, context = args diff --git a/src/mock_vws/_services_validators/auth_validators.py b/src/mock_vws/_services_validators/auth_validators.py index 1b2f3b173..fd524294c 100644 --- a/src/mock_vws/_services_validators/auth_validators.py +++ b/src/mock_vws/_services_validators/auth_validators.py @@ -47,6 +47,79 @@ def validate_auth_header_exists( return json_dump(body) +@wrapt.decorator +def validate_access_key_exists( + wrapped: Callable[..., str], + instance: Any, + args: Tuple[_RequestObjectProxy, _Context], + kwargs: Dict, +) -> str: + """ + Validate the authorization header includes an access key for a database. + + Args: + wrapped: An endpoint function for `requests_mock`. + instance: The class that the endpoint function is in. + args: The arguments given to the endpoint function. + kwargs: The keyword arguments given to the endpoint function. + + Returns: + The result of calling the endpoint. + An ``UNAUTHORIZED`` response if the access key is unknown. + """ + request, context = args + + header = request.headers['Authorization'] + first_part, _ = header.split(b':') + _, access_key = first_part.split(b' ') + for database in instance.databases: + if access_key == database.server_access_key: + return wrapped(*args, **kwargs) + + context.status_code = codes.BAD_REQUEST + + body = { + 'transaction_id': uuid.uuid4().hex, + 'result_code': ResultCodes.FAIL.value, + } + return json_dump(body) + + +@wrapt.decorator +def validate_auth_header_has_signature( + wrapped: Callable[..., str], + instance: Any, # pylint: disable=unused-argument + args: Tuple[_RequestObjectProxy, _Context], + kwargs: Dict, +) -> str: + """ + Validate the authorization header includes a signature. + + Args: + wrapped: An endpoint function for `requests_mock`. + instance: The class that the endpoint function is in. + args: The arguments given to the endpoint function. + kwargs: The keyword arguments given to the endpoint function. + + Returns: + The result of calling the endpoint. + An ``UNAUTHORIZED`` response if the "Authorization" header is not as + expected. + """ + request, context = args + + header = request.headers['Authorization'] + if header.count(b':') == 1 and header.split(b':')[1]: + return wrapped(*args, **kwargs) + + context.status_code = codes.BAD_REQUEST + body = { + 'transaction_id': uuid.uuid4().hex, + 'result_code': ResultCodes.FAIL.value, + } + return json_dump(body) + + @wrapt.decorator def validate_authorization( wrapped: Callable[..., str], @@ -78,9 +151,9 @@ def validate_authorization( if database is not None: return wrapped(*args, **kwargs) - context.status_code = codes.BAD_REQUEST + context.status_code = codes.UNAUTHORIZED body = { 'transaction_id': uuid.uuid4().hex, - 'result_code': ResultCodes.FAIL.value, + 'result_code': ResultCodes.AUTHENTICATION_FAILURE.value, } return json_dump(body) diff --git a/tests/mock_vws/test_authorization_header.py b/tests/mock_vws/test_authorization_header.py index 90aa082ce..82d16d82a 100644 --- a/tests/mock_vws/test_authorization_header.py +++ b/tests/mock_vws/test_authorization_header.py @@ -246,6 +246,27 @@ def test_bad_access_key_query( ) assert response.text == expected_text + def test_bad_secret_key_services( + self, + vuforia_database: VuforiaDatabase, + ) -> None: + """ + If the server secret key given is incorrect, an + ``AuthenticationFailure`` response is returned. + """ + keys = vuforia_database + keys.server_secret_key = b'example' + response = get_vws_target( + target_id=uuid.uuid4().hex, + vuforia_database=keys, + ) + + assert_vws_failure( + response=response, + status_code=codes.UNAUTHORIZED, + result_code=ResultCodes.AUTHENTICATION_FAILURE, + ) + def test_bad_secret_key_query( self, vuforia_database: VuforiaDatabase, diff --git a/tests/mock_vws/test_usage.py b/tests/mock_vws/test_usage.py index 567d99fa9..90a1cd8e5 100644 --- a/tests/mock_vws/test_usage.py +++ b/tests/mock_vws/test_usage.py @@ -54,7 +54,7 @@ def request_mocked_address() -> None: url='https://vws.vuforia.com/summary', headers={ 'Date': rfc_1123_date(), - 'Authorization': 'bad_auth_token', + 'Authorization': b'bad_auth_token', }, data=b'', )