Skip to content

Commit

Permalink
Refactor plus switched prediction to use tf.Example as input
Browse files Browse the repository at this point in the history
  • Loading branch information
Alan Patterson committed Nov 20, 2017
1 parent 0fc74af commit dd7c2fd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
17 changes: 10 additions & 7 deletions predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
tf.flags.DEFINE_string("signature_def", "proba",
"Stored signature key of method to call (proba|embedding)")
tf.flags.DEFINE_string("saved_model", None, "Directory of SavedModel")
tf.flags.DEFINE_string("tag", "serve", "SavedModel tag, serve|gpu")
tf.flags.DEFINE_boolean("debug", False, "Debug")
FLAGS = tf.flags.FLAGS

_TAG = "serve"


def RunModel(saved_model_dir, signature_def_key, text, ngrams_list=None):
def RunModel(saved_model_dir, signature_def_key, tag, text, ngrams_list=None):
saved_model = reader.read_saved_model(saved_model_dir)
meta_graph = None
for meta_graph_def in saved_model.meta_graphs:
if meta_graph_def.meta_info_def.tags == _TAG:
if tag in meta_graph_def.meta_info_def.tags:
meta_graph = meta_graph_def
break
if meta_graph_def is None:
raise ValueError("Cannot find saved_model with tag" + tag)
signature_def = signature_def_utils.get_signature_def_by_key(
meta_graph, signature_def_key)
text = text_utils.TokenizeText(text)
Expand All @@ -40,7 +42,7 @@ def RunModel(saved_model_dir, signature_def_key, text, ngrams_list=None):
ngrams = inputs.GenerateNgrams(text, ngrams_list)
example = inputs.BuildTextExample(text, ngrams=ngrams)
inputs_feed_dict = {
signature_def.inputs["inputs"].name: example,
signature_def.inputs["inputs"].name: [example],
}
if signature_def_key == "proba":
output_key = "scores"
Expand All @@ -50,7 +52,7 @@ def RunModel(saved_model_dir, signature_def_key, text, ngrams_list=None):
raise ValueError("Unrecognised signature_def %s" % (signature_def_key))
output_tensor = signature_def.outputs[output_key].name
with tf.Session() as sess:
loader.load(sess, [_TAG], saved_model_dir)
loader.load(sess, [tag], saved_model_dir)
outputs = sess.run(output_tensor,
feed_dict=inputs_feed_dict)
return outputs
Expand All @@ -59,7 +61,8 @@ def RunModel(saved_model_dir, signature_def_key, text, ngrams_list=None):
def main(_):
if not FLAGS.text:
raise ValueError("No --text provided")
outputs = RunModel(FLAGS.saved_model, FLAGS.signature_def, FLAGS.text, FLAGS.ngrams)
outputs = RunModel(FLAGS.saved_model, FLAGS.signature_def, FLAGS.tag,
FLAGS.text, FLAGS.ngrams)
if FLAGS.signature_def == "proba":
print("Proba:", outputs[0])
print("Class(1-N):", np.argmax(outputs[0]) + 1)
Expand Down
3 changes: 2 additions & 1 deletion process_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def ParseTextInput(textfile, labelsfie, ngrams):
examples = []
with open(textfile) as f1, open(labelsfile) as f2:
for text, label in zip(f1, f2):
words = text_utils.TokenizeText(text)
examples.append({
"text": text_utils.TokenizeText(text),
"text": words,
"label": label,
})
if ngrams:
Expand Down
2 changes: 1 addition & 1 deletion text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def TokenizeText(text):
return word_tokenise(text.lower())
return word_tokenize(text.lower())


def ParseNgramsOpts(opts):
Expand Down
5 changes: 0 additions & 5 deletions todo.txt

This file was deleted.

0 comments on commit dd7c2fd

Please sign in to comment.