In [None]:
import mxnet as mx
import sagemaker
from sagemaker.mxnet import MXNet as MXNetEstimator

In [None]:
mx.test_utils.get_cifar10() # Downloads Cifar-10 dataset to ./data

sagemaker_session = sagemaker.Session()
inputs = sagemaker_session.upload_data(path='data/cifar',
                                       key_prefix='data/cifar10')

In [None]:
estimator = MXNetEstimator(entry_point='multiple_gpus_sagemaker.py', 
                           role=sagemaker.get_execution_role(),
                           train_instance_count=1, 
                           train_instance_type='ml.p2.8xlarge',
                           hyperparameters={'batch_size': 512, 
                                            'epochs': 30})
estimator.fit(inputs)

In [None]:
predictor = estimator.deploy(initial_instance_count=1,
                             instance_type='ml.m4.xlarge')

In [None]:
from skimage import io
import numpy as np

def read_image(filename):
    img = io.imread(filename)
    img = np.array(img).transpose(2, 0, 1)
    img = np.expand_dims(img, axis=0)

    return img


def read_images(filenames):
    return [read_image(f) for f in filenames]

In [None]:
classes_map = {
    0: 'airplane',
    1: 'automobile',
    2: 'bird',
    3: 'cat',
    4: 'deer',
    5: 'dog',
    6: 'frog',
    7: 'horse',
    8: 'ship',
    9: 'truck'
}

filenames = ['images/airplane1.png',
             'images/automobile1.png',
             'images/bird1.png',
             'images/cat1.png',
             'images/deer1.png',
             'images/dog1.png',
             'images/frog1.png',
             'images/horse1.png',
             'images/ship1.png',
             'images/truck1.png']

image_data = read_images(filenames)

<img style="display: inline; height: 32px; margin: 0.25em" src="images/airplane1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/automobile1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/bird1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/cat1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/deer1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/dog1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/frog1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/horse1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/ship1.png" />
<img style="display: inline; height: 32px; margin: 0.25em" src="images/truck1.png" />

In [None]:
for i, img in enumerate(image_data):
    response = predictor.predict(img)
    print('image {}: class: {}'.format(i, classes_map[int(response)]))

In [None]:
estimator.delete_endpoint()