# Multi-worker training with Estimator

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_estimator.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_estimator.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/multi_worker_with_estimator.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

> Warning: Estimators are not recommended for new code.  Estimators run `v1.Session`-style code which is more difficult to write correctly, and can behave unexpectedly, especially when combined with TF 2 code. Estimators do fall under [compatibility guarantees](https://tensorflow.org/guide/versions), but will receive no fixes other than security vulnerabilities. See the [migration guide](https://tensorflow.org/guide/migrate) for details.

## Overview

Note: While you can use Estimators with `tf.distribute` API, it's recommended to use Keras with `tf.distribute`, see [multi-worker training with Keras](multi_worker_with_keras.ipynb). Estimator training with `tf.distribute.Strategy` has limited support.


This tutorial demonstrates how `tf.distribute.Strategy` can be used for distributed multi-worker training with `tf.estimator`.  If you write your code using `tf.estimator`, and you're interested in scaling beyond a single machine with high performance, this tutorial is for you.

Before getting started, please read the [distribution strategy](../../guide/distributed_training.ipynb) guide.  The [multi-GPU training tutorial](./keras.ipynb) is also relevant, because this tutorial uses the same model.


## Setup

First, setup TensorFlow and the necessary imports.

In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf

import os, json

Note: Starting from TF2.4 multi worker mirrored strategy fails with estimators if run with eager enabled (the default). The error in TF2.4 is `TypeError: cannot pickle '_thread.lock' object`, See [issue #46556](https://github.com/tensorflow/tensorflow/issues/46556) for details. The workaround is to disable eager execution.

In [2]:
tf.compat.v1.disable_eager_execution()

## Input function

This tutorial uses the MNIST dataset from [TensorFlow Datasets](https://www.tensorflow.org/datasets).  The code here is similar to the [multi-GPU training tutorial](./keras.ipynb) with one key difference: when using Estimator for multi-worker training, it is necessary to shard the dataset by the number of workers to ensure model convergence.  The input data is sharded by worker index, so that each worker processes `1/num_workers` distinct portions of the dataset.

In [3]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Another reasonable approach to achieve convergence would be to shuffle the dataset with distinct seeds at each worker.

## Multi-worker configuration

One of the key differences in this tutorial (compared to the [multi-GPU training tutorial](./keras.ipynb)) is the multi-worker setup.  The `TF_CONFIG` environment variable is the standard way to specify the cluster configuration to each worker that is part of the cluster.

There are two components of `TF_CONFIG`: `cluster` and `task`.  `cluster` provides information about the entire cluster, namely the workers and parameter servers in the cluster.  `task` provides information about the current task. The first component `cluster` is the same for all workers and parameter servers in the cluster, and the second component `task` is different on each worker and parameter server and specifies its own `type` and `index`. In this example, the task `type` is `worker` and the task `index` is `0`.

For illustration purposes, this tutorial shows how to set a `TF_CONFIG` with 2 workers on `localhost`.  In practice, you would create multiple workers on an external IP address and port, and set `TF_CONFIG` on each worker appropriately, i.e., modify the task `index`.

Warning: *Do not execute the following code in Colab.*  TensorFlow's runtime will attempt to create a gRPC server at the specified IP address and port, which will likely fail. See the [keras version](multi_worker_with_keras.ipynb) of this tutorial for an example of how you can test run multiple workers on a single machine.

```
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})
```


## Define the model

Write the layers, the optimizer, and the loss function for training.  This tutorial defines the model with Keras layers, similar to the [multi-GPU training tutorial](./keras.ipynb).

In [4]:
LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

Note: Although the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.

## MultiWorkerMirroredStrategy

To train the model, use an instance of `tf.distribute.experimental.MultiWorkerMirroredStrategy`.  `MultiWorkerMirroredStrategy` creates copies of all variables in the model's layers on each device across all workers.  It uses `CollectiveOps`, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync.  The [`tf.distribute.Strategy` guide](../../guide/distributed_training.ipynb) has more details about this strategy.

In [5]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CommunicationImplementation.AUTO


2024-02-26 20:16:08.098151: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2024-02-26 20:16:08.098176: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-02-26 20:16:08.098182: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-02-26 20:16:08.098212: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-02-26 20:16:08.098226: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


## Train and evaluate the model

Next, specify the distribution strategy in the `RunConfig` for the estimator, and train and evaluate by invoking `tf.estimator.train_and_evaluate`.  This tutorial distributes only the training by specifying the strategy via `train_distribute`.  It is also possible to distribute the evaluation via `eval_distribute`.

In [6]:
config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)

Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental object at 0x29a169be0>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service':

INFO:tensorflow:Calling model_fn.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use the iterator's `initializer` property instead.


Instructions for updating:
Use the iterator's `initializer` property instead.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


2024-02-26 20:16:09.471952: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-02-26 20:16:09.471969: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2024-02-26 20:16:09.474986: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:375] MLIR V1 optimization pass is not enabled
2024-02-26 20:16:09.476574: W tensorflow/core/common_runtime/colocation_graph.cc:1213] Failed to place the graph without changing the devices of some resources. Some of the operations (that had to be colocated with resource generating operations) are not supported on the resources' devices. Current candidate devices are [
  /job:localhost/replica:0/tas

