Skip to content

Commit

Permalink
change: lazy import of tensorflow module (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvsusp committed Sep 24, 2019
1 parent 810e96b commit 2ee203f
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions src/sagemaker/tensorflow/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,31 @@
import google.protobuf.json_format as json_format
from google.protobuf.message import DecodeError
from protobuf_to_dict import protobuf_to_dict
from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
from tensorflow.python.framework import tensor_util # pylint: disable=no-name-in-module

from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV
from sagemaker.predictor import json_serializer, csv_serializer
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2

_POSSIBLE_RESPONSES = [
predict_pb2.PredictResponse,
classification_pb2.ClassificationResponse,
inference_pb2.MultiInferenceResponse,
regression_pb2.RegressionResponse,
tensor_pb2.TensorProto,
]

def _possible_responses():
"""
Returns: Possible available request types.
"""
from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
from tensorflow_serving.apis import (
predict_pb2,
classification_pb2,
inference_pb2,
regression_pb2,
)

return [
predict_pb2.PredictResponse,
classification_pb2.ClassificationResponse,
inference_pb2.MultiInferenceResponse,
regression_pb2.RegressionResponse,
tensor_pb2.TensorProto,
]


REGRESSION_REQUEST = "RegressionRequest"
MULTI_INFERENCE_REQUEST = "MultiInferenceRequest"
Expand Down Expand Up @@ -88,7 +99,7 @@ def __call__(self, stream, content_type):
finally:
stream.close()

for possible_response in _POSSIBLE_RESPONSES:
for possible_response in _possible_responses():
try:
response = possible_response()
response.ParseFromString(data)
Expand All @@ -114,6 +125,9 @@ def __call__(self, data):
Args:
data:
"""

from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module

if isinstance(data, tensor_pb2.TensorProto):
return json_format.MessageToJson(data)
return json_serializer(data)
Expand All @@ -139,7 +153,7 @@ def __call__(self, stream, content_type):
finally:
stream.close()

for possible_response in _POSSIBLE_RESPONSES:
for possible_response in _possible_responses():
try:
return protobuf_to_dict(json_format.Parse(data, possible_response()))
except (UnicodeDecodeError, DecodeError, json_format.ParseError):
Expand All @@ -164,6 +178,10 @@ def __call__(self, data):
data:
"""
to_serialize = data

from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
from tensorflow.python.framework import tensor_util # pylint: disable=no-name-in-module

if isinstance(data, tensor_pb2.TensorProto):
to_serialize = tensor_util.MakeNdarray(data)
return csv_serializer(to_serialize)
Expand Down

0 comments on commit 2ee203f

Please sign in to comment.