##### Copyright 2019 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Estimators

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/guide/estimator"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/estimator.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/estimator.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/estimator.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This document introduces `tf.estimator`—a high-level TensorFlow
API. Estimators encapsulate the following actions:

* training
* evaluation
* prediction
* export for serving

You may either use the pre-made Estimators we provide or write your
own custom Estimators.  All Estimators—whether pre-made or custom—are
classes based on the `tf.estimator.Estimator` class.

For a quick example try [Estimator tutorials](../tutorials/estimator/linear.ipynb). For an overview of the API design, see the [white paper](https://arxiv.org/abs/1708.02637).

## Advantages

Similar to a `tf.keras.Model`, an `estimator` is a model-level abstraction. The `tf.estimator` provides some capabilities currently still under development for `tf.keras`. These are:

  * Parameter server based training
  * Full [TFX](http://tensorflow.org/tfx) integration.

## Estimators Capabilities
Estimators provide the following benefits:

* You can run Estimator-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Estimator-based models on CPUs, GPUs, or TPUs without recoding your model.
* Estimators provide a safe distributed training loop that controls how and when to:    
    * load data
    * handle exceptions
    * create checkpoint files and recover from failures
    * save summaries for TensorBoard

When writing an application with Estimators, you must separate the data input
pipeline from the model.  This separation simplifies experiments with
different data sets.

## Pre-made Estimators

Pre-made Estimators enable you to work at a much higher conceptual level than the base TensorFlow APIs. You no longer have to worry about creating the computational graph or sessions since Estimators handle all the "plumbing" for you. Furthermore, pre-made Estimators let you experiment with different model architectures by making only minimal code changes.  `tf.estimator.DNNClassifier`, for example, is a pre-made Estimator class that trains classification models based on dense, feed-forward neural networks.

### Structure of a pre-made Estimators program

A TensorFlow program relying on a pre-made Estimator typically consists of the following four steps:

#### 1. Write one or more dataset importing functions.

For example, you might create one function to import the training set and another function to import the test set. Each dataset importing function must return two objects:

* a dictionary in which the keys are feature names and the values are Tensors (or SparseTensors) containing the corresponding feature data
* a Tensor containing one or more labels

For example, the following code illustrates the basic skeleton for an input function:

```
def input_fn(dataset):
    ...  # manipulate dataset, extracting the feature dict and the label
    return feature_dict, label
```

See [data guide](../../guide/data.md) for details.

#### 2. Define the feature columns.

Each `tf.feature_column` identifies a feature name, its type, and any input pre-processing. For example, the following snippet creates three feature columns that hold integer or floating-point data. The first two feature columns simply identify the feature's name and type. The third feature column also specifies a lambda the program will invoke to scale the raw data:

```
# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column(
  'median_education',
  normalizer_fn=lambda x: x - global_education_mean)
```
For further information, see the [feature columns tutorial](https://www.tensorflow.org/tutorials/keras/feature_columns).

#### 3. Instantiate the relevant pre-made Estimator.

For example, here's a sample instantiation of a pre-made Estimator named `LinearClassifier`:

```
# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
  feature_columns=[population, crime_rate, median_education])
```
For further information, see the [linear classifier tutorial](https://www.tensorflow.org/tutorials/estimator/linear).

#### 4. Call a training, evaluation, or inference method.

For example, all Estimators provide a `train` method, which trains a model.

```
# `input_fn` is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)
```
You can see an example of this below.

### Benefits of pre-made Estimators

Pre-made Estimators encode best practices, providing the following benefits:

* Best practices for determining where different parts of the computational graph should run, implementing strategies on a single machine or on a
    cluster.
*   Best practices for event (summary) writing and universally useful
    summaries.

If you don't use pre-made Estimators, you must implement the preceding features yourself.

## Custom Estimators

The heart of every Estimator—whether pre-made or custom—is its *model function*, which is a method that builds graphs for training, evaluation, and prediction. When you are using a pre-made Estimator, someone else has already implemented the model function. When relying on a custom Estimator, you must write the model function yourself.

## Recommended workflow

1. Assuming a suitable pre-made Estimator exists, use it to build your first model and use its results to establish a baseline.
2. Build and test your overall pipeline, including the integrity and reliability of your data with this pre-made Estimator.
3. If suitable alternative pre-made Estimators are available, run experiments to determine which pre-made Estimator produces the best results.
4. Possibly, further improve your model by building your own custom Estimator.

In [2]:
import tensorflow as tf

In [3]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

## Create an Estimator from a Keras model

You can convert existing Keras models to Estimators with `tf.keras.estimator.model_to_estimator`. Doing so enables your Keras
model to access Estimator's strengths, such as distributed training.

Instantiate a Keras MobileNet V2 model and compile the model with the optimizer, loss, and metrics to train with:

In [4]:
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])

Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5


   8192/9406464 [..............................] - ETA: 3:29

  40960/9406464 [..............................] - ETA: 1:23

 106496/9406464 [..............................] - ETA: 47s 

 229376/9406464 [..............................] - ETA: 29s

 491520/9406464 [>.............................] - ETA: 16s

1015808/9406464 [==>...........................] - ETA: 9s 

2064384/9406464 [=====>........................] - ETA: 4s







Create an `Estimator` from the compiled Keras model. The initial model state of the Keras model is preserved in the created `Estimator`:

In [5]:
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)

INFO:tensorflow:Using default config.




INFO:tensorflow:Using the Keras model provided.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_4cm12_n', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


Treat the derived `Estimator` as you would with any other `Estimator`.

In [6]:
IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

In [7]:
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

To train, call Estimator's train function:

In [8]:
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)

Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


