Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions src/tf_container/proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def _create_classification_request(self, data):
def _create_feature_dict_list(self, data):
"""
Parses the input data and returns a [dict<string, iterable>] which will be used to create the tf examples.
If the input data is not a dict, a dictionary will be created with the default predict key PREDICT_INPUTS
If the input data is not a dict, a dictionary will be created with the default key PREDICT_INPUTS.
Used on the code path for creating ClassificationRequests.

Examples:
input => output
Expand Down Expand Up @@ -184,43 +185,46 @@ def _raise_not_implemented_exception(self, data):

def _create_input_map(self, data):
"""
Parses the input data and returns a dict<string, TensorProto> which will be used to create the predict request.
Parses the input data and returns a dict<string, TensorProto> which will be used to create the PredictRequest.
If the input data is not a dict, a dictionary will be created with the default predict key PREDICT_INPUTS

input.

Examples:
input => output
{'inputs': tensor_proto} => {'inputs': tensor_proto}
-------------------------------------------------
tensor_proto => {PREDICT_INPUTS: tensor_proto}
[1,2,3] => {PREDICT_INPUTS: tensor_proto(1,2,3)}
{'custom_tensor_name': tensor_proto} => {'custom_tensor_name': TensorProto}
[1,2,3] => {PREDICT_INPUTS: TensorProto(1,2,3)}
{'custom_tensor_name': [1, 2, 3]} => {'custom_tensor_name': TensorProto(1,2,3)}
Args:
data: request data. Can be any instance of dict<string, tensor_proto>, tensor_proto or any array like data.
data: request data. Can be any of: ndarray-like, TensorProto, dict<str, TensorProto>, dict<str, ndarray-like>

Returns:
dict<string, tensor_proto>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the docstrings to reflect the behaviour change?


"""
msg = """Unsupported request data format: {}.
Valid formats: tensor_pb2.TensorProto, dict<string, tensor_pb2.TensorProto> and predict_pb2.PredictRequest"""

if isinstance(data, dict):
if all(isinstance(v, tensor_pb2.TensorProto) for k, v in data.items()):
return data
raise ValueError(msg.format(data))
return {k: self._value_to_tensor(v) for k, v in data.items()}

# When input data is not a dict, no tensor names are given, so use default
return {self.input_tensor_name: self._value_to_tensor(data)}

if isinstance(data, tensor_pb2.TensorProto):
return {self.input_tensor_name: data}
def _value_to_tensor(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some docstring here too.

"""Converts the given value to a tensor_pb2.TensorProto. Used on code path for creating PredictRequests."""
if isinstance(value, tensor_pb2.TensorProto):
return value

msg = """Unable to convert value to TensorProto: {}.
Valid formats: tensor_pb2.TensorProto, list, numpy.ndarray"""
try:
# TODO: tensorflow container supports prediction requests with ONLY one tensor as input
input_type = self.input_type_map.values()[0]
ndarray = np.asarray(data)
tensor_proto = make_tensor_proto(values=ndarray, dtype=input_type, shape=ndarray.shape)
return {self.input_tensor_name: tensor_proto}
except:
raise ValueError(msg.format(data))
ndarray = np.asarray(value)
return make_tensor_proto(values=ndarray, dtype=input_type, shape=ndarray.shape)
except Exception:
raise ValueError(msg.format(value))


def _create_tf_example(feature_dict):
Expand Down
4 changes: 2 additions & 2 deletions src/tf_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, grpc_proxy_client, transform_fn=None, input_fn=None, output_f

@staticmethod
def _parse_json_request(serialized_data):
'''
"""
json deserialization works in the following order:
1 - tries to deserialize the payload as a tensor using google.protobuf.json_format.Parse(
payload, tensor_pb2.TensorProto())
Expand All @@ -170,7 +170,7 @@ def _parse_json_request(serialized_data):

Returns:
deserialized object
'''
"""
try:
return json_format.Parse(serialized_data, tensor_pb2.TensorProto())
except json_format.ParseError:
Expand Down
13 changes: 13 additions & 0 deletions test/integ/container_tests/layers_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ def test_json_request():
prediction_result = json.loads(serialized_output)

assert len(prediction_result['outputs']['probabilities']['floatVal']) == 10


def test_json_dict_of_lists():
data = {'inputs': [x for x in xrange(784)]}

url = "http://localhost:8080/invocations"
serialized_output = requests.post(url,
json.dumps(data),
headers={'Content-type': 'application/json'}).content

prediction_result = json.loads(serialized_output)

assert len(prediction_result['outputs']['probabilities']['floatVal']) == 10
19 changes: 16 additions & 3 deletions test/unit/test_proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def set_up():
patcher.start()
from tf_container.proxy_client import GRPCProxyClient
proxy_client = GRPCProxyClient(9000, input_tensor_name='inputs', signature_name='serving_default')
proxy_client.input_type_map['sometype'] = 'somedtype'

yield mock, proxy_client
patcher.stop()
Expand Down Expand Up @@ -253,15 +254,27 @@ def test_predict_with_predict_request(set_up, set_up_requests):
assert prediction == predict_fn.return_value


def test_predict_with_invalid_payload(set_up, set_up_requests):
mock, proxy_client = set_up
@patch('tf_container.proxy_client.make_tensor_proto', side_effect=Exception('tensor proto failed!'))
def test_predict_with_invalid_payload(make_tensor_proto, set_up, set_up_requests):
_, proxy_client = set_up

data = complex('1+2j')

with pytest.raises(ValueError) as error:
proxy_client.predict(data)

assert 'Unsupported request data format' in str(error)
assert 'Unable to convert value to TensorProto' in str(error)


@patch('tf_container.proxy_client.make_tensor_proto', return_value='MyTensorProto')
def test_predict_create_input_map_with_dict_of_lists(make_tensor_proto, set_up, set_up_requests):
_, proxy_client = set_up

data = {'mytensor': [1, 2, 3]}

result = proxy_client._create_input_map(data)
assert result == {'mytensor': 'MyTensorProto'}
make_tensor_proto.assert_called_once()


def test_classification_with_classification_request(set_up, set_up_requests):
Expand Down