In [None]:
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils 

import com.amazonaws.services.sagemaker.sparksdk.IAMRole
import com.amazonaws.services.sagemaker.sparksdk.algorithms.XGBoostSageMakerEstimator
import com.amazonaws.services.sagemaker.sparksdk.SageMakerResourceCleanup

In [None]:
// Load 2 types of emails from text files: spam and ham (non-spam).
// Each line has text from one email.

// Convert to lower case, remove punctuation and numbers, trim whitespace
// This adds 0.6% accurary!

val spam = sc.textFile("s3://sagemaker-eu-west-1-123456789012/spam").map(l => l.toLowerCase()).map(l => l.replaceAll("[^ a-z]", "")).map(l => l.trim())
    
val ham = sc.textFile("s3://sagemaker-eu-west-1-123456789012/ham").map(l => l.toLowerCase()).map(l => l.replaceAll("[^ a-z]", "")).map(l => l.trim())
    
spam.take(5)

In [None]:
// Create a HashingTF instance to map email text to vectors of features.
val tf = new HashingTF(numFeatures = 200)
// Each email is split into words, and each word is mapped to one feature.
val spamFeatures = spam.map(email => tf.transform(email.split(" ")))
val hamFeatures = ham.map(email => tf.transform(email.split(" ")))

// Display features for a spam sample
spamFeatures.take(1)
// Display features for a ham sample
hamFeatures.take(1)

In [None]:
// Create LabeledPoint datasets for positive (spam) and negative (ham) examples.
val positiveExamples = spamFeatures.map(features => LabeledPoint(1, features))
val negativeExamples = hamFeatures.map(features => LabeledPoint(0, features))

// Display label for a spam sample
positiveExamples.take(1)
// Display label for a ham sample
negativeExamples.take(1)

In [None]:
// The XGBoost built-in algo requires a libsvm-formatted DataFrame
val data = positiveExamples.union(negativeExamples)
val data_libsvm = MLUtils.convertVectorColumnsToML(data.toDF)
data_libsvm.take(2)

In [None]:
// Split the data set 80/20
val Array(trainingData, testData) = data_libsvm.randomSplit(Array(0.8, 0.2))

In [None]:
val roleArn = "YOUR_SAGEMAKER_ROLE"

val xgboost_estimator = new XGBoostSageMakerEstimator(
    trainingInstanceType="ml.m5.large", trainingInstanceCount=1,
    endpointInstanceType="ml.t2.medium", endpointInitialInstanceCount=1, 
    sagemakerRole=IAMRole(roleArn))

xgboost_estimator.setObjective("binary:logistic")
xgboost_estimator.setNumRound(25)

In [None]:
val xgboost_model = xgboost_estimator.fit(trainingData)

In [None]:
val transformedData = xgboost_model.transform(testData)
transformedData.head(5)

In [None]:
val roundedData = transformedData.withColumn("prediction_rounded", when($"prediction" > 0.5 , 1.0).otherwise(0.0))

In [None]:
val accuracy = 1.0 * roundedData.filter($"label"=== $"prediction_rounded").count / roundedData.count()

In [None]:
xgboost_model.getCreatedResources

In [None]:
val cleanup = new SageMakerResourceCleanup(xgboost_model.sagemakerClient)
cleanup.deleteResources(xgboost_model.getCreatedResources)