First, we import the necessary modules and create the SparkSession with the SageMaker-Spark dependencies attached.

In [1]:
import os
import boto3

from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

import sagemaker
from sagemaker import get_execution_role
import sagemaker_pyspark

role = get_execution_role()

# Configure Spark to use the SageMaker Spark dependency jars
jars = sagemaker_pyspark.classpath_jars()

classpath = ":".join(sagemaker_pyspark.classpath_jars())

# See the SageMaker Spark Github to learn how to connect to EMR from a notebook instance
spark = SparkSession.builder.config("spark.driver.extraClassPath", classpath)\
    .master("local[*]").getOrCreate()
    
spark


In [None]:
## Initialize a spark session and add sagemaker jars 
import sagemaker_pyspark
from pyspark.sql import SparkSession
from pyspark import SparkConf
from sagemaker_pyspark import IAMRole, classpath_jars

classpath = ":".join(sagemaker_pyspark.classpath_jars())
spark = SparkSession.builder.config("spark.driver.extraClassPath", classpath).getOrCreate()

# Load the sagemaker_pyspark classpath. If you used --jars to submit your job
# there is no need to do this in code.
#conf = (SparkConf()
#        .set("spark.driver.extraClassPath", ":".join(classpath_jars())))
#SparkContext(conf=conf)



In [44]:
## In PySpark, we recommend using "s3://" 
##to access the EMR file system(EMRFS) in EMR and "s3a://" to access S3A file system in other environments

import boto3
import datetime

region = boto3.Session().region_name
spark._jsc.hadoopConfiguration().set('fs.s3a.endpoint', 's3.{}.amazonaws.com'.format(region))



## Converting Data Frame to Dynamic Frame and writing to S3
today = datetime.datetime.today()
year = today.year
#month = today.month
month = 2
#day = today.day
day = 25

## Read the training data 
s3_train_bucket = "readmission-data-ehr"
s3_train_bucket_prefix = "train-data"

training_data = spark.read.format("parquet").load('s3a://'+s3_train_bucket+'/'+s3_train_bucket_prefix+'/'+str(year)+'/'+str(month)+'/'+str(day)+'/')


In [45]:
training_data.printSchema()

root
 |-- features: vector (nullable = true)
 |-- readmission: integer (nullable = true)



In [54]:
## Renaming the column as SageMaker XGBoost Algorithm is expecting features and label columns
from pyspark.sql.types import DoubleType
training_data = training_data.withColumn("label",training_data.readmission.cast(DoubleType())).drop("readmission")
training_data = training_data.withColumn("features",training_data.features.cast(DoubleType())).drop("readmission")

training_data.printSchema()

root
 |-- features: vector (nullable = true)
 |-- label: double (nullable = true)



In [94]:
from pyspark.sql import functions as F
training_data.select(F.to_json(training_data.features).alias("json")).collect()


AnalysisException: "cannot resolve 'structstojson(`features`)' due to data type mismatch: Input type vector must be a struct, array of structs or a map or array of map.;;\n'Project [structstojson(features#477, Some(Universal)) AS json#1319]\n+- Project [features#477, label#489]\n   +- Project [features#477, readmission#478, cast(readmission#478 as double) AS label#489]\n      +- Project [features#477, readmission#478, cast(readmission#478 as double) AS label#485]\n         +- Relation[features#477,readmission#478] parquet\n"

In [55]:
training_data.show(vertical=True, truncate=100)

