En este notebook vamos a realizar la predicción del Iris Dataset utilizando Spark MLib. Comenzaremos importando todas las dependencias necesarias para llevar a cabo el trabajo.

In [163]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.sql.functions import col
from pyspark.ml.classification import DecisionTreeClassifier, GBTClassifier, RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

Creamos la sesión con nombre 'Iris' y cargamos el .csv.

In [164]:
spark = SparkSession.builder.appName('Iris').getOrCreate()
df = spark.read.csv('iris.csv', header = True).cache()
df.show(10)

+------------+-----------+------------+-----------+-------+----------+
|sepal_length|sepal_width|petal_length|petal_width|species|species_id|
+------------+-----------+------------+-----------+-------+----------+
|         5.1|        3.5|         1.4|        0.2| setosa|         1|
|         4.9|        3.0|         1.4|        0.2| setosa|         1|
|         4.7|        3.2|         1.3|        0.2| setosa|         1|
|         4.6|        3.1|         1.5|        0.2| setosa|         1|
|         5.0|        3.6|         1.4|        0.2| setosa|         1|
|         5.4|        3.9|         1.7|        0.4| setosa|         1|
|         4.6|        3.4|         1.4|        0.3| setosa|         1|
|         5.0|        3.4|         1.5|        0.2| setosa|         1|
|         4.4|        2.9|         1.4|        0.2| setosa|         1|
|         4.9|        3.1|         1.5|        0.1| setosa|         1|
+------------+-----------+------------+-----------+-------+----------+
only s

Examinamos los tipos de datos de cada columna

In [165]:
df.dtypes

[('sepal_length', 'string'),
 ('sepal_width', 'string'),
 ('petal_length', 'string'),
 ('petal_width', 'string'),
 ('species', 'string'),
 ('species_id', 'string')]

Vemos que las columnas referentes a las dimensiones de las flores son de tipo 'string'. Vamos a tratar de transformarlas a 'float'. Además, la columna 'species_id' se pasará a formato 'int'.

In [166]:
df = df.select(col('sepal_length').cast('float'),
               col('sepal_width').cast('float'),
               col('petal_length').cast('float'),
               col('petal_width').cast('float'),
               col('species_id').cast('int'),
               col('species'))
df.show()

+------------+-----------+------------+-----------+----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species_id|species|
+------------+-----------+------------+-----------+----------+-------+
|         5.1|        3.5|         1.4|        0.2|         1| setosa|
|         4.9|        3.0|         1.4|        0.2|         1| setosa|
|         4.7|        3.2|         1.3|        0.2|         1| setosa|
|         4.6|        3.1|         1.5|        0.2|         1| setosa|
|         5.0|        3.6|         1.4|        0.2|         1| setosa|
|         5.4|        3.9|         1.7|        0.4|         1| setosa|
|         4.6|        3.4|         1.4|        0.3|         1| setosa|
|         5.0|        3.4|         1.5|        0.2|         1| setosa|
|         4.4|        2.9|         1.4|        0.2|         1| setosa|
|         4.9|        3.1|         1.5|        0.1|         1| setosa|
|         5.4|        3.7|         1.5|        0.2|         1| setosa|
|     

En este caso sabemos que el Dataset no tiene datos nulos así que no nos pararemos a buscarlos. Comprobamos que los formatos son lo correctos.

In [167]:
df.dtypes

[('sepal_length', 'float'),
 ('sepal_width', 'float'),
 ('petal_length', 'float'),
 ('petal_width', 'float'),
 ('species_id', 'int'),
 ('species', 'string')]

Vemos que las dimensiones de las flores ya se han almacenado como 'float' y el identificador de especie como 'int'. Vamos a pasara a utilizar VectorAssembler para seleccionar los 'features' de clasificación.

In [168]:
required_features = ['sepal_length',
                     'sepal_width',
                     'petal_length',
                     'petal_width'
                    ]
assembler = VectorAssembler(inputCols = required_features, outputCol = 'features')
df = assembler.transform(df)

df.show(5)