INFO:tensorflow:Done running local_init_op.


2024-02-26 20:16:09.565159: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
INFO:tensorflow:Done running local_init_op.
2024-02-26 20:16:09.571434: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:09.597206: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...


INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
2024-02-26 20:16:09.879254: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...


Instructions for updating:
Use tf.keras instead.


2024-02-26 20:16:10.086768: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.096939: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.097161: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.098729: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.100024: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.101210: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.102031: I tensorflow/core/grappler/optimizers/cust

Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.
2024-02-26 20:16:10.126408: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.140036: W tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc:926] error: NOT_FOUND: No attr named 'num_host_args' in NodeDef:
	 [[{{node sequential/conv2d/Relu}}]]
2024-02-26 20:16:10.140047: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] scoped_allocator_optimizer failed: NOT_FOUND: No attr named 'num_host_args' in NodeDef:
	 [[{{node sequential/conv2d/Relu}}]]


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


INFO:tensorflow:loss = 2.3218899, step = 0


INFO:tensorflow:loss = 2.3218899, step = 0
2024-02-26 20:16:10.463860: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:10.476882: W tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc:926] error: NOT_FOUND: No attr named 'num_host_args' in NodeDef:
	 [[{{node sequential/conv2d/Relu}}]]
2024-02-26 20:16:10.476895: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] scoped_allocator_optimizer failed: NOT_FOUND: No attr named 'num_host_args' in NodeDef:
	 [[{{node sequential/conv2d/Relu}}]]


INFO:tensorflow:global_step/sec: 178.646


INFO:tensorflow:global_step/sec: 178.646


INFO:tensorflow:loss = 2.3844895, step = 100 (0.560 sec)


INFO:tensorflow:loss = 2.3844895, step = 100 (0.560 sec)


INFO:tensorflow:global_step/sec: 191.081


INFO:tensorflow:global_step/sec: 191.081


INFO:tensorflow:loss = 2.2860851, step = 200 (0.523 sec)


INFO:tensorflow:loss = 2.2860851, step = 200 (0.523 sec)


INFO:tensorflow:global_step/sec: 197.632


INFO:tensorflow:global_step/sec: 197.632


INFO:tensorflow:loss = 2.290071, step = 300 (0.506 sec)


INFO:tensorflow:loss = 2.290071, step = 300 (0.506 sec)


INFO:tensorflow:global_step/sec: 193.557


INFO:tensorflow:global_step/sec: 193.557


INFO:tensorflow:loss = 2.3090718, step = 400 (0.517 sec)


INFO:tensorflow:loss = 2.3090718, step = 400 (0.517 sec)


INFO:tensorflow:global_step/sec: 193.318


INFO:tensorflow:global_step/sec: 193.318


INFO:tensorflow:loss = 2.311561, step = 500 (0.517 sec)


INFO:tensorflow:loss = 2.311561, step = 500 (0.517 sec)


INFO:tensorflow:global_step/sec: 187.526


INFO:tensorflow:global_step/sec: 187.526


INFO:tensorflow:loss = 2.2744005, step = 600 (0.533 sec)


INFO:tensorflow:loss = 2.2744005, step = 600 (0.533 sec)


INFO:tensorflow:global_step/sec: 172.513


INFO:tensorflow:global_step/sec: 172.513


INFO:tensorflow:loss = 2.2919688, step = 700 (0.580 sec)


INFO:tensorflow:loss = 2.2919688, step = 700 (0.580 sec)


INFO:tensorflow:global_step/sec: 153.807


INFO:tensorflow:global_step/sec: 153.807


INFO:tensorflow:loss = 2.2944775, step = 800 (0.650 sec)


INFO:tensorflow:loss = 2.2944775, step = 800 (0.650 sec)


INFO:tensorflow:global_step/sec: 194.85


INFO:tensorflow:global_step/sec: 194.85


