## Confidentiality

This notebook is downloaded from Tensorflow and is for demonstrational purposes only.

Please do not copy or distribute this notebook.

## Overview

Note: While you can use Estimators with `tf.distribute` API, it's recommended to use Keras with `tf.distribute`, see [multi-worker training with Keras](multi_worker_with_keras.ipynb). Estimator training with `tf.distribute.Strategy` has limited support.

Here's demonstrated how `tf.distribute.Strategy` can be used for distributed multi-worker training with `tf.estimator`.  If you write code using `tf.estimator`, and you're interested in scaling beyond a single machine with high performance.

Before getting started, please read the [distribution strategy](../../guide/distributed_training.ipynb) guide.  The [multi-GPU training tutorial](./keras.ipynb) is also relevant, because here the same model is used.


## Setup

First, setup TensorFlow and the necessary imports.

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

import json, math, os, sys

## Input function
Next the MNIST dataset is used from [TensorFlow Datasets](https://www.tensorflow.org/datasets).  The code here is similar to the [multi-GPU training tutorial](./keras.ipynb) with one key difference: when using Estimator for multi-worker training, it is necessary to shard the dataset by the number of workers to ensure model convergence.  The input data is sharded by worker index, so that each worker processes `1/num_workers` distinct portions of the dataset.

In [None]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Another reasonable approach to achieve convergence would be to shuffle the dataset with distinct seeds at each worker.

## Multi-worker configuration

One of the key differences here (compared to the [multi-GPU training tutorial](./keras.ipynb)) is the multi-worker setup.  The `TF_CONFIG` environment variable is the standard way to specify the cluster configuration to each worker that is part of the cluster.

There are two components of `TF_CONFIG`: `cluster` and `task`.  `cluster` provides information about the entire cluster, namely the workers and parameter servers in the cluster.  `task` provides information about the current task. The first component `cluster` is the same for all workers and parameter servers in the cluster, and the second component `task` is different on each worker and parameter server and specifies its own `type` and `index`. In this example, the task `type` is `worker` and the task `index` is `0`.

For illustration purposes, here's shown how to set a `TF_CONFIG` with 2 workers on `localhost`. In practice, you would create multiple workers on an external IP address and port, and set `TF_CONFIG` on each worker appropriately, i.e. modify the task `index`.

*Do not execute the following code in Colab.*  TensorFlow's runtime will attempt to create a gRPC server at the specified IP address and port, which will likely fail.

```
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})
```


## Defining the model

Writing the layers, the optimizer, and the loss function for training. Here the model is defined with Keras layers, similar to the [multi-GPU training tutorial](./keras.ipynb).

In [None]:
LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

Note: Although the learning rate is fixed here, in general it may be necessary to adjust the learning rate based on the global batch size.

## MultiWorkerMirroredStrategy

To train the model, use an instance of `tf.distribute.experimental.MultiWorkerMirroredStrategy`.  `MultiWorkerMirroredStrategy` creates copies of all variables in the model's layers on each device across all workers.  It uses `CollectiveOps`, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync.  The [`tf.distribute.Strategy` guide](../../guide/distributed_training.ipynb) has more details about this strategy.

In [None]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CollectiveCommunication.AUTO


## Training and evaluating the model

Next, specifying the distribution strategy in the `RunConfig` for the estimator, and training and evaluating by invoking `tf.estimator.train_and_evaluate`. Here only the training is distributed by specifying the strategy via `train_distribute`. It is also possible to distribute the evaluation via `eval_distribute`.

In [None]:
config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)

INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x7f39e47b1c88>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master':

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.



Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.



Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


Instructions for updating:
Use the iterator's `initializer` property instead.


Instructions for updating:
Use the iterator's `initializer` property instead.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...


INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...


INFO:tensorflow:loss = 2.3051615, step = 0


INFO:tensorflow:loss = 2.3051615, step = 0


INFO:tensorflow:global_step/sec: 30.6716


INFO:tensorflow:global_step/sec: 30.6716


INFO:tensorflow:loss = 2.3318462, step = 100 (3.266 sec)


INFO:tensorflow:loss = 2.3318462, step = 100 (3.266 sec)


INFO:tensorflow:global_step/sec: 30.5966


INFO:tensorflow:global_step/sec: 30.5966


INFO:tensorflow:loss = 2.3198657, step = 200 (3.269 sec)


INFO:tensorflow:loss = 2.3198657, step = 200 (3.269 sec)


INFO:tensorflow:global_step/sec: 30.5773


INFO:tensorflow:global_step/sec: 30.5773


INFO:tensorflow:loss = 2.2994285, step = 300 (3.269 sec)


