# Distributed training with TensorFlow Distribute Strategy API on Amazon SageMaker

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

---

[Tensorflow's Distributed Training API](https://www.tensorflow.org/guide/distributed_training) enables multiple strategies for distributed training natively in Tensorflow. In this example, we will use the [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) to run a distributed training job on the training instance using a Tensorflow training script and SageMaker Deep Learning Container (DLC) for TensorFlow training. We will use the popular MNIST dataset to train a classifier based on a Simple Neural Network architecture.

We will start with a non-distributed Neuron Network MNIST training script and then adapt it to use distributed training.

## Set up the environment

Let's start by setting up the environment:

In [None]:
! pip install -U sagemaker

In [None]:
import os
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()
region = sagemaker_session.boto_session.region_name

## Training Data

We will use the MNIST dataset has been already loaded to the public S3 buckets ``sagemaker-example-files-prod-<REGION>`` under the prefix ``datasets/image/MNIST``. There are four ``.npy`` file under this prefix:
* ``input_train.npy``
* ``input_test.npy``
* ``input_train_labels.npy``
* ``input_test_labels.npy``

In [None]:
training_data_uri = "s3://sagemaker-example-files-prod-{}/datasets/image/MNIST/numpy".format(region)

## Construct the training script

This tutorial's training script is based on a [SageMaker MNIST example](https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker-python-sdk/tensorflow_script_mode_training_and_serving/mnist-2.py). Here is the entire script:

In [None]:
# TensorFlow script
!pygmentize 'mnist.py'

## Create a training job using the `TensorFlow` estimator

The `sagemaker.tensorflow.TensorFlow` estimator handles locating the training container based on the framework version and the job type (Inference or Training), uploading your script to a S3 location and creating a SageMaker training job. Let's call out a couple important parameters here:

* `framework_version` is set to `'2.13.0'` to indicate the TensorFlow version we want to use for executing your model training code. This will indicate to SageMaker which DLC should be used. Here's the list of the [available Deep Learning Container Images](https://github.com/aws/deep-learning-containers/blob/master/available_images.md).

* `entry_point` is the absolute or relative path to the local Python source file that should be executed as the entry point to training. 



In [None]:
from sagemaker.tensorflow import TensorFlow

local_mode = True

if local_mode:
    instance_type = "local_gpu"
    instance_count = 1
else:
    instance_type = "ml.g5.xlarge"
    instance_count = 1

mnist_estimator = TensorFlow(
    entry_point="mnist.py",
    role=role,
    instance_count=instance_count,
    instance_type=instance_type,
    framework_version="2.13.0",
    py_version="py310",
)

## Calling ``fit``

To start a training job, we call `estimator.fit(training_data_uri)`.

An S3 location is used here as the input. `fit` creates a default channel named `'training'`, which points to this S3 location. In the training script we can then access the training data from the location stored in `SM_CHANNEL_TRAINING`. `fit` accepts a couple other types of input as well. See the API doc [here](https://sagemaker.readthedocs.io/en/stable/estimators.html#sagemaker.estimator.EstimatorBase.fit) for details.

When training starts, the TensorFlow container executes mnist.py, passing `hyperparameters` and `model_dir` from the estimator as script arguments. Because we didn't define either in this example, no hyperparameters are passed, and `model_dir` defaults to `s3://<DEFAULT_BUCKET>/<TRAINING_JOB_NAME>`, so the script execution is as follows:
```bash
python mnist.py --model_dir s3://<DEFAULT_BUCKET>/<TRAINING_JOB_NAME>
```
When training is complete, the training job will upload the saved model to Amazon S3.

Calling fit to train a model with TensorFlow script.

In [None]:
mnist_estimator.fit(training_data_uri)

## Adapt the training job and training script to use Distribtued training

In this section, we use an adapter training script that leverages Tensorflow distributed training. We will use the [`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/guide/distributed_training#multiworkermirroredstrategy) which performs Distributed Data Parallelism

MultiWorkerMirroredStrategy has two implementations for cross-device communications:

1. RING is RPC-based and supports both CPUs and GPUs.

2. NCCL uses [NVIDIA Collective Communications Library (NCCL)](https://developer.nvidia.com/nccl) which provides state-of-art performance on GPUs but it doesn't support CPUs.

In this implementation we will defers the choice to Tensorflow, which will use NCCL in case GPU devices are used.

Here are the changes we implement in the script:
1. Instantiate the Multi-Worker Mirrored Strategy and the Communication Option

```python
communication_options = tf.distribute.experimental.CommunicationOptions(
    implementation=tf.distribute.experimental.CommunicationImplementation.NCCL)
strategy = tf.distribute.MultiWorkerMirroredStrategy(
    communication_options=communication_options)
```

2. Prints the number of devices (replicas) involved in the distributed strategy

```python
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
```

3. In the `main` method, move the model definition and compilation inside the strategy scope context to ensure they are distributed across the defined devices

```python
with strategy.scope():
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1024, activation=tf.nn.relu),
            tf.keras.layers.Dropout(0.4),
            tf.keras.layers.Dense(10, activation=tf.nn.softmax),
        ]
    )

    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
```

3. Sove the model only the chief worker
```python
if strategy.cluster_resolver.task_id == 0:
    print("Saving model on chief")
    mnist_classifier.save(os.path.join(args.sm_model_dir, "000000001"))
else:
    print("Saving model in /tmp on worker")
    mnist_classifier.save(f"/tmp/{strategy.cluster_resolver.task_id}")

```

---

Here is the entire script:

In [None]:
# TensorFlow script
!pygmentize 'mnist-distributed.py'

Now, we modify the `sagemaker.tensorflow.TensorFlow` estimator by changing the `entry_point` to the new script and and adding a distribution strategy.

To enable [`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/guide/distributed_training#multiworkermirroredstrategy) we use the following configuration:

```python
{
    "multi_worker_mirrored_strategy": {
        "enabled": True
    }
}
```

This distribution strategy option is available for TensorFlow 2.9 and later in the SageMaker Python SDK v2.xx.yy and later.

In [None]:
local_mode = False

if local_mode:
    instance_type = "local_gpu"
    instance_count = 1
else:
    instance_type = "ml.g5.24xlarge"
    instance_count = 2

mnist_estimator_distibuted = TensorFlow(
    entry_point="mnist-distributed.py",
    role=role,
    instance_count=instance_count,
    instance_type=instance_type,
    framework_version="2.13.0",
    py_version="py310",
    distribution={"multi_worker_mirrored_strategy": {"enabled": True}},
)

Calling fit to train a model with TensorFlow script.

In [None]:
mnist_estimator_distibuted.fit(training_data_uri)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/training|distributed_training|tensorflow|multi_worker_mirrored_strategy|tensorflow_multi_worker_mirrored_strategy.ipynb)
