## Distributed Training with Chainer and ChainerMN

Chainer can train in two modes: single-machine, and distributed. Unlike the single-machine notebook example that trains an image classification model on the CIFAR-10 dataset, we will write a Chainer script that uses `chainermn` to distribute training to multiple instances.

[VGG](https://arxiv.org/pdf/1409.1556v6.pdf) is an architecture for deep convolution networks. In this example, we train a convolutional network to perform image classification using the CIFAR-10 dataset on multiple instances. CIFAR-10 consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We'll train a model on SageMaker, deploy it to Amazon SageMaker, and then classify images using the deployed model.

To train with a Chainer script, we construct a ```Chainer``` estimator using the [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk). We can pass in an `entry_point`, the name of a script that contains a couple of functions with certain signatures (`train` and `model_fn`). This script will be run on SageMaker in a container that invokes these functions to train and load Chainer models. 

For more on the Chainer container, please visit the sagemaker-chainer-containers repository:
https://github.com/aws/sagemaker-chainer-containers

In [1]:
# Setup
from sagemaker import get_execution_role
import sagemaker

sagemaker_session = sagemaker.Session()

# This role retrieves the SageMaker-compatible role used by this Notebook Instance.
role = get_execution_role()


## Downloading training and test data

We use helper functions given by `chainer` to download and preprocess the CIFAR10 data. 

In [2]:
import chainer

from chainer.datasets import get_cifar10

train, test = get_cifar10()

  from ._conv import register_converters as _register_converters


## Uploading the data

We save the preprocessed data to the local filesystem, and then use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value `inputs` identifies the S3 location, which we will use when we start the Training Job.

In [3]:
import os
import shutil

import numpy as np

train_data = [element[0] for element in train]
train_labels = [element[1] for element in train]

test_data = [element[0] for element in test]
test_labels = [element[1] for element in test]

try:
    os.makedirs('data/train')
    os.makedirs('data/test')
except FileExistsError:
    pass

np.savez('data/train/train.npz', data=train_data, labels=train_labels)
np.savez('data/test/test.npz', data=test_data, labels=test_labels)

# Upload preprocessed data to S3 
train_input = sagemaker_session.upload_data(path=os.path.join('data', 'train'),
                                                            key_prefix='notebook/chainer_cifar/train')
test_input = sagemaker_session.upload_data(path=os.path.join('data', 'test'),
                                                           key_prefix='notebook/chainer_cifar/test')

# Remove data from notebook instance (to save disk space)
shutil.rmtree('data')

## Writing the Chainer training script to run on Amazon SageMaker

We need to provide a training script that can run on the SageMaker platform. The training scripts are essentially the same as one you would write for local training, except that you need to provide a function `train` that returns a trained model.

Since we will use the same script to host the Chainer model, the script also needs a function `model_fn` that loads the model -- by default, Chainer models are saved to disk as `model.npz`. When SageMaker calls your `train` and `model_fn` functions, it will pass in arguments that describe the training environment.

The script below uses a `chainermn` Communicator to distribute training to multiple nodes. The Communicator depends on MPI (Message Passing Interface), so the Chainer container running on SageMaker runs this script with `mpirun` if the `Chainer` Estimator specifies a `train_instance_count` of two or greater, or if `use_mpi` in the `Chainer` estimator is `true`.

By default, one process is created per GPU (on GPU instances), or one per host (on CPU instances, which are not recommended for this notebook).

See https://github.com/chainer/chainermn for more on running Chainer on multiple nodes.  

In [4]:
!cat 'code/chainer_cifar_vgg_distributed.py'

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import print_function, absolute_import

import os

import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
import chainermn
from chainer import initializers
from chainer import serializers
from chainer import training
from chainer.training import extensions

import net


def train(hyperparameters, num_gpus, output_data_dir, channel_input_dirs):
    

## Running the training script on SageMaker

The ```Chainer``` class allows us to run our training function as a training job on SageMaker infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. In this case we will run our training job on two `ml.p3.2xlarge` instances.

This script uses the `chainermn` package, which distributes training with MPI. Your script is run with `mpirun`, so a ChainerMN Communicator object can be used to distribute training. Arguments to `mpirun` are set to sensible defaults, but you can configure how your script is run in distributed mode. See the ```Chainer``` class documentation for more on configuring MPI.

In [5]:
from sagemaker.chainer.estimator import Chainer

chainer_estimator = Chainer(entry_point='chainer_cifar_vgg_distributed.py', source_dir="code", role=role,
                            use_mpi=True, sagemaker_session=sagemaker_session,
                            train_instance_count=2, train_instance_type='ml.p3.2xlarge',
                            hyperparameters={'epochs': 30, 'batch_size': 256})

chainer_estimator.fit({'train': train_input, 'test': test_input})

INFO:sagemaker:Creating training-job with name: sagemaker-chainer-2018-05-04-20-04-24-955


......................................................
[32m2018-05-04 20:08:47,957 INFO - root - running container entrypoint[0m
[32m2018-05-04 20:08:47,958 INFO - root - starting train task[0m
[32m2018-05-04 20:08:47,970 INFO - container_support.app - started training: {'train_fn': <function train at 0x7f01a48dabf8>}[0m
[32mDownloading s3://sagemaker-us-west-2-038453126632/sagemaker-chainer-2018-05-04-20-04-24-955/source/sourcedir.tar.gz to /tmp/script.tar.gz[0m
[32m2018-05-04 20:08:48,125 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTP connection (1): 169.254.170.2[0m
[32m2018-05-04 20:08:48,213 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTPS connection (1): sagemaker-us-west-2-038453126632.s3.amazonaws.com[0m
[32m2018-05-04 20:08:48,259 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTPS connection (2): sagemaker-us-west-2-038453126632.s3.amazonaws.com[0m
[3

===== Job Complete =====
Billable seconds: 2700


Our Chainer script writes various artifacts, such as plots, to a directory `output_data_dir`, the contents of which which SageMaker uploads to S3. Now we download and extract these artifacts.

In [6]:
from s3_util import retrieve_output_from_s3

chainer_training_job = chainer_estimator.latest_training_job.name

desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=chainer_training_job)
output_data = desc['ModelArtifacts']['S3ModelArtifacts'].replace('model.tar.gz', 'output.tar.gz')

retrieve_output_from_s3(output_data, 'output/distributed_cifar')

In [7]:
# Executing as code to reload images so that browsers don't render cached images.
from IPython.display import Markdown
import time
_nonce = time.time()

Markdown("""
These plots show the accuracy and loss over epochs:

<img style="display: inline;" src="output/distributed_cifar/algo-1/accuracy.png?{0}" />
<img style="display: inline;" src="output/distributed_cifar/algo-1/loss.png?{0}" />""".format(_nonce))



These plots show the accuracy and loss over epochs:

<img style="display: inline;" src="output/distributed_cifar/algo-1/accuracy.png?1525465777.5663912" />
<img style="display: inline;" src="output/distributed_cifar/algo-1/loss.png?1525465777.5663912" />

## Deploying the Trained Model

After training, we use the Chainer estimator object to create and deploy a hosted prediction endpoint. We can use a CPU-based instance for inference (in this case an `ml.m4.xlarge`), even though we trained on GPU instances.

The predictor object returned by `deploy` lets us call the new endpoint and perform inference on our sample images. 

In [8]:
predictor = chainer_estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

INFO:sagemaker:Creating model with name: sagemaker-chainer-2018-05-04-20-04-24-955
INFO:sagemaker:Creating endpoint with name sagemaker-chainer-2018-05-04-20-04-24-955


-------------------------------------------------------------!

### CIFAR10 sample images

We'll use these CIFAR10 sample images to test the service:

<img style="display: inline; height: 32px; margin: 0.25em" src="images/airplane1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/automobile1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/bird1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/cat1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/deer1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/dog1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/frog1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/horse1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/ship1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/truck1.png" />



## Predicting using SageMaker Endpoint

We batch the images together into a single NumPy array to obtain multiple inferences with a single prediction request.

In [9]:
from skimage import io
import numpy as np

def read_image(filename):
    img = io.imread(filename)
    img = np.array(img).transpose(2, 0, 1)
    img = np.expand_dims(img, axis=0)
    img = img.astype(np.float32)
    img *= 1. / 255.
    img = img.reshape(3, 32, 32)
    return img


def read_images(filenames):
    return np.array([read_image(f) for f in filenames])

filenames = ['images/airplane1.png',
             'images/automobile1.png',
             'images/bird1.png',
             'images/cat1.png',
             'images/deer1.png',
             'images/dog1.png',
             'images/frog1.png',
             'images/horse1.png',
             'images/ship1.png',
             'images/truck1.png']

image_data = read_images(filenames)

The predictor runs inference on our input data and returns a list of predictions whose argmax gives the predicted label of the input data. 

In [10]:
response = predictor.predict(image_data)

for i, prediction in enumerate(response):
    print('image {}: prediction: {}'.format(i, prediction.argmax(axis=0)))

image 0: prediction: 0
image 1: prediction: 1
image 2: prediction: 2
image 3: prediction: 3
image 4: prediction: 4
image 5: prediction: 5
image 6: prediction: 6
image 7: prediction: 7
image 8: prediction: 8
image 9: prediction: 9


## Cleanup

After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it.

In [11]:
sagemaker.Session().delete_endpoint(predictor.endpoint)

INFO:sagemaker:Deleting endpoint with name: sagemaker-chainer-2018-05-04-20-04-24-955