-RECORD 0--------------------------------------------------------------------------------------------------------
 features | (322,[0,5,7,8,12,13,108,144,314,315,316,317,318,320,321],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1512398... 
 label    | 0.0                                                                                                  
-RECORD 1--------------------------------------------------------------------------------------------------------
 features | (322,[0,5,7,8,12,17,103,144,314,315,316,317,318,320,321],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1512398... 
 label    | 0.0                                                                                                  
-RECORD 2--------------------------------------------------------------------------------------------------------
 features | (322,[1,5,7,8,12,13,102,144,314,315,316,318,320,321],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1512398.25,... 
 label    | 0.0                                                                         

In [None]:
def toCSVLine(data):
    r = ','.join(str(d) for d in data[1])
    return str(data[0]) + "," + r


region = boto3.Session().region_name
spark._jsc.hadoopConfiguration().set('fs.s3a.endpoint', 's3.{}.amazonaws.com'.format(region))
spark._jsc.hadoopConfiguration().set('fs.s3a.cse.enabled','true')
spark._jsc.hadoopConfiguration().set('fs.s3a.cse.kms.keyId','06c7674c-c4d9-49f9-b52a-114f2d68737b')

transformed_train_rdd = training_data.rdd.map(lambda x: (x.label, x.features))
lines = transformed_train_rdd.map(toCSVLine)
lines.saveAsTextFile('s3a://' + s3_train_bucket + '/' + s3_train_bucket_prefix+ '/' + 'train')

In [88]:
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker_pyspark import SageMakerEstimator
from sagemaker_pyspark.transformation.deserializers import XGBoostCSVRowDeserializer 
from sagemaker_pyspark.transformation.serializers import ProtobufRequestRowSerializer
from sagemaker_pyspark import IAMRole
from sagemaker_pyspark import RandomNamePolicyFactory
from sagemaker_pyspark import EndpointCreationPolicy

import sagemaker
from sagemaker import get_execution_role
import sagemaker_pyspark

role = get_execution_role()

# Create an Estimator from scratch
estimator = SageMakerEstimator(
    trainingImage = get_image_uri(region, 'xgboost', '0.90-1'), # Training image 
    modelImage = get_image_uri(region, 'xgboost', '0.90-1'), # Model image
    requestRowSerializer = ProtobufRequestRowSerializer(),
    responseRowDeserializer = XGBoostCSVRowDeserializer(),
    hyperParameters = {"objective": "multi:softmax", "num_class" : 2}, # Set parameters for K-Means
    sagemakerRole = IAMRole(role),
    trainingInstanceType = "ml.m4.xlarge",
    trainingInstanceCount = 1,
    endpointInstanceType = "ml.t2.medium",
    endpointInitialInstanceCount = 1,
    trainingSparkDataFormat = "csv",
    namePolicyFactory = RandomNamePolicyFactory("sparksm-4-"),
    endpointCreationPolicy = EndpointCreationPolicy.DO_NOT_CREATE
    )

In [89]:
customModel = estimator.fit(training_data)

Py4JJavaError: An error occurred while calling o1649.fit.
: java.lang.UnsupportedOperationException: CSV data source does not support struct<type:tinyint,size:int,indices:array<int>,values:array<double>> data type.
	at org.apache.spark.sql.execution.datasources.csv.CSVUtils$.org$apache$spark$sql$execution$datasources$csv$CSVUtils$$verifyType$1(CSVUtils.scala:127)
	at org.apache.spark.sql.execution.datasources.csv.CSVUtils$$anonfun$verifySchema$1.apply(CSVUtils.scala:131)
	at org.apache.spark.sql.execution.datasources.csv.CSVUtils$$anonfun$verifySchema$1.apply(CSVUtils.scala:131)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
	at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
	at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99)
	at org.apache.spark.sql.execution.datasources.csv.CSVUtils$.verifySchema(CSVUtils.scala:131)
	at org.apache.spark.sql.execution.datasources.csv.CSVFileFormat.prepareWrite(CSVFileFormat.scala:65)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:140)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:154)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:104)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:102)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.doExecute(commands.scala:122)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:80)
	at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:80)
	at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:656)
	at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:656)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:656)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:273)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:267)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:225)
	at com.amazonaws.services.sagemaker.sparksdk.internal.DataUploader.writeData(DataUploader.scala:111)
	at com.amazonaws.services.sagemaker.sparksdk.internal.DataUploader.uploadData(DataUploader.scala:90)
	at com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator.fit(SageMakerEstimator.scala:301)
	at com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator.fit(SageMakerEstimator.scala:175)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:745)
