In [5]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import count, when, col

In [6]:
# Build the SparkSession
spark = SparkSession.builder \
   .master("local") \
   .appName("DDAM Project") \
   .config("spark.executor.memory", "1gb") \
   .getOrCreate()
   
sc = spark.sparkContext

## Import

In [7]:
df_boxscore = spark.read.csv("data/boxscore_clean.csv", header=True, inferSchema=True)

In [8]:
print(df_boxscore.count())

41633


In [9]:
# Renaming
df_boxscore = df_boxscore.withColumnRenamed("pos_clean", "Position")

In [10]:
df_boxscore.printSchema()

root
 |-- playerName: string (nullable = true)
 |-- game_id: integer (nullable = true)
 |-- teamName: string (nullable = true)
 |-- FG: integer (nullable = true)
 |-- FGA: integer (nullable = true)
 |-- 3P: integer (nullable = true)
 |-- 3PA: integer (nullable = true)
 |-- FT: integer (nullable = true)
 |-- FTA: integer (nullable = true)
 |-- ORB: integer (nullable = true)
 |-- DRB: integer (nullable = true)
 |-- TRB: integer (nullable = true)
 |-- AST: integer (nullable = true)
 |-- STL: integer (nullable = true)
 |-- BLK: integer (nullable = true)
 |-- TOV: integer (nullable = true)
 |-- PF: integer (nullable = true)
 |-- PTS: integer (nullable = true)
 |-- +/-: integer (nullable = true)
 |-- isStarter: integer (nullable = true)
 |-- seasonStartYear: integer (nullable = true)
 |-- isRegular: integer (nullable = true)
 |-- Ht: string (nullable = true)
 |-- Wt: double (nullable = true)
 |-- MP_seconds: integer (nullable = true)
 |-- Position: string (nullable = true)



## Position one-hot-enconding

In [11]:
from pyspark.ml.feature import Imputer
from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol="Position", outputCol="PosNum", handleInvalid="skip")

df_boxscore = indexer.fit(df_boxscore).transform(df_boxscore)

In [12]:
from pyspark.ml.feature import OneHotEncoder
onehotencoder_embarked_vector = OneHotEncoder(inputCol="PosNum", outputCol="PosVec")
df_boxscore = onehotencoder_embarked_vector.fit(df_boxscore).transform(df_boxscore)
df_boxscore.show()

+----------------+-------+--------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---------+---------------+---------+----+-----+----------+--------+------+-------------+
|      playerName|game_id|            teamName| FG|FGA| 3P|3PA| FT|FTA|ORB|DRB|TRB|AST|STL|BLK|TOV| PF|PTS|+/-|isStarter|seasonStartYear|isRegular|  Ht|   Wt|MP_seconds|Position|PosNum|       PosVec|
+----------------+-------+--------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---------+---------------+---------+----+-----+----------+--------+------+-------------+
|   Desmond Mason|   4577| Seattle SuperSonics|  3|  9|  0|  3|  1|  2|  3|  2|  5|  2|  0|  1|  1|  5|  7|-15|        1|           2000|        1| 6-7|224.0|      1021|       F|   1.0|(2,[1],[1.0])|
|Rubén Wolkowyski|   4577| Seattle SuperSonics|  0|  4|  0|  0|  0|  0|  1|  1|  2|  0|  0|  1|  0|  1|  0| -2|        0|           2000|        1|6-10|270.0|       707|       F|   1.0|(2,[1],[1.0])|


In [13]:
column2drop = ('playerName', 'game_id', 'teamName', 'seasonStartYear', 'Ht', 'Pos', 'Position')
df_boxscore = df_boxscore.drop(*column2drop)

In [14]:
df_boxscore.printSchema()

root
 |-- FG: integer (nullable = true)
 |-- FGA: integer (nullable = true)
 |-- 3P: integer (nullable = true)
 |-- 3PA: integer (nullable = true)
 |-- FT: integer (nullable = true)
 |-- FTA: integer (nullable = true)
 |-- ORB: integer (nullable = true)
 |-- DRB: integer (nullable = true)
 |-- TRB: integer (nullable = true)
 |-- AST: integer (nullable = true)
 |-- STL: integer (nullable = true)
 |-- BLK: integer (nullable = true)
 |-- TOV: integer (nullable = true)
 |-- PF: integer (nullable = true)
 |-- PTS: integer (nullable = true)
 |-- +/-: integer (nullable = true)
 |-- isStarter: integer (nullable = true)
 |-- isRegular: integer (nullable = true)
 |-- Wt: double (nullable = true)
 |-- MP_seconds: integer (nullable = true)
 |-- PosNum: double (nullable = false)
 |-- PosVec: vector (nullable = true)