[1mDownloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...[0m






Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteO0ACND/cats_vs_dogs-train.tfrecord


[1mDataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.[0m
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp_4cm12_n/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp_4cm12_n/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: /tmp/tmp_4cm12_n/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting from: /tmp/tmp_4cm12_n/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-started 158 variables.


INFO:tensorflow:Warm-started 158 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp_4cm12_n/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp_4cm12_n/model.ckpt.


INFO:tensorflow:loss = 0.7499167, step = 0


INFO:tensorflow:loss = 0.7499167, step = 0


INFO:tensorflow:global_step/sec: 32.7934


INFO:tensorflow:global_step/sec: 32.7934


INFO:tensorflow:loss = 0.3240549, step = 100 (3.051 sec)


INFO:tensorflow:loss = 0.3240549, step = 100 (3.051 sec)


INFO:tensorflow:global_step/sec: 36.8126


INFO:tensorflow:global_step/sec: 36.8126


INFO:tensorflow:loss = 0.12278967, step = 200 (2.716 sec)


INFO:tensorflow:loss = 0.12278967, step = 200 (2.716 sec)


INFO:tensorflow:global_step/sec: 37.0478


INFO:tensorflow:global_step/sec: 37.0478


INFO:tensorflow:loss = 0.2140803, step = 300 (2.699 sec)


INFO:tensorflow:loss = 0.2140803, step = 300 (2.699 sec)


INFO:tensorflow:global_step/sec: 37.223


INFO:tensorflow:global_step/sec: 37.223


INFO:tensorflow:loss = 0.10027408, step = 400 (2.686 sec)


INFO:tensorflow:loss = 0.10027408, step = 400 (2.686 sec)


INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp_4cm12_n/model.ckpt.


INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp_4cm12_n/model.ckpt.


INFO:tensorflow:Loss for final step: 0.15134794.


INFO:tensorflow:Loss for final step: 0.15134794.


<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fc0940b90b8>

Similarly, to evaluate, call the Estimator's evaluate function:

In [9]:
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)

INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-03-28T01:34:35Z


INFO:tensorflow:Starting evaluation at 2020-03-28T01:34:35Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/tmp_4cm12_n/model.ckpt-500


INFO:tensorflow:Restoring parameters from /tmp/tmp_4cm12_n/model.ckpt-500


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Inference Time : 2.04883s


INFO:tensorflow:Inference Time : 2.04883s


INFO:tensorflow:Finished evaluation at 2020-03-28-01:34:38


INFO:tensorflow:Finished evaluation at 2020-03-28-01:34:38


INFO:tensorflow:Saving dict for global step 500: accuracy = 0.446875, global_step = 500, loss = 0.7215727


INFO:tensorflow:Saving dict for global step 500: accuracy = 0.446875, global_step = 500, loss = 0.7215727


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp_4cm12_n/model.ckpt-500


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp_4cm12_n/model.ckpt-500


{'accuracy': 0.446875, 'loss': 0.7215727, 'global_step': 500}

For more details, please refer to the documentation for `tf.keras.estimator.model_to_estimator`.