# Training PyTorch Models using Horovod with Open MPI

*(This notebook was tested with the "Python 3 (PyTorch CPU Optimized)" kernel.)*

## Background

[Horovod](https://horovod.readthedocs.io/en/stable/summary_include.html) is a distributed deep learning training framework for TensorFlow, Keras, PyTorch, and MXNet. Horovod can be used to train your models faster on Amazon SageMaker, and this notebook shows you how to run Horovod directly with Open MPI.

Amazon SageMaker is a fully-managed service that provides developers and data scientists with the ability to build, train, and deploy machine learning (ML) models quickly. Amazon SageMaker removes the heavy lifting from each step of the machine learning process to make it easier to develop high-quality models. The SageMaker Python SDK makes it easy to train and deploy models in Amazon SageMaker with several different machine learning and deep learning frameworks, including PyTorch.

In this notebook, we:

* Upload a data set to S3
* Train a simple neural network with Horovod
* Deploy a model to a SageMaker endpoint

## Setup

For this example, we use the MNIST data set. MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits).

### Downloading the Data

The MNIST data set can be downloaded with the `torchvision.datasets` package.

In [None]:
import torch
import torchvision

# ipywidgets is required to display loading bars in Jupyter. Because ipywidgets is not installed in
# the container by default, we're temporarily disabling the loading bar.
def gen_bar_updater():
    return lambda count, block_size, total_size: None
torchvision.datasets.utils.gen_bar_updater = gen_bar_updater

dataset = torchvision.datasets.MNIST(
    root='./data',
    transform=torchvision.transforms.ToTensor(),
    download=True,
)

### Previewing the Data

Let's preview what the input images look like. We can do this using `matplotlib`.

In [None]:
import matplotlib.pyplot as plt

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
images, labels = next(iter(dataloader))
grid = torchvision.utils.make_grid(images)
plt.imshow(grid.permute(1, 2, 0))  
plt.show()

### Uploading the Data

We use the `sagemaker.Session.upload` function to upload our datasets to an S3 location. The return value, `inputs`, identifies the location of our data.

To upload our data to S3, we need to specify an S3 bucket and prefix. These should be within the same region as the Notebook Instance, training, and hosting.

In [None]:
import sagemaker
from sagemaker.s3 import S3Uploader

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
prefix = 'pytorch-horovod-mpi-example'
inputs = S3Uploader.upload('data', 's3://{}/{}/data'.format(bucket, prefix))

## Training

### Entrypoint Script

In order to use Horovod with OpenMPI, we use a shell script as the training entry point. The shell script invokes our Python training script as part of an Open MPI command. You can learn more about running Horovod with MPI at https://horovod.readthedocs.io/en/stable/mpirun.html.

Here's what the shell script looks like:

In [None]:
!cat src/train.sh

We're using 4 GPUs because that's the number of GPUs on a ml.p2.8xlarge instance, but if you're using a different instance type you may need to change the value of the`-np` parameter.

### Training Script

#### Horovod-specific Code

As you read the training script below, take note of these Horovod-specific modifications:

* We start off by running `hvd.init()` and pinning each GPU to a single process
* While constructing the `torch.optim.SGD` optimizer, we scale the learning rate by the number of workers (if you're wondering why we do this, see [this paper](https://arxiv.org/abs/1706.02677))
* We wrap `torch.optim.SGD` with `hvd.DistributedOptimized` and broadcast the initial model state to the other processes
* We gaurd the `save` function with `hvd.local_rank() != 0` to ensure the model is saved only on worker 0. This is done to prevent other workers from corrupting the checkpoint

For more information about using Horovod with PyTorch, please see the [Horovod documentation](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#prepare-a-pytorch-training-script).

#### SageMaker-specific Code

Outside of the Horovod-specific code, the training code is very similar to a training script we might run outside of Amazon SageMaker, but we can access useful properties about the training environment through various environment variables. For this notebook, our script retrieves the following environment variable values:

* `SM_HOSTS`: a list of hosts on the container network.
* `SM_CURRENT_HOST`: the name of the current container on the container network.
* `SM_MODEL_DIR`: the location for model artifacts. This directory is uploaded to Amazon S3 at the end of the training job.
* `SM_CHANNEL_TRAINING`: the location of our training data.
* `SM_NUM_GPUS`: the number of GPUs available to the current container.

For more about writing a PyTorch training script with SageMaker, please see the [SageMaker documentation](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#prepare-a-pytorch-training-script). 

In [None]:
!pygmentize src/train.py

### Starting the Training Job

The IAM role arn is used to give training and hosting access to your data. See the Amazon SageMaker Roles for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the `sagemaker.get_execution_role()` with the appropriate full IAM role arn string(s).

In [None]:
from sagemaker.pytorch import PyTorch
role = sagemaker.get_execution_role()

The `PyTorch` class allows us to run our training script as a job on SageMaker. We need to configure it with our training script, an IAM role, the number and type of training instances, and hyperparameters. In this case we are going to run our training job on a ml.p2.8xlarge instance, which contains four Tesla K80 GPUs.

In [None]:
estimator = PyTorch(entry_point='train.sh',
                    source_dir='src',
                    role=role,
                    framework_version='1.4.0',
                    train_instance_count=1,
                    train_instance_type='ml.p2.8xlarge')

After we've constructed our `PyTorch` object, we can fit it using the data we uploaded to S3. 

In [None]:
estimator.fit(inputs)

## Hosting

### Creating an Endpoint

After we train our model, we can deploy it to a SageMaker Endpoint, which serves prediction requests in real-time. An implementation of `model_fn` is required for our entry point script. We are going to use default implementations of `input_fn`, `predict_fn`, `output_fn` and `transform_fm` defined in [sagemaker-pytorch-containers](https://github.com/aws/sagemaker-pytorch-training-toolkit).

Here's the inference script we're using:

In [None]:
!pygmentize src/inference.py

The arguments to the deploy function allow us to set the number and type of instances that are used for the Endpoint. These do not need to be the same as the values we used for the training job. Here we deploy the model to a single ml.c5.xlarge instance.

In [None]:
predictor = estimator.deploy(
    initial_instance_count=1, 
    instance_type='ml.c5.xlarge',
    entry_point='inference.py', 
    source_dir='src'
)

### Classifying Digits

We can now use the predictor to classify hand-written digits.

In [None]:
import numpy as np

# The input images are displayed again
grid = torchvision.utils.make_grid(images)
plt.imshow(grid.permute(1, 2, 0))  
plt.show()

outputs = predictor.predict(images)
predictions = np.argmax(outputs, axis=1)

print('Predictions:', ' '.join('%4s' % predictions[i] for i in range(4)))

As expected, our model correctly classifies the digits.

## Cleanup

After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it!

In [None]:
predictor.delete_endpoint()