## Data preparation

In [15]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler

In [16]:
num_col = [item[0] for item in df_boxscore.dtypes if not item[1].startswith('string')]
num_col.remove("PosNum")
num_col.remove("Wt")
num_col.remove("PosVec")
print(num_col)

['FG', 'FGA', '3P', '3PA', 'FT', 'FTA', 'ORB', 'DRB', 'TRB', 'AST', 'STL', 'BLK', 'TOV', 'PF', 'PTS', '+/-', 'isStarter', 'isRegular', 'MP_seconds']


In [17]:
assembler = VectorAssembler(inputCols=num_col, outputCol="features")

output_dataset = assembler.transform(df_boxscore)

classificationData = output_dataset.select("features", "PosNum")

classificationData.show(truncate=False)

+------------------------------------------------------------------------------------+------+
|features                                                                            |PosNum|
+------------------------------------------------------------------------------------+------+
|[3.0,9.0,0.0,3.0,1.0,2.0,3.0,2.0,5.0,2.0,0.0,1.0,1.0,5.0,7.0,-15.0,1.0,1.0,1021.0]  |1.0   |
|(19,[1,6,7,8,11,13,15,17,18],[4.0,1.0,1.0,2.0,1.0,1.0,-2.0,1.0,707.0])              |1.0   |
|(19,[1,6,8,15,17,18],[1.0,1.0,1.0,7.0,1.0,534.0])                                   |1.0   |
|(19,[1,6,7,8,9,13,15,17,18],[2.0,1.0,2.0,3.0,2.0,2.0,-7.0,1.0,433.0])               |1.0   |
|(19,[6,8,17,18],[1.0,1.0,1.0,20.0])                                                 |0.0   |
|[2.0,7.0,0.0,0.0,3.0,6.0,8.0,3.0,11.0,0.0,1.0,0.0,0.0,4.0,7.0,-10.0,0.0,1.0,1635.0] |0.0   |
|[3.0,5.0,0.0,0.0,1.0,4.0,0.0,2.0,2.0,0.0,1.0,0.0,0.0,1.0,7.0,-4.0,0.0,1.0,464.0]    |2.0   |
|(19,[13,15,17,18],[1.0,-2.0,1.0,342.0])                    

In [18]:
(trainingData, testData) = classificationData.randomSplit([0.7, 0.3],seed=0)

In [19]:
trainingData.show(50, truncate=False)

