diff --git a/src/sagemaker_inference/default_inference_handler.py b/src/sagemaker_inference/default_inference_handler.py index b7cd6cf..ad01285 100644 --- a/src/sagemaker_inference/default_inference_handler.py +++ b/src/sagemaker_inference/default_inference_handler.py @@ -15,7 +15,7 @@ """ import textwrap -from sagemaker_inference import decoder, encoder +from sagemaker_inference import decoder, encoder, errors, utils class DefaultInferenceHandler(object): @@ -85,4 +85,7 @@ def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use obj: prediction data. """ - return encoder.encode(prediction, accept), accept + for content_type in utils.parse_accept(accept): + if content_type in encoder.SUPPORTED_CONTENT_TYPES: + return encoder.encode(prediction, content_type), content_type + raise errors.UnsupportedFormatError(accept) diff --git a/src/sagemaker_inference/encoder.py b/src/sagemaker_inference/encoder.py index 2f54086..fdf38a0 100644 --- a/src/sagemaker_inference/encoder.py +++ b/src/sagemaker_inference/encoder.py @@ -87,6 +87,9 @@ def _array_to_csv(array_like): } +SUPPORTED_CONTENT_TYPES = set(_encoder_map.keys()) + + def encode(array_like, content_type): """Encode an array-like object in a specific content_type to a numpy array. diff --git a/src/sagemaker_inference/utils.py b/src/sagemaker_inference/utils.py index 15eb941..592d54f 100644 --- a/src/sagemaker_inference/utils.py +++ b/src/sagemaker_inference/utils.py @@ -66,3 +66,16 @@ def retrieve_content_type_header(request_property): return request_property[key] return None + + +def parse_accept(accept): + """Parses the Accept header sent with a request. + + Args: + accept (str): the value of an Accept header. + + Returns: + (list): A list containing the MIME types that the client is able to + understand. + """ + return accept.replace(" ", "").split(",") diff --git a/test/unit/test_default_inference_handler.py b/test/unit/test_default_inference_handler.py index c82a49e..e6527ae 100644 --- a/test/unit/test_default_inference_handler.py +++ b/test/unit/test_default_inference_handler.py @@ -24,11 +24,19 @@ def test_default_input_fn(loads): loads.assert_called_with(42, content_types.JSON) +@pytest.mark.parametrize( + "accept, expected_content_type", + [ + ("text/csv", "text/csv"), + ("text/csv, application/json", "text/csv"), + ("unsupported/type, text/csv", "text/csv"), + ], +) @patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction ** 2) -def test_default_output_fn(): - result, accept = DefaultInferenceHandler().default_output_fn(2, content_types.CSV) +def test_default_output_fn(accept, expected_content_type): + result, content_type = DefaultInferenceHandler().default_output_fn(2, accept) assert result == 4 - assert accept == content_types.CSV + assert content_type == expected_content_type def test_default_model_fn(): diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 2113d62..e739686 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -13,7 +13,12 @@ from mock import Mock, mock_open, patch import pytest -from sagemaker_inference.utils import read_file, retrieve_content_type_header, write_file +from sagemaker_inference.utils import ( + parse_accept, + read_file, + retrieve_content_type_header, + write_file, +) TEXT = "text" CONTENT_TYPE = "content_type" @@ -74,3 +79,16 @@ def test_content_type_header(content_type_key): result = retrieve_content_type_header(request_property) assert result == CONTENT_TYPE + + +@pytest.mark.parametrize( + "input, expected", + [ + ("application/json", ["application/json"]), + ("application/json, text/csv", ["application/json", "text/csv"]), + ("application/json,text/csv", ["application/json", "text/csv"]), + ], +) +def test_parse_accept(input, expected): + actual = parse_accept(input) + assert actual == expected