## Training with Chainer

[VGG](https://arxiv.org/pdf/1409.1556v6.pdf) is an architecture for deep convolution networks. In this example, we train a convolutional network to perform image classification using the CIFAR-10 dataset. CIFAR-10 consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We'll train a model on SageMaker, deploy it to Amazon SageMaker, and then classify images using the deployed model.

To train with a Chainer script, we construct a ```Chainer``` estimator using the [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk). We can pass in an `entry_point`, the name of a script that contains a couple of functions with certain signatures (`train` and `model_fn`). This script will be run on SageMaker in a container that invokes these functions to train and load Chainer models. 

For more on the Chainer container, please visit the sagemaker-chainer-containers repository:
https://github.com/aws/sagemaker-chainer-containers

In [1]:
# Setup
from sagemaker import get_execution_role
import sagemaker

sagemaker_session = sagemaker.Session()

# This role retrieves the SageMaker-compatible role used by this Notebook Instance.
role = get_execution_role()

## Downloading training and test data

We use helper functions given by `chainer` to download and preprocess the CIFAR10 data. 

In [2]:
import chainer

from chainer.datasets import get_cifar10

train, test = get_cifar10()

  from ._conv import register_converters as _register_converters


## Uploading the data

We save the preprocessed data to the local filesystem, and then use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value `inputs` identifies the S3 location, which we will use when we start the Training Job.

In [3]:
import os
import shutil

import numpy as np

train_data = [element[0] for element in train]
train_labels = [element[1] for element in train]

test_data = [element[0] for element in test]
test_labels = [element[1] for element in test]

try:
    os.makedirs('data/train')
    os.makedirs('data/test')
except FileExistsError:
    pass

np.savez('data/train/train.npz', data=train_data, labels=train_labels)
np.savez('data/test/test.npz', data=test_data, labels=test_labels)

# Upload preprocessed data to S3 
train_input = sagemaker_session.upload_data(path=os.path.join('data', 'train'),
                                                            key_prefix='notebook/chainer_cifar/train')
test_input = sagemaker_session.upload_data(path=os.path.join('data', 'test'),
                                                           key_prefix='notebook/chainer_cifar/test')

# Remove data from notebook instance (to conserve disk space)
shutil.rmtree('data')

## Writing the Chainer training script to run on Amazon SageMaker

We need to provide a training script that can run on the SageMaker platform. The training scripts are essentially the same as one you would write for local training, except that you need to provide a function `train` that returns a trained model.

Since we will use the same script to host the Chainer model, the script also needs a function `model_fn` that loads the model -- by default, Chainer models are saved to disk as `model.npz`. When SageMaker calls your `train` and `model_fn` functions, it will pass in arguments that describe the training environment.

See the script below, which uses `chainer` to train on any number of GPUs on a single machine, to see how this works. For more on implementing these functions, see the documentation at https://github.com/aws/sagemaker-python-sdk.

In [4]:
!cat 'code/chainer_cifar_vgg_single_machine.py'

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import print_function, absolute_import

import os

import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer import serializers
from chainer.training import extensions

import net


def train(hyperparameters, num_gpus, output_data_dir, channel_input_dirs):
    """
    This function is called by the Chainer cont

## Running the training script on SageMaker

The ```Chainer``` class allows us to run our training function as a training job on SageMaker infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. In this case we will run our training job on two `ml.p3.2xlarge` instances.

This script uses the `chainermn` package, which distributes training with MPI. Your script is run with `mpirun`, so a `chainermn

Chainer scripts can distribute training with the `chainermn` package, which this Chainer script does not use, so this script should only be run on one instance.

In [5]:
from sagemaker.chainer.estimator import Chainer

chainer_estimator = Chainer(entry_point='chainer_cifar_vgg_single_machine.py', source_dir="code", role=role,
                            sagemaker_session=sagemaker_session,
                            train_instance_count=1, train_instance_type='ml.p3.2xlarge',
                            hyperparameters={'epochs': 50, 'batch_size': 64})

chainer_estimator.fit({'train': train_input, 'test': test_input})

INFO:sagemaker:Creating training-job with name: sagemaker-chainer-2018-05-04-20-05-57-636


...................................................
[31m2018-05-04 20:10:03,802 INFO - root - running container entrypoint[0m
[31m2018-05-04 20:10:03,802 INFO - root - starting train task[0m
[31m2018-05-04 20:10:03,814 INFO - container_support.app - started training: {'train_fn': <function train at 0x7f95f865abf8>}[0m
[31mDownloading s3://sagemaker-us-west-2-038453126632/sagemaker-chainer-2018-05-04-20-05-57-636/source/sourcedir.tar.gz to /tmp/script.tar.gz[0m
[31m2018-05-04 20:10:03,970 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTP connection (1): 169.254.170.2[0m
[31m2018-05-04 20:10:04,055 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTPS connection (1): sagemaker-us-west-2-038453126632.s3.amazonaws.com[0m
[31m2018-05-04 20:10:04,091 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTPS connection (2): sagemaker-us-west-2-038453126632.s3.amazonaws.com[0m
[31m2

[31m#033[4A#033[J     total [##................................................]  5.89%[0m
[31mthis epoch [###############################################...] 94.40%
      2300 iter, 2 epoch / 50 epochs
    26.406 iters/sec. Estimated time to finish: 0:23:12.182037.[0m
[31m#033[4A#033[J3           1.45249     1.33205               0.456486       0.517815                  106.216       [0m
[31m#033[J     total [###...............................................]  6.14%[0m
[31mthis epoch [###...............................................]  7.20%
      2400 iter, 3 epoch / 50 epochs
     25.72 iters/sec. Estimated time to finish: 0:23:45.462532.[0m
[31m#033[4A#033[J     total [###...............................................]  6.40%[0m
[31mthis epoch [##########........................................] 20.00%
      2500 iter, 3 epoch / 50 epochs
    25.806 iters/sec. Estimated time to finish: 0:23:36.829590.[0m
[31m#033[4A#033[J     total [###............................

[31m#033[4A#033[J     total [######............................................] 13.57%[0m
[31mthis epoch [#######################################...........] 78.40%
      5300 iter, 6 epoch / 50 epochs
    26.036 iters/sec. Estimated time to finish: 0:21:36.751916.[0m
[31m#033[4A#033[J     total [######............................................] 13.82%[0m
[31mthis epoch [#############################################.....] 91.20%
      5400 iter, 6 epoch / 50 epochs
    26.075 iters/sec. Estimated time to finish: 0:21:30.987540.[0m
[31m#033[4A#033[J7           0.881214    0.845455              0.705466       0.720541                  227.037       [0m
[31m#033[J     total [#######...........................................] 14.08%[0m
[31mthis epoch [##................................................]  4.00%
      5500 iter, 7 epoch / 50 epochs
    25.793 iters/sec. Estimated time to finish: 0:21:41.234377.[0m
[31m#033[4A#033[J     total [#######........................

[31m#033[4A#033[J     total [##########........................................] 21.25%[0m
[31mthis epoch [###############################...................] 62.40%
      8300 iter, 10 epoch / 50 epochs
    25.958 iters/sec. Estimated time to finish: 0:19:45.097490.[0m
[31m#033[4A#033[J     total [##########........................................] 21.50%[0m
[31mthis epoch [#####################################.............] 75.20%
      8400 iter, 10 epoch / 50 epochs
    25.985 iters/sec. Estimated time to finish: 0:19:40.029077.[0m
[31m#033[4A#033[J     total [##########........................................] 21.76%[0m
[31mthis epoch [############################################......] 88.00%
      8500 iter, 10 epoch / 50 epochs
     26.01 iters/sec. Estimated time to finish: 0:19:35.030354.[0m
[31m#033[4A#033[J11          0.720952    0.878712              0.767646       0.712182                  347.608       [0m
[31m#033[J     total [###########.................

[31m#033[4A#033[J     total [##############....................................] 28.67%[0m
[31mthis epoch [################..................................] 33.60%
     11200 iter, 14 epoch / 50 epochs
    25.829 iters/sec. Estimated time to finish: 0:17:58.729954.[0m
[31m#033[4A#033[J     total [##############....................................] 28.93%[0m
[31mthis epoch [#######################...........................] 46.40%
     11300 iter, 14 epoch / 50 epochs
    25.827 iters/sec. Estimated time to finish: 0:17:54.934231.[0m
[31m#033[4A#033[J     total [##############....................................] 29.18%[0m
[31mthis epoch [#############################.....................] 59.20%
     11400 iter, 14 epoch / 50 epochs
    25.827 iters/sec. Estimated time to finish: 0:17:51.072825.[0m
[31m#033[4A#033[J     total [##############....................................] 29.44%[0m
[31mthis epoch [####################################..............] 72.00%
     1

[31m#033[4A#033[J18          0.625197    0.867325              0.801436       0.728603                  558.752       [0m
[31m#033[J     total [##################................................] 36.10%[0m
[31mthis epoch [##................................................]  4.80%
     14100 iter, 18 epoch / 50 epochs
    25.871 iters/sec. Estimated time to finish: 0:16:04.866968.[0m
[31m#033[4A#033[J     total [##################................................] 36.35%[0m
[31mthis epoch [########..........................................] 17.60%
     14200 iter, 18 epoch / 50 epochs
    25.872 iters/sec. Estimated time to finish: 0:16:00.974255.[0m
[31m#033[4A#033[J     total [##################................................] 36.61%[0m
[31mthis epoch [###############...................................] 30.40%
     14300 iter, 18 epoch / 50 epochs
    25.875 iters/sec. Estimated time to finish: 0:15:57.002787.[0m
[31m#033[4A#033[J     total [##################..........

[31m#033[4A#033[J     total [#####################.............................] 43.78%[0m
[31mthis epoch [############################################......] 88.80%
     17100 iter, 21 epoch / 50 epochs
    26.065 iters/sec. Estimated time to finish: 0:14:02.620438.[0m
[31m#033[4A#033[J22          0.598742    0.611572              0.810159       0.797373                  678.981       [0m
[31m#033[J     total [######################............................] 44.03%[0m
[31mthis epoch [..................................................]  1.60%
     17200 iter, 22 epoch / 50 epochs
    25.894 iters/sec. Estimated time to finish: 0:14:04.316837.[0m
[31m#033[4A#033[J     total [######################............................] 44.29%[0m
[31mthis epoch [#######...........................................] 14.40%
     17300 iter, 22 epoch / 50 epochs
    25.896 iters/sec. Estimated time to finish: 0:14:00.381391.[0m
[31m#033[4A#033[J     total [######################......

[31m#033[4A#033[J     total [#########################.........................] 51.20%[0m
[31mthis epoch [##############################....................] 60.00%
     20000 iter, 25 epoch / 50 epochs
    25.927 iters/sec. Estimated time to finish: 0:12:15.246801.[0m
[31m#033[4A#033[J     total [#########################.........................] 51.46%[0m
[31mthis epoch [####################################..............] 72.80%
     20100 iter, 25 epoch / 50 epochs
    25.927 iters/sec. Estimated time to finish: 0:12:11.366893.[0m
[31m#033[4A#033[J     total [#########################.........................] 51.71%[0m
[31mthis epoch [##########################################........] 85.60%
     20200 iter, 25 epoch / 50 epochs
    26.101 iters/sec. Estimated time to finish: 0:12:02.660663.[0m
[31m#033[4A#033[J     total [#########################.........................] 51.97%[0m
[31mthis epoch [#################################################.] 98.40%
     2

[31m#033[4A#033[J     total [#############################.....................] 58.62%[0m
[31mthis epoch [###############...................................] 31.20%
     22900 iter, 29 epoch / 50 epochs
    25.955 iters/sec. Estimated time to finish: 0:10:22.711562.[0m
[31m#033[4A#033[J     total [#############################.....................] 58.88%[0m
[31mthis epoch [######################............................] 44.00%
     23000 iter, 29 epoch / 50 epochs
    25.955 iters/sec. Estimated time to finish: 0:10:18.850783.[0m
[31m#033[4A#033[J     total [#############################.....................] 59.14%[0m
[31mthis epoch [############################......................] 56.80%
     23100 iter, 29 epoch / 50 epochs
    25.956 iters/sec. Estimated time to finish: 0:10:14.994739.[0m
[31m#033[4A#033[J     total [#############################.....................] 59.39%[0m
[31mthis epoch [##################################................] 69.60%
     2

[31m#033[4A#033[J33          0.409315    0.528008              0.868246       0.828822                  1009.78       [0m
[31m#033[J     total [#################################.................] 66.05%[0m
[31mthis epoch [#.................................................]  2.40%
     25800 iter, 33 epoch / 50 epochs
    25.945 iters/sec. Estimated time to finish: 0:08:31.183963.[0m
[31m#033[4A#033[J     total [#################################.................] 66.30%[0m
[31mthis epoch [#######...........................................] 15.20%
     25900 iter, 33 epoch / 50 epochs
    25.944 iters/sec. Estimated time to finish: 0:08:27.341291.[0m
[31m#033[4A#033[J     total [#################################.................] 66.56%[0m
[31mthis epoch [##############....................................] 28.00%
     26000 iter, 33 epoch / 50 epochs
    25.943 iters/sec. Estimated time to finish: 0:08:23.506181.[0m
[31m#033[4A#033[J     total [############################

[31m#033[4A#033[J     total [####################################..............] 73.73%[0m
[31mthis epoch [###########################################.......] 86.40%
     28800 iter, 36 epoch / 50 epochs
    26.087 iters/sec. Estimated time to finish: 0:06:33.397199.[0m
[31m#033[4A#033[J     total [####################################..............] 73.98%[0m
[31mthis epoch [#################################################.] 99.20%
     28900 iter, 36 epoch / 50 epochs
    26.086 iters/sec. Estimated time to finish: 0:06:29.572111.[0m
[31m#033[4A#033[J37          0.398674    0.454225              0.872502       0.852707                  1130.43       [0m
[31m#033[J     total [#####################################.............] 74.24%[0m
[31mthis epoch [#####.............................................] 12.00%
     29000 iter, 37 epoch / 50 epochs
    25.913 iters/sec. Estimated time to finish: 0:06:28.321525.[0m
[31m#033[4A#033[J     total [############################

[31m#033[4A#033[J     total [########################################..........] 81.41%[0m
[31mthis epoch [###################################...............] 70.40%
     31800 iter, 40 epoch / 50 epochs
    25.902 iters/sec. Estimated time to finish: 0:04:40.385706.[0m
[31m#033[4A#033[J     total [########################################..........] 81.66%[0m
[31mthis epoch [#########################################.........] 83.20%
     31900 iter, 40 epoch / 50 epochs
    26.074 iters/sec. Estimated time to finish: 0:04:34.699958.[0m
[31m#033[4A#033[J     total [########################################..........] 81.92%[0m
[31mthis epoch [################################################..] 96.00%
     32000 iter, 40 epoch / 50 epochs
    26.074 iters/sec. Estimated time to finish: 0:04:30.868157.[0m
[31m#033[4A#033[J41          0.388386    0.482514              0.875999       0.847134                  1251.06       [0m
[31m#033[J     total [############################

[31m#033[4A#033[J     total [############################################......] 88.83%[0m
[31mthis epoch [####################..............................] 41.60%
     34700 iter, 44 epoch / 50 epochs
     25.81 iters/sec. Estimated time to finish: 0:02:49.022658.[0m
[31m#033[4A#033[J     total [############################################......] 89.09%[0m
[31mthis epoch [###########################.......................] 54.40%
     34800 iter, 44 epoch / 50 epochs
    25.788 iters/sec. Estimated time to finish: 0:02:45.293178.[0m
[31m#033[4A#033[J     total [############################################......] 89.34%[0m
[31mthis epoch [#################################.................] 67.20%
     34900 iter, 44 epoch / 50 epochs
    25.772 iters/sec. Estimated time to finish: 0:02:41.515041.[0m
[31m#033[4A#033[J     total [############################################......] 89.60%[0m
[31mthis epoch [#######################################...........] 80.00%
     3

[31m#033[4A#033[J     total [################################################..] 96.51%[0m
[31mthis epoch [############......................................] 25.60%
     37700 iter, 48 epoch / 50 epochs
    25.712 iters/sec. Estimated time to finish: 0:00:52.990169.[0m
[31m#033[4A#033[J     total [################################################..] 96.77%[0m
[31mthis epoch [###################...............................] 38.40%
     37800 iter, 48 epoch / 50 epochs
    25.715 iters/sec. Estimated time to finish: 0:00:49.096743.[0m
[31m#033[4A#033[J     total [################################################..] 97.02%[0m
[31mthis epoch [#########################.........................] 51.20%
     37900 iter, 48 epoch / 50 epochs
    25.716 iters/sec. Estimated time to finish: 0:00:45.206018.[0m
[31m#033[4A#033[J     total [################################################..] 97.28%[0m
[31mthis epoch [################################..................] 64.00%
     3

Our Chainer script writes various artifacts, such as plots, to a directory `output_data_dir`, the contents of which which SageMaker uploads to S3. Now we download and extract these artifacts.

In [6]:
from s3_util import retrieve_output_from_s3

chainer_training_job = chainer_estimator.latest_training_job.name

desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=chainer_training_job)
output_data = desc['ModelArtifacts']['S3ModelArtifacts'].replace('model.tar.gz', 'output.tar.gz')

retrieve_output_from_s3(output_data, 'output/single_machine_cifar')

In [7]:
# Executing as code to reload images so that browsers don't render cached images.
from IPython.display import Markdown
import time
_nonce = time.time()

Markdown("""
These plots show the accuracy and loss over epochs:

<img style="display: inline;" src="output/single_machine_cifar/accuracy.png?{0}" />
<img style="display: inline;" src="output/single_machine_cifar/loss.png?{0}" />""".format(_nonce))



These plots show the accuracy and loss over epochs:

<img style="display: inline;" src="output/single_machine_cifar/accuracy.png?1525466159.2398272" />
<img style="display: inline;" src="output/single_machine_cifar/loss.png?1525466159.2398272" />

## Deploying the Trained Model

After training, we use the Chainer estimator object to create and deploy a hosted prediction endpoint. We can use a CPU-based instance for inference (in this case an `ml.m4.xlarge`), even though we trained on GPU instances.

The predictor object returned by `deploy` lets us call the new endpoint and perform inference on our sample images. 

In [8]:
predictor = chainer_estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

INFO:sagemaker:Creating model with name: sagemaker-chainer-2018-05-04-20-05-57-636
INFO:sagemaker:Creating endpoint with name sagemaker-chainer-2018-05-04-20-05-57-636


------------------------------------------------------------!

### CIFAR10 sample images

We'll use these CIFAR10 sample images to test the service:

<img style="display: inline; height: 32px; margin: 0.25em" src="images/airplane1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/automobile1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/bird1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/cat1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/deer1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/dog1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/frog1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/horse1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/ship1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/truck1.png" />



## Predicting using SageMaker Endpoint

We batch the images together into a single NumPy array to obtain multiple inferences with a single prediction request.

In [9]:
from skimage import io
import numpy as np

def read_image(filename):
    img = io.imread(filename)
    img = np.array(img).transpose(2, 0, 1)
    img = np.expand_dims(img, axis=0)
    img = img.astype(np.float32)
    img *= 1. / 255.
    img = img.reshape(3, 32, 32)
    return img


def read_images(filenames):
    return np.array([read_image(f) for f in filenames])

filenames = ['images/airplane1.png',
             'images/automobile1.png',
             'images/bird1.png',
             'images/cat1.png',
             'images/deer1.png',
             'images/dog1.png',
             'images/frog1.png',
             'images/horse1.png',
             'images/ship1.png',
             'images/truck1.png']

image_data = read_images(filenames)

The predictor runs inference on our input data and returns a list of predictions whose argmax gives the predicted label of the input data. 

In [10]:
response = predictor.predict(image_data)

for i, prediction in enumerate(response):
    print('image {}: prediction: {}'.format(i, prediction.argmax(axis=0)))

image 0: prediction: 0
image 1: prediction: 1
image 2: prediction: 2
image 3: prediction: 3
image 4: prediction: 4
image 5: prediction: 5
image 6: prediction: 6
image 7: prediction: 7
image 8: prediction: 8
image 9: prediction: 9


## Cleanup

After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it.

In [11]:
sagemaker.Session().delete_endpoint(predictor.endpoint)

INFO:sagemaker:Deleting endpoint with name: sagemaker-chainer-2018-05-04-20-05-57-636
