Skip to content
This repository has been archived by the owner on Dec 17, 2021. It is now read-only.

Commit

Permalink
2016_10_21_12:10:47
Browse files Browse the repository at this point in the history
  • Loading branch information
elibixby committed Oct 21, 2016
1 parent 1cc8e14 commit 7a002e7
Show file tree
Hide file tree
Showing 9 changed files with 647 additions and 326 deletions.
48 changes: 19 additions & 29 deletions iris/local_predict.py
Expand Up @@ -26,59 +26,50 @@
import collections
import json
import os
import subprocess
import sys

from google.cloud.ml import features
from google.cloud.ml import session_bundle


def _default_project():
get_project = ['gcloud', 'config', 'list', 'project',
'--format=value(core.project)']

with open(os.devnull, 'w') as dev_null:
return subprocess.check_output(get_project, stderr=dev_null).strip()


def local_predict(args):
def local_predict(input_data, model_dir, dry_run):
"""Runs prediction locally."""

session, _ = session_bundle.load_session_bundle_from_path(args.model_dir)
session, _ = session_bundle.load_session_bundle_from_path(model_dir)
# get the mappings between aliases and tensor names
# for both inputs and outputs
input_alias_map = json.loads(session.graph.get_collection('inputs')[0])
output_alias_map = json.loads(session.graph.get_collection('outputs')[0])
aliases, tensor_names = zip(*output_alias_map.items())

metadata_path = os.path.join(args.model_dir, 'metadata.yaml')
metadata_path = os.path.join(model_dir, 'metadata.yaml')
transformer = features.FeatureProducer(metadata_path)
for input_file in args.input:
for input_file in input_data:
with open(input_file) as f:
feed_dict = collections.defaultdict(list)
for line in f:
preprocessed = transformer.preprocess(line)
feed_dict[input_alias_map.values()[0]].append(
preprocessed.SerializeToString())
if args.dry_run:
print 'Feed data dict %s to graph and fetch %s' % (
feed_dict, tensor_names)
else:
result = session.run(fetches=tensor_names, feed_dict=feed_dict)
for row in zip(*result):
print json.dumps({name: value.tolist()
for name, value in zip(aliases, row)})
result = session.run(fetches=tensor_names, feed_dict=feed_dict)
for row in zip(*result):
print json.dumps({
name: (value.tolist() if getattr(value, 'tolist', None) else
value)
for name, value in zip(aliases, row)
})


def parse_args():
def parse_args(args):
"""Parses arguments specified on the command-line."""

argparser = argparse.ArgumentParser('Predict on the Iris model.')

argparser.add_argument(
'input',
'input_data',
nargs='+',
help=('The input data file/file patterns. Multiple '
'files can be specified if more than one file patterns is needed.'))
help=('The input data file. Multiple files can be specified if more than '
'one file is needed.'))

argparser.add_argument(
'--model_dir',
Expand All @@ -91,9 +82,8 @@ def parse_args():
action='store_true',
help='Instead of executing commands, prints them.')

return argparser.parse_args()

return argparser.parse_args(args)

if __name__ == '__main__':
arguments = parse_args()
local_predict(arguments)
parsed_args = parse_args(sys.argv[1:])
local_predict(**vars(parsed_args))

0 comments on commit 7a002e7

Please sign in to comment.