Skip to content

Commit

Permalink
improvement: remove unnecessary expand and squeeze dimension and intr…
Browse files Browse the repository at this point in the history
…oduce image send in a batch mode
  • Loading branch information
vbezgachev committed Jun 12, 2018
1 parent 1a48ab2 commit 8407dd1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 29 deletions.
69 changes: 47 additions & 22 deletions svnh_semi_supervised_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,77 @@
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow.contrib.util import make_tensor_proto

from tensorflow.contrib.util import make_tensor_proto

from os import listdir
from os.path import isfile, join


def parse_args():
parser = ArgumentParser(description="Request a TensorFlow server for a prediction on the image")
parser.add_argument("-s", "--server",
dest="server",
parser = ArgumentParser(description='Request a TensorFlow server for a prediction on the image')
parser.add_argument('-s', '--server',
dest='server',
default='172.17.0.2:9000',
help="prediction service host:port")
parser.add_argument("-i", "--image",
dest="image",
default="",
help="path to image in JPEG format",)
help='prediction service host:port')
parser.add_argument('-i', '--image_path',
dest='image_path',
default='',
help='path to images folder',)
parser.add_argument('-b', '--batch_mode',
dest='batch_mode',
default='true',
help='send image as batch or one-by-one')
args = parser.parse_args()

host, port = args.server.split(':')

return host, port, args.image
return host, port, args.image_path, args.batch_mode == 'true'


def main():
# parse command line arguments
host, port, image = parse_args()
host, port, image_path, batch_mode = parse_args()

channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

filenames = [(image_path + '/' + f) for f in listdir(image_path) if isfile(join(image_path, f))]
files = []
imagedata = []
for filename in filenames:
f = open(filename, 'rb')
files.append(f)

# Send request
with open(image, 'rb') as f:
# See prediction_service.proto for gRPC request/response details.
data = f.read()
imagedata.append(data)

start = time.time()
start = time.time()

if batch_mode:
print('In batch mode')
request = predict_pb2.PredictRequest()

# Call GAN model to make prediction on the image
request.model_spec.name = 'gan'
request.model_spec.signature_name = 'predict_images'
request.inputs['images'].CopyFrom(make_tensor_proto(data, shape=[1]))

result = stub.Predict(request, 60.0) # 60 secs timeout

end = time.time()
time_diff = end - start
request.inputs['images'].CopyFrom(make_tensor_proto(imagedata, shape=[len(imagedata)]))

result = stub.Predict(request, 60.0)
print(result)
print('time elapased: {}'.format(time_diff))
else:
print('In one-by-one mode')
for data in imagedata:
request = predict_pb2.PredictRequest()
request.model_spec.name = 'gan'
request.model_spec.signature_name = 'predict_images'

request.inputs['images'].CopyFrom(make_tensor_proto(data, shape=[1]))

result = stub.Predict(request, 60.0) # 60 secs timeout
print(result)

end = time.time()
time_diff = end - start
print('time elapased: {}'.format(time_diff))


if __name__ == '__main__':
Expand Down
7 changes: 0 additions & 7 deletions svnh_semi_supervised_model_saved.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ def preprocess_image(image_buffer):
# adjust_* ops all require this range for dtype float.
image = tf.image.convert_image_dtype(image, dtype=tf.float32)

# Networks accept images in batches.
# The first dimension usually represents the batch size.
# In our case the batch size is one.
image = tf.expand_dims(image, 0)

# Finally, rescale to [-1,1] instead of [0, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
Expand All @@ -65,8 +60,6 @@ def main(_):
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
jpegs = tf_example['image/encoded']
images = tf.map_fn(preprocess_image, jpegs, dtype=tf.float32)
images = tf.squeeze(images, [0])
# now the image shape is (1, ?, ?, 3)

# Create GAN model
z_size = 100
Expand Down

0 comments on commit 8407dd1

Please sign in to comment.