+------------------------------------------------------------------------------------+------+
|features                                                                            |PosNum|
+------------------------------------------------------------------------------------+------+
|(19,[],[])                                                                          |1.0   |
|(19,[0,1,2,3,4,5,6,8,14,17,18],[1.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0,4.0,1.0,645.0])     |1.0   |
|(19,[0,1,2,3,4,5,7,8,14,17,18],[5.0,7.0,3.0,3.0,4.0,6.0,1.0,1.0,17.0,1.0,846.0])    |1.0   |
|(19,[0,1,2,3,4,5,9,13,14,17,18],[4.0,9.0,1.0,2.0,6.0,7.0,1.0,1.0,15.0,1.0,1327.0])  |1.0   |
|(19,[0,1,2,3,4,5,9,14,15,17,18],[1.0,1.0,1.0,1.0,2.0,2.0,1.0,5.0,-1.0,1.0,1031.0])  |0.0   |
|(19,[0,1,2,3,4,5,9,14,15,17,18],[1.0,1.0,1.0,1.0,2.0,2.0,1.0,5.0,2.0,1.0,260.0])    |0.0   |
|(19,[0,1,2,3,4,5,9,14,15,17,18],[1.0,7.0,1.0,4.0,2.0,2.0,5.0,5.0,-17.0,1.0,1049.0]) |0.0   |
|(19,[0,1,2,3,4,5,9,14,15,17,18],[1.0,7.0,1.0,4.0,2.0,2.0,5.

## Decision Tree

In [20]:
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier

dt = DecisionTreeClassifier(labelCol="PosNum", featuresCol="features")

dt = dt.fit(trainingData)

In [21]:
dt.featureImportances

SparseVector(19, {3: 0.4584, 6: 0.1168, 7: 0.0249, 8: 0.1153, 9: 0.2077, 11: 0.0543, 13: 0.0082, 16: 0.0145})

In [22]:
# Make predictions
predictions = dt.transform(testData)

predictions.show(20)

+--------------------+------+--------------------+--------------------+----------+
|            features|PosNum|       rawPrediction|         probability|prediction|
+--------------------+------+--------------------+--------------------+----------+
|          (19,[],[])|   0.0|[1050.0,832.0,385.0]|[0.46316718129686...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|  [2493.0,348.0,6.0]|[0.87565858798735...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0| [2021.0,545.0,35.0]|[0.77700884275278...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|  [2493.0,348.0,6.0]|[0.87565858798735...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0| [2021.0,545.0,35.0]|[0.77700884275278...|       0.0|
|(19,[0,1,2,3,4,5,...|   1.0|  [131.0,138.0,19.0]|[0.45486111111111...|       1.0|
|(19,[0,1,2,3,4,5,...|   1.0|  [131.0,138.0,19.0]|[0.45486111111111...|       1.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1312.0,646.0,105.0]|[0.63596703829374...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1312.0,646.0,105.0]|[0.63596703829374...|       0.0|
|(19

In [23]:
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(labelCol="PosNum"
                                              , predictionCol="prediction"
                                              , metricName="accuracy")

accuracy = evaluator.evaluate(predictions)
print(accuracy)
print("Test Error = %g" % (1.0 - accuracy))

0.5814176245210728
Test Error = 0.418582


## Gradient-Boosted Tree

In [25]:
from pyspark.ml.classification import GBTClassifier

# Create a GBTClassifier instance
gbt = GBTClassifier(labelCol="PosNum", featuresCol="features", maxIter=10)

# Train the model
model = gbt.fit(trainingData)

# Make predictions on the test data
predictions = model.transform(testData)

# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Accuracy:", accuracy)
print("Test Error = %g" % (1.0 - accuracy))

# You can also view the feature importances
feature_importances = model.featureImportances
print("Feature Importances:", feature_importances)

Py4JJavaError: An error occurred while calling o353.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 36.0 failed 1 times, most recent failure: Lost task 0.0 in stage 36.0 (TID 33) (192.168.1.8 executor driver): java.lang.RuntimeException: Labels MUST be in {0, 1}, but got 2.0
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$2(RDD.scala:1223)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2492)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1589)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2844)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2780)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2779)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2779)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1242)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3048)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2982)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2971)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:984)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2493)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$1(RDD.scala:1225)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
	at org.apache.spark.rdd.RDD.aggregate(RDD.scala:1218)
	at org.apache.spark.ml.tree.impl.DecisionTreeMetadata$.buildMetadata(DecisionTreeMetadata.scala:125)
	at org.apache.spark.ml.tree.impl.GradientBoostedTrees$.boost(GradientBoostedTrees.scala:333)
	at org.apache.spark.ml.tree.impl.GradientBoostedTrees$.run(GradientBoostedTrees.scala:61)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$1(GBTClassifier.scala:201)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.GBTClassifier.train(GBTClassifier.scala:170)
	at org.apache.spark.ml.classification.GBTClassifier.train(GBTClassifier.scala:58)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:114)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:76)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:52)
	at java.base/java.lang.reflect.Method.invoke(Method.java:578)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:1589)
Caused by: java.lang.RuntimeException: Labels MUST be in {0, 1}, but got 2.0
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$2(RDD.scala:1223)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2492)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	... 1 more


In [None]:
# Save the model for future use
#model.save("path_to_save_model")

# Load the saved model
#loaded_model = GBTClassifier.load("path_to_saved_model")

## Multilayer Perceptron Classifier

In [34]:
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

layers = [19, 20, 10, 3]  # Define the layers of the neural network
mlp = MultilayerPerceptronClassifier(labelCol="PosNum", featuresCol="features", layers=layers, seed=1234)

model = mlp.fit(trainingData)

predictions = model.transform(testData)

# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="PosNum", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Accuracy:", accuracy)

Accuracy: 0.5830938697318008


In [35]:
predictions.show(20)

+--------------------+------+--------------------+--------------------+----------+
|            features|PosNum|       rawPrediction|         probability|prediction|
+--------------------+------+--------------------+--------------------+----------+
|          (19,[],[])|   0.0|[0.22675666490199...|[0.34289128806061...|       1.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1.40293352602737...|[0.73677563936177...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1.37340800021238...|[0.72848575524147...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1.40659596981844...|[0.73781480705336...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1.40506285250744...|[0.73737955780179...|       0.0|
|(19,[0,1,2,3,4,5,...|   1.0|[0.79855463231135...|[0.53155658310424...|       0.0|
|(19,[0,1,2,3,4,5,...|   1.0|[1.26167844519160...|[0.69468219473418...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1.40385532799476...|[0.73703750720679...|       0.0|
|(19,[0,1,2,3,4,5,...|   0.0|[1.19087804004461...|[0.67253631433828...|       0.0|
|(19

In [36]:
model

MultilayerPerceptronClassificationModel: uid=MultilayerPerceptronClassifier_dd3785011acc, numLayers=4, numClasses=3, numFeatures=19