INFO:tensorflow:loss = 2.2893872, step = 900 (0.513 sec)


INFO:tensorflow:loss = 2.2893872, step = 900 (0.513 sec)


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...


2024-02-26 20:16:15.594206: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 18076806549494980904
2024-02-26 20:16:15.594224: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 4023121871814650357
2024-02-26 20:16:15.594227: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 6246213777179886435
2024-02-26 20:16:15.594232: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 15519404885938365122
2024-02-26 20:16:15.594235: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 722670692967963994
2024-02-26 20:16:15.594239: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 2113668336618674152
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...


INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2024-02-26T20:16:16


INFO:tensorflow:Starting evaluation at 2024-02-26T20:16:16


Instructions for updating:
Use tf.keras instead.


Instructions for updating:
Use tf.keras instead.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938


2024-02-26 20:16:16.047993: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-02-26 20:16:16.048011: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Running local_init_op.


2024-02-26 20:16:16.054151: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:16.066302: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


2024-02-26 20:16:16.076449: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
INFO:tensorflow:Done running local_init_op.
2024-02-26 20:16:16.085053: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:16.100619: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:16.138522: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-02-26 20:16:16.146248: W tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc:926] error: NOT_FOUND: No attr named 'num_host_args' in NodeDef:
	 [[{{node sequential/conv2d/Relu}}]]
2024-02-26 20:16:16.146264: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] scoped_allocator_optimizer failed: NOT_FOUND: No

INFO:tensorflow:Evaluation [10/100]


INFO:tensorflow:Evaluation [10/100]


INFO:tensorflow:Evaluation [20/100]


INFO:tensorflow:Evaluation [20/100]


INFO:tensorflow:Evaluation [30/100]


INFO:tensorflow:Evaluation [30/100]


INFO:tensorflow:Evaluation [40/100]


INFO:tensorflow:Evaluation [40/100]


INFO:tensorflow:Evaluation [50/100]


INFO:tensorflow:Evaluation [50/100]


INFO:tensorflow:Evaluation [60/100]


INFO:tensorflow:Evaluation [60/100]


INFO:tensorflow:Evaluation [70/100]


INFO:tensorflow:Evaluation [70/100]


INFO:tensorflow:Evaluation [80/100]


INFO:tensorflow:Evaluation [80/100]


INFO:tensorflow:Evaluation [90/100]


INFO:tensorflow:Evaluation [90/100]


INFO:tensorflow:Evaluation [100/100]


INFO:tensorflow:Evaluation [100/100]
2024-02-26 20:16:16.562289: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


INFO:tensorflow:Inference Time : 0.59839s


INFO:tensorflow:Inference Time : 0.59839s


INFO:tensorflow:Finished evaluation at 2024-02-26-20:16:16


INFO:tensorflow:Finished evaluation at 2024-02-26-20:16:16


INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2844214


INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2844214


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Loss for final step: 1.1206634.


INFO:tensorflow:Loss for final step: 1.1206634.


({'loss': 2.2844214, 'global_step': 938}, [])

## Optimize training performance

You now have a model and a multi-worker capable Estimator powered by `tf.distribute.Strategy`.  You can try the following techniques to optimize performance of multi-worker training:

*   *Increase the batch size:* The batch size specified here is per-GPU.  In general, the largest batch size that fits the GPU memory is advisable.
*   *Cast variables:* Cast the variables to `tf.float` if possible.  The official ResNet model includes [an example](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466) of how this can be done.
*   *Use collective communication:* `MultiWorkerMirroredStrategy` provides multiple [collective communication implementations](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/cross_device_ops.py).  
    * `RING` implements ring-based collectives using gRPC as the cross-host communication layer.  
    * `NCCL` uses [Nvidia's NCCL](https://developer.nvidia.com/nccl) to implement collectives.  
    * `AUTO` defers the choice to the runtime.
    
    The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.  To override the automatic choice, specify a valid value to the `communication` parameter of `MultiWorkerMirroredStrategy`'s constructor, e.g. `communication=tf.distribute.experimental.CollectiveCommunication.NCCL`.

Visit the [Performance section](../../guide/function.ipynb) in the guide to learn more about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.


## Other code examples

1.   [End to end example](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy) for multi worker training in tensorflow/ecosystem using Kubernetes templates. This example starts with a Keras model and converts it to an Estimator using the `tf.keras.estimator.model_to_estimator` API.
2.   [Official models](https://github.com/tensorflow/models/tree/master/official), many of which can be configured to run multiple distribution strategies.
