Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
A copy of the License is located at
 
  http://aws.amazon.com/apache2.0/

or in the "license" file accompanying this file. This file is distributed
on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License.

# SageMakerPySpark MNIST Example

1. [Introduction](#Introduction)
2. [Data Inspection](#Data-Inspection)
3. [Training the K-Means Model](#Training-the-K-Means-Model)
4. [Validate the Model for use](#Validate-the-Model-for-use)
5. [Bring your Own Algorithm](#Bring-your-Own-Algorithm)


## Introduction
This notebook will show how to classify handwritten digits using the KMeans clustering algorithm through the SageMakerPySparkSDK.

You can visit SageMaker Spark's Github repository at https://github.com/aws/sagemaker-spark for more about SageMaker Spark.

We will train on Amazon SageMaker using the KMeans Clustering on the MNIST dataset, host the trained model on Amazon SageMaker, and then make predictions against that hosted model.

First, we load the MNIST dataset into a Spark Dataframe, which dataset is available in LibSVM format at

s3://sagemaker-sample-data-[region, such as us-east-1]/spark/mnist/train/

In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import os
import sagemaker_pyspark
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()

# Configure Spark to use the SageMaker Spark dependency jars
jars = sagemaker_pyspark.classpath_jars()

classpath = ":".join(sagemaker_pyspark.classpath_jars())

# See the SageMaker Spark Github repo under sagemaker-pyspark-sdk
# to learn how to connect to a remote EMR cluster running Spark from a Notebook Instance.
spark = SparkSession.builder.config("spark.driver.extraClassPath", classpath)\
    .master("local[*]").getOrCreate()

In [None]:
# replace this with your own region, such as us-east-1
region = 'us-east-1'
trainingData = spark.read.format('libsvm')\
    .option('numFeatures', '784')\
    .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))

testData = spark.read.format('libsvm')\
    .option('numFeatures', '784')\
    .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))

## Data Inspection
In order to train and make inferences our input DataFrame must have a column of Doubles (named "label" by default) and a column of Vectors of Doubles (named "features" by default).

Spark's LibSVM DataFrameReader loads a DataFrame already suitable for training and inference.

In [None]:
trainingData.show()

## Training the K-Means Model
Now we create a KMeansSageMakerEstimator, which uses the KMeans Amazon SageMaker Algorithm to train on our input data, and uses the KMeans Amazon SageMaker model image to host our model.

Calling fit() on this estimator will train our model on Amazon SageMaker, and then create an Amazon SageMaker Endpoint to host our model.

We can then use the SageMakerModel returned by this call to fit() to transform Dataframes using our hosted model.

The following cell runs a training job and creates an endpoint to host the resulting model, so this cell can take up to twenty minutes to complete.

In [None]:
import random
from sagemaker_pyspark import IAMRole, S3DataPath
from sagemaker_pyspark.algorithms import KMeansSageMakerEstimator

# replace this with your role ARN
kmeans_estimator = KMeansSageMakerEstimator(
    sagemakerRole=IAMRole(role),
    trainingInstanceType='ml.p2.xlarge',
    trainingInstanceCount=1,
    endpointInstanceType='ml.c4.xlarge',
    endpointInitialInstanceCount=1)

kmeans_estimator.setK(10)
kmeans_estimator.setFeatureDim(784)

# train
model = kmeans_estimator.fit(trainingData)

## Validate the Model for use
Now we transform our DataFrame.
To do this, we serialize each row's "features" Vector of Doubles into a Protobuf format for inference against the Amazon SageMaker Endpoint. We deserialize the Protobuf responses back into our DataFrame:

In [None]:
transformedData = model.transform(testData)

transformedData.show()

In [None]:
from pyspark.sql.types import DoubleType
import matplotlib.pyplot as plt
import numpy as np

# helper function to display a digit
def show_digit(img, caption='', xlabel='', subplot=None):
    if subplot==None:
        _,(subplot)=plt.subplots(1,1)
    imgr=img.reshape((28,28))
    subplot.axes.get_xaxis().set_ticks([])
    subplot.axes.get_yaxis().set_ticks([])
    plt.title(caption)
    plt.xlabel(xlabel)
    subplot.imshow(imgr, cmap='gray')

images = np.array(transformedData.select("features").cache().take(250))
clusters = transformedData.select("closest_cluster").cache().take(250)

for cluster in range(10):
    print('\n\n\nCluster {}:'.format(int(cluster)))
    digits = [ img for l, img in zip(clusters, images) if int(l.closest_cluster) == cluster ]
    height=((len(digits)-1)//5)+1
    width=5
    plt.rcParams["figure.figsize"] = (width,height)
    _, subplots = plt.subplots(height, width)
    subplots=np.ndarray.flatten(subplots)
    for subplot, image in zip(subplots, digits):
        show_digit(image, subplot=subplot)
    for subplot in subplots[len(digits):]:
        subplot.axis('off')

    plt.show()

In [None]:
# Delete the endpoint

from sagemaker_pyspark import SageMakerResourceCleanup

resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient)
resource_cleanup.deleteResources(model.getCreatedResources())

## Bring your Own Algorithm

The SageMaker Spark Github repository has more about SageMaker Spark, including how to use SageMaker Spark with your own algorithms on Amazon SageMaker: https://github.com/aws/sagemaker-spark
