In [None]:
import tensorflow as tf
tf.__version__

## Global Variable

In [None]:
import tensorflow as tf
import os

try:
  ROOT_DIR = 'gs://{}'.format(bucket_name)
except NameError:
  ROOT_DIR = './tutorial'
  
MODEL_NAME = 'mnist-with-keras-estimator'
DATA_DIR = '{}/data'.format(ROOT_DIR)
MODEL_DIR = '{}/{}'.format(ROOT_DIR, MODEL_NAME)
EXPORT_DIR = '{}/{}/export'.format(ROOT_DIR, MODEL_NAME)
CHECKPOINT_PATH = '{}/checkpoint'.format(MODEL_DIR)

# Remove CHECKPOINT_DIR if needed
if tf.gfile.IsDirectory(MODEL_DIR):
  tf.logging.info('delete {}'.format(MODEL_DIR))
  tf.gfile.DeleteRecursively(MODEL_DIR)
  
if not tf.gfile.IsDirectory(DATA_DIR):
  tf.logging.info('create {}'.format(DATA_DIR))
  tf.gfile.MkDir(ROOT_DIR)
  tf.gfile.MkDir(DATA_DIR)

## Save input data in TFRecord format

In [None]:
import numpy as np

train, test = tf.keras.datasets.mnist.load_data()
X_train = train[0][:-5000]
y_train = train[1][:-5000]
X_eval = train[0][-5000:]
y_eval = train[1][-5000:]
X_test = test[0]
y_test = test[1]

In [None]:
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def create_example(image, label):  
  feature={
      'image': _bytes_feature(image.tobytes()),
      'label': _bytes_feature(label.tobytes())}
  return tf.train.Example(features=tf.train.Features(feature=feature))

def convert_to_tfrecord(images, labels, output_file):
  with tf.python_io.TFRecordWriter(output_file) as record_writer:
    for image, label in zip(images, labels):
      example = create_example(image, label)
      record_writer.write(example.SerializeToString())
      
      
convert_to_tfrecord(X_train, y_train,
                    output_file='{}/train.tfrecord'.format(DATA_DIR))
convert_to_tfrecord(X_eval, y_eval,
                    output_file='{}/eval.tfrecord'.format(DATA_DIR))
convert_to_tfrecord(X_test, y_test,
                    output_file='{}/test.tfrecord'.format(DATA_DIR))

## Input pipeline

In [None]:
BATCH_SIZE = 50
N_EPOCHS = 1
  
def generate_input_fn(file_pattern, mode, batch_size=BATCH_SIZE, count=N_EPOCHS):
  
  def parse_record(serialized_example):
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.string),            
        })
    # Normalize from [0, 255] to [0.0, 1.0]
    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [28*28]) / 255.0
    label = tf.decode_raw(features['label'], tf.uint8)
    label = tf.reshape(label, [])
    label = tf.one_hot(label, 10, dtype=tf.int32)
    return image, label

  def input_fn():
    files = tf.data.Dataset.list_files(file_pattern)
    dataset = tf.data.TFRecordDataset(files)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
      dataset = dataset.cache()
      dataset = dataset.shuffle(10000)
      dataset = dataset.repeat(count=count)
      
    dataset = dataset.map(parse_record)
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()

    return features, labels
  
  return input_fn

In [None]:
train_input_fn = generate_input_fn(
    file_pattern='{}/train.tfrecord'.format(DATA_DIR),
    mode=tf.estimator.ModeKeys.TRAIN,
    batch_size=BATCH_SIZE, count=N_EPOCHS)

eval_input_fn = generate_input_fn(
    file_pattern='{}/eval.tfrecord'.format(DATA_DIR),
    mode=tf.estimator.ModeKeys.EVAL, count=1)

test_input_fn = generate_input_fn(
    file_pattern='{}/test.tfrecord'.format(DATA_DIR),
    mode=tf.estimator.ModeKeys.PREDICT)

## Model Definition

In [None]:
def get_keras_model():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.InputLayer(input_shape=[28*28]))
  model.add(tf.keras.layers.Dense(300, activation='relu'))
  model.add(tf.keras.layers.Dense(100, activation='relu'))
  model.add(tf.keras.layers.Dense(10, activation='softmax'))
  model.compile(loss='categorical_crossentropy',
                optimizer=tf.keras.optimizers.SGD(lr=0.005),
                metrics=['accuracy'])
  return model

## Create Estimator

In [None]:
model = get_keras_model()

In [None]:
estimator = tf.keras.estimator.model_to_estimator(model, model_dir=MODEL_DIR)

## Train

In [None]:
estimator.train(input_fn=train_input_fn, steps=1000)

## Eval

In [None]:
estimator.evaluate(input_fn=eval_input_fn)

## Predict

In [None]:
iterator = estimator.predict(input_fn=test_input_fn)

In [None]:
next(iterator)

In [None]:
iterator.close()

## Export

In [None]:
input_feature_name = model.input.name.split(':')[0]

In [None]:
# identity function
def preprocess(x):
  return tf.reshape(x, [-1, 28*28]) / 255.0

def serving_input_fn():
  receiver_tensor = {'X': tf.placeholder(tf.float32, shape=[None, 28, 28])}
  features = {input_feature_name: tf.map_fn(preprocess, receiver_tensor['X'])}
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensor)

In [None]:
b_export_dir = estimator.export_savedmodel(
    export_dir_base=EXPORT_DIR, serving_input_receiver_fn=serving_input_fn)
export_dir = b_export_dir.decode('utf-8')

In [None]:
#!saved_model_cli show --dir {export_dir}
#!saved_model_cli show --dir {export_dir} --tag_set serve
!saved_model_cli show --dir {export_dir} --tag_set serve --signature_def serving_default

In [None]:
N_EXAMPLES = 1000

predictor_fn = tf.contrib.predictor.from_saved_model(
  export_dir=export_dir, signature_def_key='serving_default')

_X = X_test[:N_EXAMPLES]
_y = y_test[:N_EXAMPLES]

output = predictor_fn({'X': _X})
class_ids = np.argmax(output['dense_3'], axis=2).reshape(-1)

accuracy = np.sum(_y == class_ids)/float(N_EXAMPLES)
print(accuracy)

## Train_and_Evaluate

In [None]:
# classifier.train
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10000)

# Saved Model
exporter = tf.estimator.LatestExporter(
    name='export', serving_input_receiver_fn=serving_input_fn)

# Validation option.
eval_spec = tf.estimator.EvalSpec(
    input_fn=eval_input_fn,
    steps=None,          # stop when it catches EOF Exception
    start_delay_secs=60, # start evaluating after N seconds
    throttle_secs=60,
    exporters=exporter,
)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)