# Training and hosting SageMaker Models using the Apache MXNet Module API

The **SageMaker Python SDK** makes it easy to train and deploy MXNet models. In this example, we train a simple neural network using the Apache MXNet [Module API](https://mxnet.incubator.apache.org/api/python/module.html) and the MNIST dataset. The MNIST dataset is widely used for handwritten digit classification, and consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images.

### Setup

First we need to define a few variables that will be needed later in the example.

In [None]:
from sagemaker import get_execution_role
import json
import boto3
import time
import os
import time
import tarfile
from botocore.exceptions import ClientError
cf = boto3.client('cloudformation')
s3 = boto3.client('s3')
sns = boto3.client('sns')
step = boto3.client('stepfunctions')
sagemaker = boto3.client('sagemaker-runtime')
ssm=boto3.client('ssm')
cf = boto3.client('cloudformation')

with open('../SageBuild/config.json') as json_file:  
    config = json.load(json_file)
StackName=config["StackName"]

result=cf.describe_stacks(
    StackName=StackName
)
outputs={}
for output in result['Stacks'][0]['Outputs']:
    outputs[output['OutputKey']]=output['OutputValue']

We need to make sure the Sagebuild template is configured correctly for MXNET. the following code will set the stack configuration

In [None]:
parameters=result["Stacks"][0]["Parameters"]
for n,i in enumerate(parameters):
    if(i["ParameterKey"]=="ConfigFramework"):
        i["ParameterValue"]="BYOD" 

try:
    cf.update_stack(
        StackName=StackName,
        UsePreviousTemplate=True,
        Parameters=params,
        Capabilities=[
            'CAPABILITY_NAMED_IAM',
        ]
    )
    waiter = cf.get_waiter('stack_update_complete')
    print("Waiting for stack update")
    waiter.wait(
        StackName=StackName,
        WaiterConfig={
            'Delay':10,
            'MaxAttempts':600
        }
    )

except ClientError as e:
    if(e.response["Error"]["Message"]=="No updates are to be performed."):
        pass
    else:
        raise e
print("stack ready!")

### Update SageBuild Parameters

In [None]:
store=outputs["ParameterStore"]
result=ssm.get_parameter(Name=store)

params=json.loads(result["Parameter"]["Value"])
params["traininstancetype"]="ml.p3.2xlarge"
params["trainvolumesize"]="1000"
params["channels"]={
    "train":{
        "path":"test"
    }
}
params["train"]=False
params["build"]={
    "Inference":False,
    "Training":False
}
params["hyperparameters"]={
    "sagemaker_container_log_level":"200",
    "sagemaker_enable_cloudwatch_metrics":"false",
    "sagemaker_job_name":'"{}-{}"'.format(StackName,"v1"),
    "sagemaker_program":'"train.py"'.format(),
    "sagemaker_region":'"{}"'.format(config["Region"]),
    "sagemaker_submit_directory":'"s3://{}/gan.tar.gz"'.format(outputs["CodeBucket"]),
}
params["modelhostingenvironment"]={
    "SAGEMAKER_CONTAINER_LOG_LEVEL":"200",
    "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS":"false",
    "SAGEMAKER_PROGRAM":"host.py",
    "SAGEMAKER_REGION":config["Region"],
    "SAGEMAKER_SUBMIT_DIRECTORY":"s3://{}/gan.tar.gz".format(outputs["CodeBucket"]),
    "UPLOAD_BUCKET":outputs["DataBucket"]
}
params["dockerfile_path_Training"]="gpu"
params["dockerfile_path_Inference"]="cpu"
params["pyversion"]="py2"

ssm.put_parameter(
    Name=store,
    Type="String",
    Overwrite=True,
    Value=json.dumps(params)
)

## The training script

The ``mnist.py`` script provides all the code we need for training and hosting a SageMaker model. The script we will use is adaptated from Apache MXNet [MNIST tutorial (https://mxnet.incubator.apache.org/tutorials/python/mnist.html).

### Start Train/Deploy pipeline

In [None]:
%%time
result=sns.publish(
    TopicArn=outputs['LaunchTopic'],
    Message="{}" #message is not important, just publishing to topic starts build
)
print(result)
time.sleep(5)
#list all executions for our StateMachine to get our current running one
result=step.list_executions(
    stateMachineArn=outputs['StateMachine'],
    statusFilter="RUNNING"
)['executions']

if len(result) > 0:
    response = step.describe_execution(
        executionArn=result[0]['executionArn']
    )
    status=response['status']
    print(status,response['name'])
    #poll status till execution finishes
    while status == "RUNNING":
        print('.',end="")
        time.sleep(5)
        status=step.describe_execution(executionArn=result[0]['executionArn'])['status']
    print()
    print(status)
else:
    print("no running tasks")

### Making an inference request

The request handling behavior of the Endpoint is determined by the ``mnist.py`` script. In this case, the script doesn't include any request handling functions, so the Endpoint will use the default handlers provided by SageMaker. These default handlers allow us to perform inference on input data encoded as a multi-dimensional JSON array.

To see inference in action, draw a digit in the image box below. The pixel data from your drawing will be loaded into a ``data`` variable in this notebook. 

*Note: after drawing the image, you'll need to move to the next notebook cell.*

In [None]:
from IPython.display import Image

result=sagemaker.invoke_endpoint(
    EndpointName=outputs["SageMakerEndpoint"],
    Body=json.dumps({
        "row":9,
        "path":12
    }),    
    ContentType="application/json",
    Accept="application/json"
)

body=json.loads(result["Body"].read())
Image(url= body["url"])