INFO:tensorflow:loss = 2.2994285, step = 300 (3.269 sec)


INFO:tensorflow:global_step/sec: 30.9944


INFO:tensorflow:global_step/sec: 30.9944


INFO:tensorflow:loss = 2.3085122, step = 400 (3.225 sec)


INFO:tensorflow:loss = 2.3085122, step = 400 (3.225 sec)


INFO:tensorflow:global_step/sec: 29.833


INFO:tensorflow:global_step/sec: 29.833


INFO:tensorflow:loss = 2.2896461, step = 500 (3.357 sec)


INFO:tensorflow:loss = 2.2896461, step = 500 (3.357 sec)


INFO:tensorflow:global_step/sec: 30.8019


INFO:tensorflow:global_step/sec: 30.8019


INFO:tensorflow:loss = 2.2889292, step = 600 (3.243 sec)


INFO:tensorflow:loss = 2.2889292, step = 600 (3.243 sec)


INFO:tensorflow:global_step/sec: 30.1936


INFO:tensorflow:global_step/sec: 30.1936


INFO:tensorflow:loss = 2.2855854, step = 700 (3.315 sec)


INFO:tensorflow:loss = 2.2855854, step = 700 (3.315 sec)


INFO:tensorflow:global_step/sec: 32.1128


INFO:tensorflow:global_step/sec: 32.1128


INFO:tensorflow:loss = 2.2826686, step = 800 (3.108 sec)


INFO:tensorflow:loss = 2.2826686, step = 800 (3.108 sec)


INFO:tensorflow:global_step/sec: 43.9904


INFO:tensorflow:global_step/sec: 43.9904


INFO:tensorflow:loss = 2.2744117, step = 900 (2.274 sec)


INFO:tensorflow:loss = 2.2744117, step = 900 (2.274 sec)


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...


INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-10-14T13:20:16Z


INFO:tensorflow:Starting evaluation at 2020-10-14T13:20:16Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [10/100]


INFO:tensorflow:Evaluation [10/100]


INFO:tensorflow:Evaluation [20/100]


INFO:tensorflow:Evaluation [20/100]


INFO:tensorflow:Evaluation [30/100]


INFO:tensorflow:Evaluation [30/100]


INFO:tensorflow:Evaluation [40/100]


INFO:tensorflow:Evaluation [40/100]


INFO:tensorflow:Evaluation [50/100]


INFO:tensorflow:Evaluation [50/100]


INFO:tensorflow:Evaluation [60/100]


INFO:tensorflow:Evaluation [60/100]


INFO:tensorflow:Evaluation [70/100]


INFO:tensorflow:Evaluation [70/100]


INFO:tensorflow:Evaluation [80/100]


INFO:tensorflow:Evaluation [80/100]


INFO:tensorflow:Evaluation [90/100]


INFO:tensorflow:Evaluation [90/100]


INFO:tensorflow:Evaluation [100/100]


INFO:tensorflow:Evaluation [100/100]


INFO:tensorflow:Inference Time : 2.65838s


INFO:tensorflow:Inference Time : 2.65838s


INFO:tensorflow:Finished evaluation at 2020-10-14-13:20:18


INFO:tensorflow:Finished evaluation at 2020-10-14-13:20:18


INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2764425


INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2764425


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Loss for final step: 1.133571.


INFO:tensorflow:Loss for final step: 1.133571.


({'global_step': 938, 'loss': 2.2764425}, [])

## Optimizing training performance

There's now a model and a multi-worker capable Estimator powered by `tf.distribute.Strategy`.  You can try the following techniques to optimize performance of multi-worker training:

*   *Increase the batch size:* The batch size specified here is per-GPU.  In general, the largest batch size that fits the GPU memory is advisable.
*   *Cast variables:* Cast the variables to `tf.float` if possible.  The official ResNet model includes [an example](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466) of how this can be done.
*   *Use collective communication:* `MultiWorkerMirroredStrategy` provides multiple [collective communication implementations](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/cross_device_ops.py).  
    * `RING` implements ring-based collectives using gRPC as the cross-host communication layer.  
    * `NCCL` uses [Nvidia's NCCL](https://developer.nvidia.com/nccl) to implement collectives.  
    * `AUTO` defers the choice to the runtime.
    
    The best choice of collective implementation depends upon the number and type of GPUs, and the network interconnect in the cluster.  To override the automatic choice, specify a valid value to the `communication` parameter of `MultiWorkerMirroredStrategy`'s constructor, e.g. `communication=tf.distribute.experimental.CollectiveCommunication.NCCL`.

Visit the [Performance section](../../guide/function.ipynb) in the guide to learn more about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of TensorFlow models.
