-
Notifications
You must be signed in to change notification settings - Fork 163
Description
Hey there,
So I've recently been speccing out a project that would require multiple inputs given to a model. I had a few hiccups getting this to work with a TensorFlow model w/Keras, but I figured out a workaround -- however I don't think this workaround should have to be built by the customer and instead should be handled by default.
Basically, when serve receives a request, it turns it into a tensor proto using json_format.Parse. This means that it cannot handle dictionaries. However, the Sagemaker SDK's predictor will happily take in, serialize, and send off a dictionary from Python. This was a confusing and frustrating difference in API to deal with.
However, it's fairly easy to resolve, as you can simply use tf.make_tensor_proto on each element after being JSON deserialized, as below:
def input_fn(serialized_input, content_type):
"""An input_fn that loads a pickled object"""
if content_type == "application/json":
deserialized_input = json.loads(serialized_input)
if isinstance(deserialized_input, dict):
deserialized_tensorproto = {
k:tensorflow.make_tensor_proto(v)
for k, v in deserialized_input.items()
}
else:
deserialized_tensorproto = tensorflow.make_tensor_proto(serialized_input)
return deserialized_tensorproto
else:
# Handle other content-types here or raise an Exception
# if the content type is not supported.
pass
Unfortunately this doesn't use json_format, which I'm not sure is a requirement. This input_fn can be used to override the default one and provide the multi-input functionality as done in this demo gist, however it would be great not to have to include this workaround!