From 62d83252abf3f61b4b15de63759276ac4f440620 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Tue, 28 Jul 2020 13:15:48 -0500 Subject: [PATCH 1/5] Support multiple accept --- .../default_inference_handler.py | 7 +++++-- src/sagemaker_inference/encoder.py | 3 +++ src/sagemaker_inference/utils.py | 13 +++++++++++++ test/unit/test_utils.py | 19 ++++++++++++++++++- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_inference/default_inference_handler.py b/src/sagemaker_inference/default_inference_handler.py index b7cd6cf..1b85707 100644 --- a/src/sagemaker_inference/default_inference_handler.py +++ b/src/sagemaker_inference/default_inference_handler.py @@ -15,7 +15,7 @@ """ import textwrap -from sagemaker_inference import decoder, encoder +from sagemaker_inference import content_types, decoder, encoder, utils class DefaultInferenceHandler(object): @@ -85,4 +85,7 @@ def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use obj: prediction data. """ - return encoder.encode(prediction, accept), accept + for content_type in utils.parse_accept(accept): + if content_type in encoder.SUPPORTED_CONTENT_TYPES: + return encoder.encode(prediction, content_type), content_type + return encoder.encode(prediction, content_types.JSON), content_types.JSON diff --git a/src/sagemaker_inference/encoder.py b/src/sagemaker_inference/encoder.py index 2f54086..fdf38a0 100644 --- a/src/sagemaker_inference/encoder.py +++ b/src/sagemaker_inference/encoder.py @@ -87,6 +87,9 @@ def _array_to_csv(array_like): } +SUPPORTED_CONTENT_TYPES = set(_encoder_map.keys()) + + def encode(array_like, content_type): """Encode an array-like object in a specific content_type to a numpy array. diff --git a/src/sagemaker_inference/utils.py b/src/sagemaker_inference/utils.py index 15eb941..0ba4905 100644 --- a/src/sagemaker_inference/utils.py +++ b/src/sagemaker_inference/utils.py @@ -66,3 +66,16 @@ def retrieve_content_type_header(request_property): return request_property[key] return None + + +def parse_accept(accept): + """Parses the Accept header sent with a request. + + Args: + accept (str): the value of an Accept header. + + Returns: + (list): A list containing the MIME types that the client is able to + understand. + """ + return accept.split(", ") diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 2113d62..e930ca3 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -13,7 +13,12 @@ from mock import Mock, mock_open, patch import pytest -from sagemaker_inference.utils import read_file, retrieve_content_type_header, write_file +from sagemaker_inference.utils import ( + parse_accept, + read_file, + retrieve_content_type_header, + write_file, +) TEXT = "text" CONTENT_TYPE = "content_type" @@ -74,3 +79,15 @@ def test_content_type_header(content_type_key): result = retrieve_content_type_header(request_property) assert result == CONTENT_TYPE + + +@pytest.mark.parametrize( + "input, expected", + [ + ("application/json", ["application/json"]), + ("application/json, text/csv", ["application/json", "text/csv"]), + ], +) +def test_parse_accept(input, expected): + actual = parse_accept(input) + assert actual == expected From 310e9cb632847e4b1216a1d8240205b2d469f8bc Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Tue, 28 Jul 2020 13:24:50 -0500 Subject: [PATCH 2/5] Update test_default_inference_handler.py --- test/unit/test_default_inference_handler.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/test/unit/test_default_inference_handler.py b/test/unit/test_default_inference_handler.py index c82a49e..f5937ac 100644 --- a/test/unit/test_default_inference_handler.py +++ b/test/unit/test_default_inference_handler.py @@ -24,11 +24,20 @@ def test_default_input_fn(loads): loads.assert_called_with(42, content_types.JSON) +@pytest.mark.parametrize( + "accept, expected_content_type", + [ + ("text/csv", "text/csv"), + ("text/csv, application/json", "text/csv"), + ("unsupported/type, text/csv", "text/csv"), + ("unsupported/type", "application/json"), + ], +) @patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction ** 2) -def test_default_output_fn(): - result, accept = DefaultInferenceHandler().default_output_fn(2, content_types.CSV) +def test_default_output_fn(accept, expected_content_type): + result, content_type = DefaultInferenceHandler().default_output_fn(2, accept) assert result == 4 - assert accept == content_types.CSV + assert content_type == expected_content_type def test_default_model_fn(): From 22653c800bd913ba57f78f35f75a05608031279c Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 29 Jul 2020 11:57:59 -0500 Subject: [PATCH 3/5] Address review comments --- src/sagemaker_inference/default_inference_handler.py | 8 +++++--- src/sagemaker_inference/encoder.py | 3 --- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/sagemaker_inference/default_inference_handler.py b/src/sagemaker_inference/default_inference_handler.py index 1b85707..b9092e2 100644 --- a/src/sagemaker_inference/default_inference_handler.py +++ b/src/sagemaker_inference/default_inference_handler.py @@ -15,13 +15,15 @@ """ import textwrap -from sagemaker_inference import content_types, decoder, encoder, utils +from sagemaker_inference import content_types, decoder, encoder, errors, utils class DefaultInferenceHandler(object): """Bare-bones implementation of default inference functions. """ + SUPPORTED_CONTENT_TYPES = {content_types.NPY, content_types.JSON, content_types.CSV} + def default_model_fn(self, model_dir): """Function responsible for loading the model. @@ -86,6 +88,6 @@ def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use """ for content_type in utils.parse_accept(accept): - if content_type in encoder.SUPPORTED_CONTENT_TYPES: + if content_type in self.SUPPORTED_CONTENT_TYPES: return encoder.encode(prediction, content_type), content_type - return encoder.encode(prediction, content_types.JSON), content_types.JSON + raise errors.UnsupportedFormatError(accept) diff --git a/src/sagemaker_inference/encoder.py b/src/sagemaker_inference/encoder.py index fdf38a0..2f54086 100644 --- a/src/sagemaker_inference/encoder.py +++ b/src/sagemaker_inference/encoder.py @@ -87,9 +87,6 @@ def _array_to_csv(array_like): } -SUPPORTED_CONTENT_TYPES = set(_encoder_map.keys()) - - def encode(array_like, content_type): """Encode an array-like object in a specific content_type to a numpy array. From 81b80fb67f9f09aa51a3327d0504c4fa42c65931 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 29 Jul 2020 11:59:47 -0500 Subject: [PATCH 4/5] Update test_default_inference_handler.py --- test/unit/test_default_inference_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/unit/test_default_inference_handler.py b/test/unit/test_default_inference_handler.py index f5937ac..e6527ae 100644 --- a/test/unit/test_default_inference_handler.py +++ b/test/unit/test_default_inference_handler.py @@ -30,7 +30,6 @@ def test_default_input_fn(loads): ("text/csv", "text/csv"), ("text/csv, application/json", "text/csv"), ("unsupported/type, text/csv", "text/csv"), - ("unsupported/type", "application/json"), ], ) @patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction ** 2) From ad7e5838286f628e56ea3410d5a4e9dce00e32b4 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 29 Jul 2020 12:07:59 -0500 Subject: [PATCH 5/5] Address review comments --- src/sagemaker_inference/default_inference_handler.py | 6 ++---- src/sagemaker_inference/encoder.py | 3 +++ src/sagemaker_inference/utils.py | 2 +- test/unit/test_utils.py | 1 + 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/sagemaker_inference/default_inference_handler.py b/src/sagemaker_inference/default_inference_handler.py index b9092e2..ad01285 100644 --- a/src/sagemaker_inference/default_inference_handler.py +++ b/src/sagemaker_inference/default_inference_handler.py @@ -15,15 +15,13 @@ """ import textwrap -from sagemaker_inference import content_types, decoder, encoder, errors, utils +from sagemaker_inference import decoder, encoder, errors, utils class DefaultInferenceHandler(object): """Bare-bones implementation of default inference functions. """ - SUPPORTED_CONTENT_TYPES = {content_types.NPY, content_types.JSON, content_types.CSV} - def default_model_fn(self, model_dir): """Function responsible for loading the model. @@ -88,6 +86,6 @@ def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use """ for content_type in utils.parse_accept(accept): - if content_type in self.SUPPORTED_CONTENT_TYPES: + if content_type in encoder.SUPPORTED_CONTENT_TYPES: return encoder.encode(prediction, content_type), content_type raise errors.UnsupportedFormatError(accept) diff --git a/src/sagemaker_inference/encoder.py b/src/sagemaker_inference/encoder.py index 2f54086..fdf38a0 100644 --- a/src/sagemaker_inference/encoder.py +++ b/src/sagemaker_inference/encoder.py @@ -87,6 +87,9 @@ def _array_to_csv(array_like): } +SUPPORTED_CONTENT_TYPES = set(_encoder_map.keys()) + + def encode(array_like, content_type): """Encode an array-like object in a specific content_type to a numpy array. diff --git a/src/sagemaker_inference/utils.py b/src/sagemaker_inference/utils.py index 0ba4905..592d54f 100644 --- a/src/sagemaker_inference/utils.py +++ b/src/sagemaker_inference/utils.py @@ -78,4 +78,4 @@ def parse_accept(accept): (list): A list containing the MIME types that the client is able to understand. """ - return accept.split(", ") + return accept.replace(" ", "").split(",") diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index e930ca3..e739686 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -86,6 +86,7 @@ def test_content_type_header(content_type_key): [ ("application/json", ["application/json"]), ("application/json, text/csv", ["application/json", "text/csv"]), + ("application/json,text/csv", ["application/json", "text/csv"]), ], ) def test_parse_accept(input, expected):