# Multi-worker training with Keras



## Setup

Start with some necessary imports:

In [None]:
import json
import os
import sys

Before importing TensorFlow, make a few changes to the environment:

* In a real-world application, each worker would be on a different machine. For the purposes of this tutorial, all the workers will run on the **this** machine. So disable all GPUs to prevents errors caused by all workers trying to use the same GPU.

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

* Reset the `TF_CONFIG` environment variable (you'll learn more about this later):

In [None]:
os.environ.pop('TF_CONFIG', None)

* Make sure that the current directory is on Python's path—this allows the notebook to import the files written by `%%writefile` later:


In [None]:
if '.' not in sys.path:
  sys.path.insert(0, '.')

Finally, import TensorFlow:

In [None]:
import tensorflow as tf

### Dataset and model definition

In [None]:
%%writefile mnist_setup.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the [0, 255] range.
  # You need to convert them to float32 with values in the [0, 1] range.
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

### Model training on a single worker

In [None]:
import mnist_setup

batch_size = 64
single_worker_dataset = mnist_setup.mnist_dataset(batch_size)
single_worker_model = mnist_setup.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)

## Multi-worker configuration


### A cluster with jobs and tasks

In TensorFlow, distributed training involves a `'cluster'`
with several jobs, and each of the jobs may have one or more `'task'`s.

You will need the `TF_CONFIG` configuration environment variable for training on multiple machines, each of which possibly has a different role. `TF_CONFIG` is a JSON string used to specify the cluster configuration for each worker that is part of the cluster.

There are two components of a `TF_CONFIG` variable: `'cluster'` and `'task'`.

* A `'cluster'` is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs, such as `'worker'` or `'chief'`.
    - In multi-worker training with `tf.distribute.MultiWorkerMirroredStrategy`, there is usually one `'worker'` that takes on more responsibilities, such as saving a checkpoint and writing a summary file for TensorBoard, in addition to what a regular `'worker'` does. Such `'worker'` is referred to as the chief worker (with a job name `'chief'`).
    - It is customary for the worker with `'index'` `0` to be the `'chief'`.

* A `'task'` provides information on the current task and is different for each worker. It specifies the `'type'` and `'index'` of that worker.

In [None]:
tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

Note that `tf_config` is just a local variable in Python. To use it for training configuration, serialize it as a JSON and place it in a `TF_CONFIG` environment variable.

In [None]:
json.dumps(tf_config)

### Environment variables and subprocesses in notebooks

In [None]:
os.environ['GREETINGS'] = 'Hello!'

... then you can access the environment variable from the subprocesses:

In [None]:
%%bash
echo ${GREETINGS}

In the next section, you'll use this method to pass the `TF_CONFIG` to the worker subprocesses. You would never really launch your jobs this way in a real-world scenario—this tutorial is just showing how to do it with a minimal multi-worker example.

## Train the model

To train the model, firstly create an instance of the `tf.distribute.MultiWorkerMirroredStrategy`:

In [None]:
strategy = tf.distribute.MultiWorkerMirroredStrategy()

With the integration of `tf.distribute.Strategy` API into `tf.keras`, the only change you will make to distribute the training to multiple-workers is enclosing the model building and `model.compile()` call inside `strategy.scope()`. The distribution strategy's scope dictates how and where the variables are created, and in the case of `MultiWorkerMirroredStrategy`, the variables created are `MirroredVariable`s, and they are replicated on each of the workers.


In [None]:
with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist_setup.build_and_compile_cnn_model()

In [None]:
%%writefile main.py

import os
import json

import tensorflow as tf
import mnist_setup

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist_setup.build_and_compile_cnn_model()


multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

The current directory now contains both Python files:

In [None]:
%%bash
ls *.py

Serialize the `TF_CONFIG` to JSON and add it to the environment variables:

In [None]:
os.environ['TF_CONFIG'] = json.dumps(tf_config)

Now, you can launch a worker process that will run the `main.py` and use the `TF_CONFIG`:

In [None]:
# first kill any previous runs
%killbgscripts

In [None]:
%%bash --bg
python main.py &> job_0.log

In [None]:
import time
time.sleep(10)

Now, inspect what's been output to the worker's log file so far:

In [None]:
%%bash
cat job_0.log

The last line of the log file should say: `Started server with target: grpc://localhost:12345`. The first worker is now ready and is waiting for all the other worker(s) to be ready to proceed.

So update the `tf_config` for the second worker's process to pick up:

In [None]:
tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

Launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):

In [None]:
%%bash
python main.py

