# Mxnet MNIST BYOM. Train locally and deploy on SageMaker.

In this notebook, we will train a model locally on the notebook instance and will deploy and predict from Sagemaker. This can easily be extended to a model trained anywhere else as well. All that is needed is the exported model file and the entry point file containing model definitions. 

First, let us begin by downloading the mnist data using the mxnet utilities.

In [None]:
import mxnet as mx
data = mx.test_utils.get_mnist()

Train a typical mxnet model for lenet.

In [None]:
from mnist import train
model = train(data = data)

Export the model and save it down. Analogous to the tensorflow example, some structure needs to be followed, which is explained in the following code.

In [None]:
import os
os.mkdir('model')
model.save_checkpoint('model/model', 0000)
import tarfile
with tarfile.open('model.tar.gz', mode='w:gz') as archive:
    archive.add('model', recursive=True)

Open a sagemaker session and upload the model on to the default S3 bucket.

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()
inputs = sagemaker_session.upload_data(path='model.tar.gz', key_prefix='model')

Use the ``sagemaker.mxnet.model.MXNetModel`` to create a new model that can be deployed.

In [None]:
from sagemaker.mxnet.model import MXNetModel
sagemaker_model = MXNetModel(model_data = 's3://' + sagemaker_session.default_bucket() + '/model/model.tar.gz',
                                  role = 'arn:aws:iam::032969728358:role/SageMakerRole',
                                  entry_point = 'mnist.py')

Deploy the model

In [None]:
predictor = sagemaker_model.deploy(initial_instance_count=1,
                                          instance_type='ml.c4.xlarge')

We can now use this predictor to classify hand-written digits.

In [None]:
predict_sample = data['test_data'][0][0]
response = predictor.predict(data)
print('Raw prediction result:')
print(response)

(Optional) Delete the Endpoint

In [None]:
print(predictor.endpoint)

In [None]:
import sagemaker

sagemaker.Session().delete_endpoint(predictor.endpoint)

In [None]:
os.remove('model.tar.gz')
import shutil
shutil.rmtree('export')