Skip to content
Permalink
Fetching contributors…
Cannot retrieve contributors at this time
154 lines (133 sloc) 5.98 KB
#!/usr/bin/env python
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import shutil
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
BUCKET = None # set from task.py
PATTERN = 'of' # gets all files
TRAIN_STEPS = 10000
CSV_COLUMNS = 'weight_pounds,is_male,mother_age,plurality,gestation_weeks,key'.split(',')
LABEL_COLUMN = 'weight_pounds'
KEY_COLUMN = 'key'
DEFAULTS = [[0.0], ['null'], [0.0], ['null'], [0.0], ['nokey']]
def read_dataset(prefix, pattern, batch_size=512):
# use prefix to create filename
filename = 'gs://{}/babyweight/preproc/{}*{}*'.format(BUCKET, prefix, pattern)
if prefix == 'train':
mode = tf.estimator.ModeKeys.TRAIN
num_epochs = None # indefinitely
else:
mode = tf.estimator.ModeKeys.EVAL
num_epochs = 1 # end-of-input after this
# the actual input function passed to TensorFlow
def _input_fn():
# could be a path to one file or a file pattern.
input_file_names = tf.train.match_filenames_once(filename)
filename_queue = tf.train.string_input_producer(
input_file_names, shuffle=True, num_epochs=num_epochs)
# read CSV
reader = tf.TextLineReader()
_, value = reader.read_up_to(filename_queue, num_records=batch_size)
if mode == tf.estimator.ModeKeys.TRAIN:
value = tf.train.shuffle_batch([value], batch_size, capacity=10*batch_size,
min_after_dequeue=batch_size, enqueue_many=True,
allow_smaller_final_batch=False)
value_column = tf.expand_dims(value, -1)
columns = tf.decode_csv(value_column, record_defaults=DEFAULTS)
features = dict(zip(CSV_COLUMNS, columns))
features.pop(KEY_COLUMN)
label = features.pop(LABEL_COLUMN)
return features, label
return _input_fn
def get_wide_deep():
# define column types
is_male,mother_age,plurality,gestation_weeks = \
[\
tf.feature_column.categorical_column_with_vocabulary_list('is_male',
['True', 'False', 'Unknown']),
tf.feature_column.numeric_column('mother_age'),
tf.feature_column.categorical_column_with_vocabulary_list('plurality',
['Single(1)', 'Twins(2)', 'Triplets(3)',
'Quadruplets(4)', 'Quintuplets(5)','Multiple(2+)']),
tf.feature_column.numeric_column('gestation_weeks')
]
# discretize
age_buckets = tf.feature_column.bucketized_column(mother_age,
boundaries=np.arange(15,45,1).tolist())
gestation_buckets = tf.feature_column.bucketized_column(gestation_weeks,
boundaries=np.arange(17,47,1).tolist())
# sparse columns are wide
wide = [is_male,
plurality,
age_buckets,
gestation_buckets]
# feature cross all the wide columns and embed into a lower dimension
crossed = tf.feature_column.crossed_column(wide, hash_bucket_size=20000)
embed = tf.feature_column.embedding_column(crossed, 3)
# continuous columns are deep
deep = [mother_age,
gestation_weeks,
embed]
return wide, deep
def serving_input_fn():
feature_placeholders = {
'is_male': tf.placeholder(tf.string, [None]),
'mother_age': tf.placeholder(tf.float32, [None]),
'plurality': tf.placeholder(tf.string, [None]),
'gestation_weeks': tf.placeholder(tf.float32, [None])
}
features = {
key: tf.expand_dims(tensor, -1)
for key, tensor in feature_placeholders.items()
}
return tf.estimator.export.ServingInputReceiver(features, feature_placeholders)
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
def experiment_fn(output_dir):
wide, deep = get_wide_deep()
return tf.contrib.learn.Experiment(
tf.estimator.DNNLinearCombinedRegressor(model_dir=output_dir,
linear_feature_columns=wide,
dnn_feature_columns=deep,
dnn_hidden_units=[64, 32]),
train_input_fn=read_dataset('train', PATTERN),
eval_input_fn=read_dataset('eval', PATTERN),
export_strategies=[saved_model_export_utils.make_export_strategy(
serving_input_fn,
default_output_alternative_key=None,
exports_to_keep=1
)],
train_steps=TRAIN_STEPS,
eval_steps=None
)
def train_and_evaluate(output_dir):
wide, deep = get_wide_deep()
estimator = tf.estimator.DNNLinearCombinedRegressor(
model_dir=output_dir,
linear_feature_columns=wide,
dnn_feature_columns=deep,
dnn_hidden_units=[64, 32])
train_spec=tf.estimator.TrainSpec(
input_fn=read_dataset('train', PATTERN),
max_steps=TRAIN_STEPS)
exporter = tf.estimator.LatestExporter('exporter',serving_input_fn)
eval_spec=tf.estimator.EvalSpec(
input_fn=read_dataset('eval', PATTERN),
steps=None,
exporters=exporter)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
You can’t perform that action at this time.