If you recheck the logs written by the first worker, you'll learn that it participated in training that model:

In [None]:
%%bash
cat job_0.log

Note: This may run slower than the test run at the beginning of this tutorial because running multiple workers on a single machine only adds overhead. The goal here is not to improve the training time but to give an example of multi-worker training.


In [None]:
# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts

## Multi-worker training in depth


### Dataset sharding

In multi-worker training, _dataset sharding_ is needed to ensure convergence and performance.

The example in the previous section relies on the default autosharding provided by the `tf.distribute.Strategy` API. You can control the sharding by setting the `tf.data.experimental.AutoShardPolicy` of the `tf.data.experimental.DistributeOptions`.

To learn more about _auto-sharding_, refer to the [Distributed input guide](https://www.tensorflow.org/tutorials/distribute/input#sharding).

Here is a quick example of how to turn the auto sharding off, so that each replica processes every example (_not recommended_):


In [None]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

global_batch_size = 64
multi_worker_dataset = mnist_setup.mnist_dataset(batch_size=64)
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

### Evaluation

If you pass the `validation_data` into `Model.fit` as well, it will alternate between training and evaluation for each epoch. The evaluation work is distributed across the same set of workers, and its results are aggregated and available to all workers.

Similar to training, the validation dataset is automatically sharded at the file level. You need to set a global batch size in the validation dataset and set the `validation_steps`.

A repeated dataset (by calling `tf.data.Dataset.repeat`) is recommended for evaluation.

Alternatively, you can also create another task that periodically reads checkpoints and runs the evaluation. This is what Estimator does. But this is not a recommended way to perform evaluation and thus its details are omitted.

### Performance

To tweak the performance of multi-worker training, you can try the following:

- `tf.distribute.MultiWorkerMirroredStrategy` provides multiple [collective communication implementations](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CommunicationImplementation):
    - `RING` implements ring-based collectives using gRPC as the cross-host communication layer.
    - `NCCL` uses the [NVIDIA Collective Communication Library](https://developer.nvidia.com/nccl) to implement collectives.
    -  `AUTO` defers the choice to the runtime.
    
    The best choice of collective implementation depends upon the number of GPUs, the type of GPUs, and the network interconnects in the cluster. To override the automatic choice, specify the `communication_options` parameter of `MultiWorkerMirroredStrategy`'s constructor. For example:
    
    ```python
    communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CommunicationImplementation.NCCL)
    ```

- Cast the variables to `tf.float` if possible:
    - The official ResNet model includes [an example](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466) of how to do this.

In [None]:
model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # Note: there are two possible `TF_CONFIG` configuration.
  #   1) In addition to `worker` tasks, a `chief` task type is use;
  #      in this case, this function should be modified to
  #      `return task_type == 'chief'`.
  #   2) Only `worker` task type is used; in this case, worker 0 is
  #      regarded as the chief. The implementation demonstrated here
  #      is for this case.
  # For the purpose of this Colab section, the `task_type is None` case
  # is added because it is effectively run with only a single worker.
  return (task_type == 'worker' and task_id == 0) or task_type is None

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)

With that, you're now ready to save:

In [None]:
multi_worker_model.save(write_model_path)

As described above, later on the model should only be loaded from the path chief saved to, so let's remove the temporary ones the non-chief workers saved:

In [None]:
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))

Now, when it's time to load, let's use convenient `tf.keras.models.load_model` API, and continue with further work.

Here, assume only using single worker to load and continue training, in which case you do not call `tf.keras.models.load_model` within another `strategy.scope()` (note that `strategy = tf.distribute.MultiWorkerMirroredStrategy()`, as defined earlier):

In [None]:
loaded_model = tf.keras.models.load_model(model_path)

# Now that the model is restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)

### Checkpoint saving and restoring

On the other hand, checkpointing allows you to save your model's weights and restore them without having to save the whole model.

Here, you'll create one `tf.train.Checkpoint` that tracks the model, which is managed by the `tf.train.CheckpointManager`, so that only the latest checkpoint is preserved:

In [None]:
checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

Once the `CheckpointManager` is set up, you're now ready to save and remove the checkpoints the non-chief workers had saved:

In [None]:
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)

Now, when you need to restore the model, you can find the latest checkpoint saved using the convenient `tf.train.latest_checkpoint` function. After restoring the checkpoint, you can continue with training.

In [None]:
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)

In [None]:
# Multi-worker training with `MultiWorkerMirroredStrategy`
# and the `BackupAndRestore` callback.

callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = mnist_setup.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)