# Horovod Distributed Training with SageMaker TensorFlow script mode.

Horovod is a distributed training framework based on Message Passing Interfae (MPI). For information about Horovod, see [Horovod README](https://github.com/uber/horovod).

You can perform distributed training with Horovod on SageMaker by using the SageMaker Tensorflow container. If MPI is enabled when you create the training job, SageMaker creates the MPI environment and executes the `mpirun` command to execute the training script. Details on how to configure mpi settings in training job are described later in this example.

In this example notebook, we create a Horovod training job that uses the MNIST data set.

## Set up the environment

We get the `IAM` role that this notebook is running as and pass that role to the TensorFlow estimator that SageMaker uses to get data and perform training.

In [32]:
import sagemaker
import os
from sagemaker.utils import sagemaker_timestamp
from sagemaker.tensorflow import TensorFlow
from sagemaker import get_execution_role
import time

sagemaker_session = sagemaker.Session()

default_s3_bucket = sagemaker_session.default_bucket()
sagemaker_iam_role = get_execution_role()

train_script = "mnist_hvd.py"

## Prepare Data for training

Now we download the MNIST dataset to the local `/tmp/data/` directory and then upload it to an S3 bucket. After uploading the dataset to S3, we delete the data from `/tmp/data/`. 

In [33]:
import os
import shutil

import numpy as np

import keras
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

s3_train_path = "s3://{}/mnist/train.npz".format(default_s3_bucket)
s3_test_path = "s3://{}/mnist/test.npz".format(default_s3_bucket)

# Create local directory
! mkdir -p /tmp/data/mnist_train
! mkdir -p /tmp/data/mnist_test

# Save data locally
np.savez('/tmp/data/mnist_train/train.npz', data=x_train, labels=y_train)
np.savez('/tmp/data/mnist_test/test.npz', data=x_test, labels=y_test)

# Upload the dataset to s3
! aws s3 cp /tmp/data/mnist_train/train.npz $s3_train_path
! aws s3 cp /tmp/data/mnist_test/test.npz $s3_test_path

print('training data at ', s3_train_path)
print('test data at ', s3_test_path)
! rm -rf /tmp/data

upload: ../../../../../tmp/data/mnist_train/train.npz to s3://sagemaker-us-east-1-328296961357/mnist/train.npz
upload: ../../../../../tmp/data/mnist_test/test.npz to s3://sagemaker-us-east-1-328296961357/mnist/test.npz
training data at  s3://sagemaker-us-east-1-328296961357/mnist/train.npz
test data at  s3://sagemaker-us-east-1-328296961357/mnist/test.npz


## Write a script for horovod distributed training

This example is based on the [Keras MNIST horovod example](https://github.com/uber/horovod/blob/master/examples/keras_mnist.py) example in the horovod github repository.

To run this script we have to make following modifications:

### 1. Accept `--model_dir` as a command-line argument
Modify the script to accept `model_dir` as a command-line argument that defines the directory path (i.e. `/opt/ml/model/`) where the output model is saved. Because Sagemaker deletes the training cluster when training completes, saving the model to `/opt/ml/model/` directory prevents the trained model from getting lost, because when the training job completes, SageMaker writes the data stored in `/opt/ml/model/` to an S3 bucket. 

This also allows the SageMaker training job to integrate with other SageMaker services, such as hosted inference endpoints or batch transform jobs. It also allows you to host the trained model outside of SageMaker.

The following code adds `model_dir` as a command-line argument to the script:

```
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str)
```

More details can be found [here](https://github.com/aws/sagemaker-containers/blob/master/README.md).

### 2. Load train and test data

You can get local directory path where the `train` and `test` data is downloaded by reading the environment variable `SM_CHANNEL_TRAIN` and `SM_CHANNEL_TEST` respectively.
After you get the directory path, load the data into memory.

Here is the code:

```
x_train = np.load(os.path.join(os.environ['SM_CHANNEL_TRAIN'], 'train.npz'))['data']
y_train = np.load(os.path.join(os.environ['SM_CHANNEL_TRAIN'], 'train.npz'))['labels']

x_test = np.load(os.path.join(os.environ['SM_CHANNEL_TEST'], 'test.npz'))['data']
y_test = np.load(os.path.join(os.environ['SM_CHANNEL_TEST'], 'test.npz'))['labels']
```

For a list of all environment variables set by SageMaker that are accessible inside a training script, see [SageMaker Containers](https://github.com/aws/sagemaker-containers/blob/master/README.md).

### 3. Save the model only at the master node

Because in Horovod the training is distributed to multiple nodes, the model should only be saved by the master node. The following code in the script does this:

```
# Horovod: Save model only on worker 0 (i.e. master)
if hvd.rank() == 0:
    saved_model_path = tf.contrib.saved_model.save_keras_model(model, args.model_dir)
```

### Training script

Here is the final training script.

In [34]:
!pygmentize 'mnist_hvd.py'

/bin/sh: pigmentize: command not found


## Test locally using SageMaker Python SDK TensorFlow Estimator

You can use the SageMaker Python SDK TensorFlow estimator to easily train locally and in SageMaker.

This notebook shows how to use the SageMaker Python SDK to run your code in a local container before deploying to SageMaker's managed training or hosting environments. Just change your estimator's `train_instance_type` to `local` or `local_gpu`. For more information, see: https://github.com/aws/sagemaker-python-sdk#local-mode.

To use this feature, you need to install docker-compose (and nvidia-docker if you are training with a GPU). Run the following script to install docker-compose or nvidia-docker-compose, and configure the notebook environment for you.

**Note**: You can only run a single local notebook at a time.

In [35]:
!/bin/bash ./setup.sh

/bin/bash: ./setup.sh: No such file or directory


To train locally, set `train_instance_type` to `local`:

In [36]:
train_instance_type='local'

The MPI environment for Horovod can be configured by setting the following flags in the `mpi` field of the `distribution` dictionary that you pass to the TensorFlow estimator :

* ``enabled (bool)``: If set to ``True``, the MPI setup is performed and ``mpirun`` command is executed.
* ``processes_per_host (int) [Optional]``: Number of processes MPI should launch on each host. Note, this should not be greater than the available slots on the selected instance type. This flag should be set for the multi-cpu/gpu training.
* ``custom_mpi_options (str) [Optional]``: Any mpirun flag(s) can be passed in this field that will be added to the mpirun command executed by SageMaker to launch distributed horovod training.

For more information about the `distribution` dictionary, see the SageMaker Python SDK [README](https://github.com/aws/sagemaker-python-sdk/blob/v1.17.3/src/sagemaker/tensorflow/README.rst).

First, enable MPI:

In [37]:
distributions = {'mpi': {'enabled': True}}

Now, we create the Tensorflow estimator passing the `train_instance_type` and `distribution`

In [38]:
estimator_local = TensorFlow(entry_point=train_script,
                       role=sagemaker_iam_role,
                       train_instance_count=instance_count,
                       train_instance_type=train_instance_type,
                       script_mode=True,
                       framework_version='1.12',
                       distributions=distributions,
                       base_job_name='hvd-mnist-local')

tensorflow py2 container will be deprecated soon.


Call `fit()` to start the local training 

In [39]:
#%%time
#To save time an not run in 'local' mode, comment out the next line 
#estimator_local.fit({"train":s3_train_path, "test":s3_test_path})

## Train in SageMaker

After you test the training job locally, run it on SageMaker:

First, change the instance type from `local` to the valid EC2 instance type. For example, `ml.c4.xlarge`.

In [40]:
#train_instance_type='ml.p2.xlarge' #1 K80 GPU
#train_instance_type='ml.p2.8xlarge' #8 K80 GPU
#train_instance_type='ml.c4.xlarge' #4 vCPU
#train_instance_type='ml.c4.4xlarge' #16 vCPU
train_instance_type='ml.c5.4xlarge' #16 vCPU
instance_count = 2

You can also provide your custom MPI options by passing in the `custom_mpi_options` field of `distribution` dictionary that will be added to the `mpirun` command executed by SageMaker:

In [41]:
distributions = {'mpi': {'enabled': True, "custom_mpi_options": "-verbose --NCCL_DEBUG=INFO"}}

Now, we create the Tensorflow estimator passing the `train_instance_type` and `distribution` to launch the training job in sagemaker.

In [42]:
estimator = TensorFlow(entry_point=train_script,
                       role=sagemaker_iam_role,
                       train_instance_count=instance_count,
                       train_instance_type=train_instance_type,
                       script_mode=True,
                       framework_version='1.12',
                       distributions=distributions,
                       base_job_name='hvd-mnist')

tensorflow py2 container will be deprecated soon.


Call `fit()` to start the training

In [43]:
%%time
print( "instance_type:", train_instance_type, "instance_count:", instance_count, "processes_per_host: 1")
estimator.fit({"train":s3_train_path, "test":s3_test_path})

##  Horovod training in SageMaker using multiple CPU/GPU

To enable mulitiple CPUs or GPUs for horovod training, set the `processes_per_host` field in the `mpi` section of the `distribution` dictionary to the desired value of processes that will be executed per instance.

In [44]:
instance_count = 2
processes_per_host = 2
print( "instance_type:", train_instance_type, "instance_count:", instance_count, "processes_per_host:", processes_per_host)
distributions = {'mpi': {'enabled': True, 
                         "custom_mpi_options": "-verbose --NCCL_DEBUG=INFO -x OMPI_MCA_btl_vader_single_copy_mechanism=none", 
                         "processes_per_host": processes_per_host}}

4 instance_count:  4 processes_per_host:  4


Now, we create the Tensorflow estimator passing the `train_instance_type` and `distribution`

In [45]:
estimator = TensorFlow(entry_point=train_script,
                       role=sagemaker_iam_role,
                       train_instance_count=instance_count,
                       train_instance_type=train_instance_type,
                       script_mode=True,
                       framework_version='1.12', #1.12
                       distributions=distributions,
                       base_job_name='hvd-mnist-multi-cpu')

tensorflow py2 container will be deprecated soon.


Call `fit()` to start the training

In [46]:
%%time
estimator.fit({"train":s3_train_path, "test":s3_test_path})

2019-06-06 16:41:19 Starting - Starting the training job...
2019-06-06 16:41:20 Starting - Launching requested ML instances......
2019-06-06 16:42:31 Starting - Preparing the instances for training............
2019-06-06 16:44:45 Downloading - Downloading input data
2019-06-06 16:44:45 Training - Downloading the training image.
[32m2019-06-06 16:44:51,827 sagemaker-containers INFO     Imported framework sagemaker_tensorflow_container.training[0m
[32m2019-06-06 16:44:51,833 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[33m2019-06-06 16:44:52,144 sagemaker-containers INFO     Imported framework sagemaker_tensorflow_container.training[0m
[33m2019-06-06 16:44:52,151 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[32m2019-06-06 16:44:52,395 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[32m2019-06-06 16:44:52,405 sagemaker-containers INFO     Starting MPI run as worker node.[0m



2019-06-06 16:44:54 Training - Training image download completed. Training in progress.[31m[1,9]<stdout>:#015  128/60000 [..............................] - ETA: 23:22 - loss: 2.3010 - acc: 0.1250[1,8]<stdout>:#015  128/60000 [..............................] - ETA: 23:22 - loss: 2.3074 - acc: 0.0625[1,10]<stdout>:#015  128/60000 [..............................] - ETA: 23:22 - loss: 2.3099 - acc: 0.1016[1,14]<stdout>:#015  128/60000 [..............................] - ETA: 23:21 - loss: 2.3198 - acc: 0.0703[1,11]<stdout>:#015  128/60000 [..............................] - ETA: 23:21 - loss: 2.3214 - acc: 0.0391[1,12]<stdout>:#015  128/60000 [..............................] - ETA: 23:21 - loss: 2.3087 - acc: 0.1094[1,15]<stdout>:#015  128/60000 [..............................] - ETA: 23:21 - loss: 2.3185 - acc: 0.0625[1,13]<stdout>:#015  128/60000 [..............................] - ETA: 23:21 - loss: 2.3108 - acc: 0.0859[1,3]<stdout>:#015[1,3]<stdout>:  128/60000 [........................

[31mt>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2176/60000 [>.............................] - ETA: 7:29 - loss: 2.5359 - acc: 0.2220[1,10]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2176/60000 [>.............................] - ETA: 7:29 - loss: 2.5796 - acc: 0.2160[1,9]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015[1,0]<stdout>: 2816/60000 [>.............................][1,0]<stdout>: - ETA: 7:11 - loss: 2.6113 - acc: 0.2319[1,1]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2944/60000 [>.............................][1,7]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#0

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 4224/60000 [=>............................][1,2]<stdout>: - ETA: 6:46 - loss: 2.4235 - acc: 0.3007[1,14]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 4224/60000 [=>............................] - ETA: 6:46 - loss: 2.3360 - acc: 0.3106[1,1]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#0

[31m......................] - ETA: 6:33 - loss: 1.9693 - acc: 0.4103[1,8]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 5632/60000 [=>............................] - ETA: 6:33 - loss: 2.0046 - acc: 0.3983[1,3]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015[1,3]<stdout>: 5632/60000 [=>............................][1,3]<stdout>: - ETA: 6:33 - loss: 1.9954 - acc: 0.4203[1,13]<stdout>:#0

[31m#010#010#010#010#010#010#010#010#010#015[1,0]<stdout>: 7680/60000 [==>...........................][1,0]<stdout>: - ETA: 6:15 - loss: 1.5550 - acc: 0.5320[1,13]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7680/60000 [==>...........................] - ETA: 6:15 - loss: 1.5601 - acc: 0.5309[1,14]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7680/60000 [==>.......................

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 9088/60000 [===>..........................] - ETA: 6:02 - loss: 1.3850 - acc: 0.5881[1,6]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 9088/60000 [===>..........................] - ETA: 6:02 - loss: 1.3762 - acc: 0.5924[1,8]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01511136/60000 [====>.........................] - ETA: 5:46 - loss: 1.1898 - acc: 0.6509[1,5]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01511136/60000 [====>.........................] - ETA: 5:46 - loss: 1.1810 - acc: 0.6450[1,4]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01511776/60000 [====>.........................] - ETA: 5:40 - loss: 1.1306 - acc: 0.6650[1,1]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01511776/60000 [====>.........................] - ETA: 5:40 - loss: 1.1301 - acc: 0.6616[1,8]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#































































[31m[1,3]<stdout>:Test loss:[1,3]<stdout>: [1,3]<stdout>:0.15149432390481232[1,3]<stdout>:[0m
[31m[1,3]<stdout>:Test accuracy:[1,3]<stdout>: [1,3]<stdout>:0.9499[1,3]<stdout>:[0m
[31m[1,1]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,1]<stdout>:Test accuracy: 0.9499[0m
[31m[1,0]<stdout>:Test loss: 0.15149432390481232[1,0]<stdout>:[0m
[31m[1,0]<stdout>:Test accuracy: 0.9499[0m
[31m[1,2]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,2]<stdout>:Test accuracy: [1,2]<stdout>:0.9499[0m


[31m[1,0]<stderr>:[0m
[31m[1,0]<stderr>:Consider using a TensorFlow optimizer from `tf.train`.[0m
[31m[1,0]<stdout>:Model successfully saved at: /opt/ml/model/temp-1559839932[0m
[31m[1,9]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,9]<stdout>:Test accuracy: 0.9499[0m
[31m[1,6]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,6]<stdout>:Test accuracy: 0.9499[0m
[31m[1,4]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,4]<stdout>:Test accuracy: 0.9499[0m
[31m[1,7]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,7]<stdout>:Test accuracy: 0.9499[0m
[31m[1,10]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,10]<stdout>:Test accuracy: 0.9499[0m
[31m[1,15]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,15]<stdout>:Test accuracy: 0.9499[0m
[31m[1,5]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,5]<stdout>:Test accuracy: 0.9499[0m
[31m[1,11]<stdout>:Test loss: 0.15149432390481232[0m
[31m[1,11]<stdout>:Test accuracy: 0.9499[0m
[31m[1,12]<stdout

## Improving horovod training performance on SageMaker

Performing Horovod training inside a VPC improves the network latency between nodes, leading to higher performance and stability of Horovod training jobs.

For a detailed explanation of how to configure a VPC for SageMaker training, see [Secure Training and Inference with VPC](https://github.com/aws/sagemaker-python-sdk#secure-training-and-inference-with-vpc).

### Setup VPC infrastructure
We will setup following resources as part of VPC stack:
* `VPC`: AWS Virtual private cloud with CIDR block.
* `Subnets`: Two subnets with the CIDR blocks `10.0.0.0/24` and `10.0.1.0/24`
* `Security Group`: Defining the open ingress and egress ports, such as TCP.
* `VpcEndpoint`: S3 Vpc endpoint allowing sagemaker's vpc cluster to dosenload data from S3.
* `Route Table`: Defining routes and is tied to subnets and VPC.

Complete cloud formation template for setting up the VPC stack can be seen [here](./vpc_infra_cfn.json).

In [47]:
print(sagemaker_iam_role)
import boto3
from botocore.exceptions import ClientError
from time import sleep

def create_vpn_infra(stack_name="hvdvpcstack"):
    cfn = boto3.client("cloudformation")

    cfn_template = open("vpc_infra_cfn.json", "r").read()
    
    try:
        vpn_stack = cfn.create_stack(StackName=(stack_name),
                                     TemplateBody=cfn_template)
    except ClientError as e:
        if e.response['Error']['Code'] == 'AlreadyExistsException':
            print("Stack: {} already exists, so skipping stack creation.".format(stack_name))
        else:
            print("Unexpected error: %s" % e)
            raise e

    describe_stack = cfn.describe_stacks(StackName=stack_name)["Stacks"][0]

    while describe_stack["StackStatus"] == "CREATE_IN_PROGRESS":
        describe_stack = cfn.describe_stacks(StackName=stack_name)["Stacks"][0]
        sleep(0.5)

    if describe_stack["StackStatus"] != "CREATE_COMPLETE":
        raise ValueError("Stack creation failed in state: {}".format(describe_stack["StackStatus"]))

    print("Stack: {} created successfully with status: {}".format(stack_name, describe_stack["StackStatus"]))

    subnets = []
    security_groups = []

    for output_field in describe_stack["Outputs"]:

        if output_field["OutputKey"] == "SecurityGroupId":
            security_groups.append(output_field["OutputValue"])
        if output_field["OutputKey"] == "Subnet1Id" or output_field["OutputKey"] == "Subnet2Id":
            subnets.append(output_field["OutputValue"])

    return subnets, security_groups


subnets, security_groups = create_vpn_infra()
print("Subnets: {}".format(subnets))
print("Security Groups: {}".format(security_groups))

arn:aws:iam::328296961357:role/service-role/AmazonSageMaker-ExecutionRole-20190430T072462
Stack: hvdvpcstack already exists, so skipping stack creation.
Stack: hvdvpcstack created successfully with status: CREATE_COMPLETE
Subnets: ['subnet-06ebce2df1c3f0303', 'subnet-0bdaf2d5c019f21c9']
Security Groups: ['sg-01f003850777a9140']


### VPC training in SageMaker
Now, we create the Tensorflow estimator, passing the `train_instance_type` and `distribution`.

In [48]:
estimator = TensorFlow(entry_point=train_script,
                       role=sagemaker_iam_role,
                       train_instance_count=instance_count,
                       train_instance_type=train_instance_type,
                       script_mode=True,
                       framework_version='1.12',
                       distributions=distributions,
                       security_group_ids=['sg-01f003850777a9140'],
                       subnets=['subnet-06ebce2df1c3f0303', 'subnet-0bdaf2d5c019f21c9'],
                       base_job_name='hvd-mnist-vpc')

tensorflow py2 container will be deprecated soon.


Call `fit()` to start the training

In [49]:
%%time
estimator.fit({"train":s3_train_path, "test":s3_test_path})

2019-06-06 16:53:21 Starting - Starting the training job...
2019-06-06 16:53:23 Starting - Launching requested ML instances......
2019-06-06 16:54:32 Starting - Preparing the instances for training......
2019-06-06 16:55:49 Downloading - Downloading input data...
2019-06-06 16:56:12 Training - Training image download completed. Training in progress.
[31m2019-06-06 16:56:14,141 sagemaker-containers INFO     Imported framework sagemaker_tensorflow_container.training[0m
[31m2019-06-06 16:56:14,147 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[31m2019-06-06 16:56:14,558 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[31m2019-06-06 16:56:14,568 sagemaker-containers INFO     Starting MPI run as worker node.[0m
[31m2019-06-06 16:56:14,568 sagemaker-containers INFO     Creating SSH daemon.[0m
[31m2019-06-06 16:56:14,574 sagemaker-containers INFO     Waiting for MPI workers to establish their SSH connections[0m
[31

[32m2019-06-06 16:56:14,340 sagemaker-containers INFO     Imported framework sagemaker_tensorflow_container.training[0m
[32m2019-06-06 16:56:14,346 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[32m2019-06-06 16:56:14,819 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[32m2019-06-06 16:56:14,830 sagemaker-containers INFO     Starting MPI run as worker node.[0m
[32m2019-06-06 16:56:14,830 sagemaker-containers INFO     Waiting for MPI Master to create SSH daemon.[0m
[32m2019-06-06 16:56:14,836 paramiko.transport INFO     Connected (version 2.0, client OpenSSH_7.2p2)[0m
  m.add_string(self.Q_C.public_numbers().encode_point())[0m
  self.curve, Q_S_bytes[0m
  hm.add_string(self.Q_C.public_numbers().encode_point())[0m
[32m2019-06-06 16:56:14,985 paramiko.transport INFO     Authentication (publickey) successful![0m
[32m2019-06-06 16:56:15,029 sagemaker-containers INFO     Can connect to host algo-1[0m
[32m2

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2176/60000 [>.............................] - ETA: 7:14 - loss: 2.2550 - acc: 0.2684[1,8]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2176/60000 [>.............................] - ETA: 7:14 - loss: 2.2703 - acc: 0.2707[1,10]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2816/60000 [>.............................] - ETA: 6:55 - loss: 2.2050 - acc: 0.3125[1,15]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 2944/60000 [>.............................] - ETA: 6:52 - loss: 2.1598 - acc: 0.3240[1,4]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010

[31m ETA: 6:31 - loss: 1.7884 - acc: 0.4467[1,14]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 4224/60000 [=>............................] - ETA: 6:31 - loss: 1.7878 - acc: 0.4451[1,0]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015[1,0]<stdout>: 4224/60000 [=>............................][1,0]<stdout>: - ETA: 6:31 - loss: 1.7839 - acc: 0.4344[1,1]<stdout>:#010#010#010#010#010#010#01

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7040/60000 [==>...........................] - ETA: 5:58 - loss: 1.1911 - acc: 0.6250[1,10]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7040/60000 [==>...........................] - ETA: 5:58 - loss: 1.2020 - acc: 0.6220[1,5]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7040/60000 [==>........

[31m1,0]<stdout>: 7680/60000 [==>...........................][1,10]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7680/60000 [==>...........................] - ETA: 5:52 - loss: 1.1219 - acc: 0.6479[1,8]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 7680/60000 [==>...........................] - ETA: 5:52 - loss: 1.1269 - acc: 0.6474[1,9]<stdout>:#010#010#010#010#010#010#010#010#010#

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 9728/60000 [===>..........................] - ETA: 5:37 - loss: 0.9162 - acc: 0.7096[1,12]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015 9728/60000 [===>..........................] - ETA: 5:37 - loss: 0.9367 - acc: 0.7096[1,2]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010

[31m8 - acc: 0.7389[1,8]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01511136/60000 [====>.........................] - ETA: 5:26 - loss: 0.8321 - acc: 0.7415[1,1]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015[1,1]<stdout>:11136/60000 [====>.........................][1,1]<stdout>: - ETA: 5:26 - loss: 0.8497 - acc: 0.7354[1,5]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010

[31m#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01513184/60000 [=====>........................] - ETA: 5:10 - loss: 0.7282 - acc: 0.7750[1,0]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#015[1,0]<stdout>:13184/60000 [=====>........................][1,7]<stdout>:#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010



























































[31m[1,1]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,1]<stdout>:Test accuracy: 0.9903[0m
[31m[1,3]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,3]<stdout>:Test accuracy: 0.9903[0m
[31m[1,2]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,2]<stdout>:Test accuracy: 0.9903[0m
[31m[1,0]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,0]<stdout>:Test accuracy: [1,0]<stdout>:0.9903[0m
[31m[1,0]<stderr>:[0m
[31m[1,0]<stderr>:Consider using a TensorFlow optimizer from `tf.train`.[0m
[31m[1,0]<stdout>:Model successfully saved at: /opt/ml/model/temp-1559840596[0m
[31m[1,7]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,7]<stdout>:Test accuracy: 0.9903[0m
[31m[1,11]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,11]<stdout>:Test accuracy: 0.9903[0m
[31m[1,15]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,15]<stdout>:Test accuracy: 0.9903[0m
[31m[1,10]<stdout>:Test loss: 0.030449212397501743[0m
[31m[1,10]<stdout>:Test accuracy: 0.9903

After training is completed, you can host the saved model by using TensorFlow Serving on SageMaker. For an example that uses TensorFlow Serving, see [(https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_serving_container/tensorflow_serving_container.ipynb](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_serving_container/tensorflow_serving_container.ipynb).

## Reference Links:
* [SageMaker Container MPI Support.](https://github.com/aws/sagemaker-containers/blob/master/src/sagemaker_containers/_mpi.py)
* [Horovod Official Documentation](https://github.com/uber/horovod)
* [SageMaker Tensorflow script mode example.](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_script_mode_quickstart/tensorflow_script_mode_quickstart.ipynb)