# Multi-worker training with Estimator

## Overview

<div class="alert alert-block alert-info">
    <b>Note:</b> While you can use Estimators with tf.distribute API, it's recommended to use Keras with tf.distribute, see multi-worker training with Keras. Estimator training with tf.distribute.Strategy has limited support.
</div>

This tutorial demonstrates how tf.distribute.Strategy can be used for distributed multi-worker training with tf.estimator. If you write your code using tf.estimator, and you're interested in scaling beyond a single machine with high performance, this tutorial is for you.

Before getting started, please read the distribution strategy guide. The multi-GPU training tutorial is also relevant, because this tutorial uses the same model.

## Setup

First, setup TensorFlow and the necessary imports.

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals 

In [2]:
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os, json

## Input function

This tutorial uses the MNIST dataset from TensorFlow Datasets. The code here is similar to the multi-GPU training tutorial with one key difference: when using Estimator for multi-worker training, it is necessary to shard the dataset by the number of workers to ensure model convergence. The input data is sharded by worker index, so that each worker processes 1/num_workers distinct portions of the dataset.

In [3]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Another reasonable approach to achieve convergence would be to shuffle the dataset with distinct seeds at each worker.

## Multi-worker configuration

One of the key differences in this tutorial (compared to the multi-GPU training tutorial) is the multi-worker setup. The TF_CONFIG environment variable is the standard way to specify the cluster configuration to each worker that is part of the cluster.

There are two components of TF_CONFIG: cluster and task. cluster provides information about the entire cluster, namely the workers and parameter servers in the cluster. task provides information about the current task. In this example, the task type is worker and the task index is 0.

For illustration purposes, this tutorial shows how to set a TF_CONFIG with 2 workers on localhost. In practice, you would create multiple workers on an external IP address and port, and set TF_CONFIG on each worker appropriately, i.e. modify the task index.

<div class="alert alert-block alert-danger">
    <b>Warning:</b> Do not execute the following code in Colab. TensorFlow's runtime will attempt to create a gRPC server at the specified IP address and port, which will likely fail.
</div>

In [4]:
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

## Define the model

Write the layers, the optimizer, and the loss function for training. This tutorial defines the model with Keras layers, similar to the multi-GPU training tutorial.

In [5]:
LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

<div class="alert alert-block alert-info">
    <b>Note:</b> Note that while the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.
</div>

## MultiWorkerMirroredStrategy

To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy. MultiWorkerMirroredStrategy creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy guide has more details about this strategy.

In [6]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

INFO:tensorflow:Enabled multi-worker collective ops with available devices: ['/job:worker/replica:0/task:0/device:CPU:0', '/job:worker/replica:0/task:0/device:GPU:0']
INFO:tensorflow:Using MirroredStrategy with devices ('/job:worker/task:0/device:GPU:0',)
INFO:tensorflow:MultiWorkerMirroredStrategy with cluster_spec = {'worker': ['localhost:12345', 'localhost:23456']}, task_type = 'worker', task_id = 0, num_workers = 2, local_devices = ('/job:worker/task:0/device:GPU:0',), communication = CollectiveCommunication.AUTO


## Train and evaluate the model

Next, specify the distribution strategy in the RunConfig for the estimator, and train and evaluate by invoking tf.estimator.train_and_evaluate. This tutorial distributes only the training by specifying the strategy via train_distribute. It is also possible to distribute the evaluation via eval_distribute.

In [None]:
config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)

INFO:tensorflow:TF_CONFIG environment variable: {'cluster': {'worker': ['localhost:12345', 'localhost:23456']}, 'task': {'type': 'worker', 'index': 0}}
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:RunConfig initialized for Distribute Coordinator with INDEPENDENT_WORKER mode
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x000002F03C5066C8>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': N



Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Collective batch_all_reduce: 6 all-reduces, num_workers = 2, communication_hint = AUTO, num_packs = 1


INFO:tensorflow:Collective batch_all_reduce: 6 all-reduces, num_workers = 2, communication_hint = AUTO, num_packs = 1


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Collective batch_all_reduce: 1 all-reduces, num_workers = 2, communication_hint = AUTO, num_packs = 1


INFO:tensorflow:Collective batch_all_reduce: 1 all-reduces, num_workers = 2, communication_hint = AUTO, num_packs = 1


Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.



Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.



Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.





INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:all_hooks [<tensorflow_estimator.python.estimator.util.DistributedIteratorInitializerHook object at 0x000002F013F2AD88>, <tensorflow.python.training.basic_session_run_hooks.NanTensorHook object at 0x000002F03C9A9048>, <tensorflow.python.training.basic_session_run_hooks.LoggingTensorHook object at 0x000002F03C85E488>, <tensorflow.python.training.basic_session_run_hooks.StepCounterHook object at 0x000002F01A3EA908>, <tensorflow.python.training.basic_session_run_hooks.SummarySaverHook object at 0x000002F03C961B08>, <tensorflow.python.training.basic_session_run_hooks.CheckpointSaverHook object at 0x000002F013EF5B48>]


INFO:tensorflow:all_hooks [<tensorflow_estimator.python.estimator.util.DistributedIteratorInitializerHook object at 0x000002F013F2AD88>, <tensorflow.python.training.basic_session_run_hooks.NanTensorHook object at 0x000002F03C9A9048>, <tensorflow.python.training.basic_session_run_hooks.LoggingTensorHook object at 0x000002F03C85E488>, <tensorflow.python.training.basic_session_run_hooks.StepCounterHook object at 0x000002F01A3EA908>, <tensorflow.python.training.basic_session_run_hooks.SummarySaverHook object at 0x000002F03C961B08>, <tensorflow.python.training.basic_session_run_hooks.CheckpointSaverHook object at 0x000002F013EF5B48>]


INFO:tensorflow:Creating chief session creator with config: device_filters: "/job:worker/task:0"
allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
    scoped_allocator_optimization: ON
    scoped_allocator_opts {
      enable_op: "CollectiveReduce"
    }
  }
}
experimental {
  collective_group_leader: "/job:worker/replica:0/task:0"
}



INFO:tensorflow:Creating chief session creator with config: device_filters: "/job:worker/task:0"
allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
    scoped_allocator_optimization: ON
    scoped_allocator_opts {
      enable_op: "CollectiveReduce"
    }
  }
}
experimental {
  collective_group_leader: "/job:worker/replica:0/task:0"
}



Instructions for updating:
Use the iterator's `initializer` property instead.


Instructions for updating:
Use the iterator's `initializer` property instead.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


## Optimize training performance

You now have a model and a multi-worker capable Estimator powered by tf.distribute.Strategy. You can try the following techniques to optimize performance of multi-worker training:

* Increase the batch size: The batch size specified here is per-GPU. In general, the largest batch size that fits the GPU memory is advisable.
* Cast variables: Cast the variables to tf.float if possible. The official ResNet model includes an example of how this can be done.
* Use collective communication: MultiWorkerMirroredStrategy provides multiple collective communication implementations.

    * RING implements ring-based collectives using gRPC as the cross-host communication layer.
    * NCCL uses Nvidia's NCCL to implement collectives.
    * AUTO defers the choice to the runtime.
    
The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the communication parameter of MultiWorkerMirroredStrategy's constructor, e.g. communication=tf.distribute.experimental.CollectiveCommunication.NCCL.

## Other code examples

1. End to end example for multi worker training in tensorflow/ecosystem using Kubernetes templates. This example starts with a Keras model and converts it to an Estimator using the tf.keras.estimator.model_to_estimator API.
2. Official models, many of which can be configured to run multiple distribution strategies.