Skip to content

Commit

Permalink
JSON serializer: predictor.predict accepts dictionaries (#62)
Browse files Browse the repository at this point in the history
Add support for serializing python dictionaries to json
Add prediction with dictionary in tf iris integ test
  • Loading branch information
andremoeller authored and lukmis committed Feb 15, 2018
1 parent 795b030 commit d47f6d1
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
class MXNetPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against MXNet Endpoints.
This is able to serialize Python lists and numpy arrays to multidimensional tensors for MXNet inference."""
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
inference."""

def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``MXNetPredictor``.
Expand Down
11 changes: 8 additions & 3 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def __call__(self, data):
if isinstance(data, list):
if not len(data) > 0:
raise ValueError("empty array can't be serialized")
return _json_serialize_python_array(data)
return _json_serialize_python_object(data)

if isinstance(data, dict):
if not len(data.keys()) > 0:
raise ValueError("empty dictionary can't be serialized")
return _json_serialize_python_object(data)

# files and buffers
if hasattr(data, 'read'):
Expand All @@ -254,10 +259,10 @@ def __call__(self, data):

def _json_serialize_numpy_array(data):
# numpy arrays can't be serialized but we know they have uniform type
return _json_serialize_python_array(data.tolist())
return _json_serialize_python_object(data.tolist())


def _json_serialize_python_array(data):
def _json_serialize_python_object(data):
return _json_serialize_object(data)


Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@


class TensorFlowPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` for inference against MXNet ``Endpoint``s."""
"""A ``RealTimePredictor`` for inference against TensorFlow ``Endpoint``s.
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
inference"""
def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``TensorFlowPredictor``.
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/tensorflow/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self):
self.content_type = CONTENT_TYPE_OCTET_STREAM

def __call__(self, data):
# isintance does not work here because a same protobuf message can be imported from a different module.
# isinstance does not work here because a same protobuf message can be imported from a different module.
# for example sagemaker.tensorflow.tensorflow_serving.regression_pb2 and tensorflow_serving.apis.regression_pb2
predict_type = data.__class__.__name__

Expand Down
9 changes: 7 additions & 2 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ def test_tf(sagemaker_session):
with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
json_predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')

result = json_predictor.predict([6.4, 3.2, 4.5, 1.5])
print('predict result: {}'.format(result))
features = [6.4, 3.2, 4.5, 1.5]
dict_result = json_predictor.predict({'inputs': features})
print('predict result: {}'.format(dict_result))
list_result = json_predictor.predict(features)
print('predict result: {}'.format(list_result))

assert dict_result == list_result


def test_tf_async(sagemaker_session):
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,26 @@ def test_json_serializer_python_array():
assert result == '[1, 2, 3]'


def test_json_serializer_python_dictionary():
d = {"gender": "m", "age": 22, "city": "Paris"}

result = json_serializer(d)

assert json.loads(result) == d


def test_json_serializer_python_invalid_empty():
with pytest.raises(ValueError) as error:
json_serializer([])
assert "empty array" in str(error)


def test_json_serializer_python_dictionary_invalid_empty():
with pytest.raises(ValueError) as error:
json_serializer({})
assert "empty dictionary" in str(error)


def test_json_serializer_csv_buffer():
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
with open(csv_file_path) as csv_file:
Expand Down

0 comments on commit d47f6d1

Please sign in to comment.