+------------+-----------+------------+-----------+----------+-------+--------------------+
|sepal_length|sepal_width|petal_length|petal_width|species_id|species|            features|
+------------+-----------+------------+-----------+----------+-------+--------------------+
|         5.1|        3.5|         1.4|        0.2|         1| setosa|[5.09999990463256...|
|         4.9|        3.0|         1.4|        0.2|         1| setosa|[4.90000009536743...|
|         4.7|        3.2|         1.3|        0.2|         1| setosa|[4.69999980926513...|
|         4.6|        3.1|         1.5|        0.2|         1| setosa|[4.59999990463256...|
|         5.0|        3.6|         1.4|        0.2|         1| setosa|[5.0,3.5999999046...|
+------------+-----------+------------+-----------+----------+-------+--------------------+
only showing top 5 rows



En nuestro caso las es de la columna 'species' ya vienen asignadas a los enteros 0,1, y 2 en la columna 'speces_id'. En cualquier caso, vamos a borrar esta última columna y a utilizar StringIndexer para transformar 'species' en una columna categórica.

In [169]:
df = df.drop('species_id')
df = StringIndexer(inputCol = 'species', outputCol = 'species_id', 
                   handleInvalid = 'keep').fit(df).transform(df)
df.show()

+------------+-----------+------------+-----------+-------+--------------------+----------+
|sepal_length|sepal_width|petal_length|petal_width|species|            features|species_id|
+------------+-----------+------------+-----------+-------+--------------------+----------+
|         5.1|        3.5|         1.4|        0.2| setosa|[5.09999990463256...|       0.0|
|         4.9|        3.0|         1.4|        0.2| setosa|[4.90000009536743...|       0.0|
|         4.7|        3.2|         1.3|        0.2| setosa|[4.69999980926513...|       0.0|
|         4.6|        3.1|         1.5|        0.2| setosa|[4.59999990463256...|       0.0|
|         5.0|        3.6|         1.4|        0.2| setosa|[5.0,3.5999999046...|       0.0|
|         5.4|        3.9|         1.7|        0.4| setosa|[5.40000009536743...|       0.0|
|         4.6|        3.4|         1.4|        0.3| setosa|[4.59999990463256...|       0.0|
|         5.0|        3.4|         1.5|        0.2| setosa|[5.0,3.4000000953...|

Una vez hecho esto borramos la columna 'species'

In [170]:
df = df.drop('species')
df.show()

+------------+-----------+------------+-----------+--------------------+----------+
|sepal_length|sepal_width|petal_length|petal_width|            features|species_id|
+------------+-----------+------------+-----------+--------------------+----------+
|         5.1|        3.5|         1.4|        0.2|[5.09999990463256...|       0.0|
|         4.9|        3.0|         1.4|        0.2|[4.90000009536743...|       0.0|
|         4.7|        3.2|         1.3|        0.2|[4.69999980926513...|       0.0|
|         4.6|        3.1|         1.5|        0.2|[4.59999990463256...|       0.0|
|         5.0|        3.6|         1.4|        0.2|[5.0,3.5999999046...|       0.0|
|         5.4|        3.9|         1.7|        0.4|[5.40000009536743...|       0.0|
|         4.6|        3.4|         1.4|        0.3|[4.59999990463256...|       0.0|
|         5.0|        3.4|         1.5|        0.2|[5.0,3.4000000953...|       0.0|
|         4.4|        2.9|         1.4|        0.2|[4.40000009536743...|    

Pasamos a generar la división para entrenar el modelo. Barajamos de forma aleatoria los datos y tomamos 80% entrenamiento y 20% test.

In [171]:
(training_data, test_data) = df.randomSplit([0.8, 0.2])

## Decision Tree Classifier

In [172]:
dt = DecisionTreeClassifier(labelCol = 'species_id',
                             featuresCol = 'features',
                             maxDepth = 5)

Entrenamos el modelo con fit

In [173]:
model1 = dt.fit(training_data)

Generamos la predicción de las especies

In [174]:
pred1 = model1.transform(test_data)
pred1.show(5)

+------------+-----------+------------+-----------+--------------------+----------+------------------+-----------------+----------+
|sepal_length|sepal_width|petal_length|petal_width|            features|species_id|     rawPrediction|      probability|prediction|
+------------+-----------+------------+-----------+--------------------+----------+------------------+-----------------+----------+
|         4.4|        3.0|         1.3|        0.2|[4.40000009536743...|       0.0|[39.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.4|        3.2|         1.3|        0.2|[4.40000009536743...|       0.0|[39.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.6|        3.6|         1.0|        0.2|[4.59999990463256...|       0.0|[39.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.7|        3.2|         1.6|        0.2|[4.69999980926513...|       0.0|[39.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.8|        3.0|         1.4|        0.1|[4.80000019073486...|    

Pasamos a obtener la precisión del modelo utilizando MulticlassClassificationEvaluator. 

In [175]:
eval1 = MulticlassClassificationEvaluator(labelCol = 'species_id', predictionCol = 'prediction',
                                          metricName = 'accuracy')
accuracy1 = eval1.evaluate(pred1)
print("Test accuracy with DTC => ", accuracy1)

Test accuracy with DTC =>  0.9642857142857143


Se obitene una precisión del 96.43%

## Gradient-Booster tree classifier

In [178]:
gbt = GBTClassifier(labelCol = 'species_id', featuresCol = 'features',
                    maxIter = 10)

Obtenemos el modelo y la predicción

In [179]:
model2 = gbt.fit(training_data)
pred2 = model2.transform(test_data)
pred2.select('prediction','probability','features').show(5)

Py4JJavaError: An error occurred while calling o3174.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 299.0 failed 1 times, most recent failure: Lost task 0.0 in stage 299.0 (TID 288) (c4e9c0c4c96b executor driver): java.lang.IllegalArgumentException: requirement failed: GBTClassifier was given dataset with invalid label 2.0.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$2(GBTClassifier.scala:176)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$2$adapted(GBTClassifier.scala:173)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$2(Predictor.scala:96)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$2(RDD.scala:1207)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2308)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2454)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2403)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2402)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2402)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1160)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2642)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2584)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2573)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2214)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2309)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$1(RDD.scala:1209)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
	at org.apache.spark.rdd.RDD.aggregate(RDD.scala:1202)
	at org.apache.spark.ml.tree.impl.DecisionTreeMetadata$.buildMetadata(DecisionTreeMetadata.scala:125)
	at org.apache.spark.ml.tree.impl.GradientBoostedTrees$.boost(GradientBoostedTrees.scala:333)
	at org.apache.spark.ml.tree.impl.GradientBoostedTrees$.run(GradientBoostedTrees.scala:61)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$1(GBTClassifier.scala:209)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.GBTClassifier.train(GBTClassifier.scala:170)
	at org.apache.spark.ml.classification.GBTClassifier.train(GBTClassifier.scala:58)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
	at jdk.internal.reflect.GeneratedMethodAccessor136.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.lang.IllegalArgumentException: requirement failed: GBTClassifier was given dataset with invalid label 2.0.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$2(GBTClassifier.scala:176)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$2$adapted(GBTClassifier.scala:173)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$2(Predictor.scala:96)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$2(RDD.scala:1207)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2308)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more


En este caso vemos como el GradientBoost Tree Classifier no es una opción adecuada para este tipo de problema, pues únicamente da soporte a clasificación binaria mientras que nuestro problema supone clasificar en tres catergorías. Omitimos por tanto el cálculo de la precisión y pasamos a hacer uso del Random Forest Classifier.

## Random Forest Classifier

In [180]:
rfc = RandomForestClassifier(labelCol = 'species_id',
                             featuresCol = 'features')

In [181]:
model3 = rfc.fit(training_data)

In [182]:
pred3 = model3.transform(test_data)
pred3.show(5)

+------------+-----------+------------+-----------+--------------------+----------+------------------+-----------------+----------+
|sepal_length|sepal_width|petal_length|petal_width|            features|species_id|     rawPrediction|      probability|prediction|
+------------+-----------+------------+-----------+--------------------+----------+------------------+-----------------+----------+
|         4.4|        3.0|         1.3|        0.2|[4.40000009536743...|       0.0|[20.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.4|        3.2|         1.3|        0.2|[4.40000009536743...|       0.0|[20.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.6|        3.6|         1.0|        0.2|[4.59999990463256...|       0.0|[20.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.7|        3.2|         1.6|        0.2|[4.69999980926513...|       0.0|[20.0,0.0,0.0,0.0]|[1.0,0.0,0.0,0.0]|       0.0|
|         4.8|        3.0|         1.4|        0.1|[4.80000019073486...|    

Pasamos a obtener la precisión del modelo utilizando MulticlassClassificationEvaluator. 

In [183]:
eval3 = MulticlassClassificationEvaluator(labelCol = 'species_id', predictionCol = 'prediction',
                                          metricName = 'accuracy')
accuracy3 = eval3.evaluate(pred3)
print("Test accuracy with RFC => ", accuracy3)

Test accuracy with RFC =>  0.9642857142857143


Se obtiene una precisión del 96.42%