# Horovod Distributed Training with SageMaker TensorFlow script mode.

Horovod is a distributed training framework based on MPI. You can find more details at [Horovod README](https://github.com/uber/horovod)`.

Horovod Distributed Training can be perfomed on SageMaker using the SageMaker Tensorflow container. SageMaker creates the MPI environment and executed the `mpirun` command to execute the training script.

MPI environment for Horovod can be configured by following flags in SageMaker SDK:

* ``enabled (bool)``: If set to ``True``, the MPI setup is performed and ``mpirun`` command is executed.
* ``processes_per_host (int)``: Number of processes MPI should launch on each host. Note, this should not be greater than the available slots on the selected instance type.
* ``custom_mpi_options (str)``: Additional command line arguments to pass to ``mpirun``.

In this example notebook, we create an mnist horovod training job.

## Set up the environment

In [None]:
import sagemaker
import os
from sagemaker.utils import sagemaker_timestamp
from sagemaker.tensorflow import TensorFlow

sage_session = sagemaker.Session()

from sagemaker import get_execution_role
role = get_execution_role()

account = sage_session.boto_session.client('sts').get_caller_identity()['Account']
region = sage_session.boto_session.region_name


## Construct a script for horovod distributed training

In [None]:
!cat 'mnist_hvd.py'

## Initialize Job Parameters

In [None]:
is_local_mode = False 

processes_per_host = 2
instance_count = 2
instance_type = "ml.c4.xlarge" if not is_local_mode else "local"

## Create a training job using the sagemaker.TensorFlow estimator

In [None]:
estimator = TensorFlow(entry_point="mnist_hvd.py",
                       role='SageMakerRole',
                       train_instance_count=instance_count,
                       train_instance_type=instance_type,
                       sagemaker_session=sage_session,
                       script_mode=True,
                       framework_version='1.12',
                       distributions={
                           'mpi': {
                               'enabled': True,
                               'processes_per_host': 2
                           }
                       },
                       base_job_name='hvd-mnist')

estimator.fit()