# Create an Estimator from a Keras model

TensorFlow Estimators are fully supported in TensorFlow, and can be created from new and existing tf.keras models. This tutorial contains a complete, minimal example of that process.

In [None]:
from warnings import filterwarnings
filterwarnings('ignore')

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

In [3]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import os

gpus = tf.config.experimental.list_physical_devices('GPU')
# GPU 메모리 제한하기
MEMORY_LIMIT_CONFIG = [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)]
print(MEMORY_LIMIT_CONFIG, gpus)
tf.config.experimental.set_virtual_device_configuration(gpus[0], MEMORY_LIMIT_CONFIG)

[VirtualDeviceConfiguration(memory_limit=2048)] [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [4]:
tf.__version__

'2.0.0'

## Create a simple Keras model.

In [5]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(1, activation='sigmoid')])

model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 16)                80        
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
Total params: 97
Trainable params: 97
Non-trainable params: 0
_________________________________________________________________


## Create an input function

Use the [Datasets API](https://www.tensorflow.org/guide/data) to scale to large datasets or multi-device training.

Estimators need control of when and how their input pipeline is built. To allow this, they require an "Input function" or `input_fn`. The Estimator will call this function with no arguments. The `input_fn` must return a `tf.data.Dataset`.

* `features`: dictionary of input
* `labels`: ground truth


>`input_fn`: A function that provides input data for training as minibatches.
See [Premade Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)  
for more information. The function should construct and return one of the following:
    * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple `(features, labels)` with same constraints as below.
    * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary of string feature name to `Tensor` and `labels` is a
        `Tensor` or a dictionary of string label name to `Tensor`. Both `features` and `labels` are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs.

In [6]:
def input_fn():
    split = tfds.Split.TRAIN
    dataset = tfds.load('iris', split=split, as_supervised=True)
    dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
    dataset = dataset.batch(32).repeat()
    return dataset

In [7]:
ds = input_fn()
ds.element_spec

({'dense_input': TensorSpec(shape=(None, 4), dtype=tf.float32, name=None)},
 TensorSpec(shape=(None,), dtype=tf.int64, name=None))

In [8]:
# example 
# ft_batch is a dictionary, which contains features(tensors)
ft_batch, label_batch = next(iter(ds))
for k, v in ft_batch.items():
    print('{}.shape: {}'.format(k, v.shape))
print('label.shape:', label_batch.shape)

dense_input.shape: (32, 4)
label.shape: (32,)


## Create an Estimator from the tf.keras model.
A `tf.keras.Model` can be trained with the [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator) API by converting the model to an `tf.estimator.Estimator` object with `tf.keras.estimator.model_to_estimator`.

In [9]:
root_dir = './estimator'
model_dir = root_dir
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=model_dir)

INFO:tensorflow:Using default config.


INFO:tensorflow:Using default config.


INFO:tensorflow:Using the Keras model provided.


INFO:tensorflow:Using the Keras model provided.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Using config: {'_model_dir': './estimator', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f78b06d6750>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': './estimator', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f78b06d6750>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


Train and evaluate the estimator.  
Note that <font color=red> For each step, calls `input_fn`, which returns one batch of data. </font>

In [10]:
keras_estimator.train(input_fn=input_fn, steps=25)

Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./estimator/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./estimator/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: ./estimator/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting from: ./estimator/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into ./estimator/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ./estimator/model.ckpt.


INFO:tensorflow:loss = 111.757324, step = 0


INFO:tensorflow:loss = 111.757324, step = 0


INFO:tensorflow:Saving checkpoints for 25 into ./estimator/model.ckpt.


INFO:tensorflow:Saving checkpoints for 25 into ./estimator/model.ckpt.


INFO:tensorflow:Loss for final step: 83.65043.


INFO:tensorflow:Loss for final step: 83.65043.


<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f78b03e6bd0>

In [11]:
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))

INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-01-19T23:19:43Z


INFO:tensorflow:Starting evaluation at 2020-01-19T23:19:43Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./estimator/model.ckpt-25


INFO:tensorflow:Restoring parameters from ./estimator/model.ckpt-25


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Finished evaluation at 2020-01-19-23:19:43


INFO:tensorflow:Finished evaluation at 2020-01-19-23:19:43


INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 100.48047


INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 100.48047


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: ./estimator/model.ckpt-25


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: ./estimator/model.ckpt-25


Eval result: {'loss': 100.48047, 'global_step': 25}
