Implementing a simple random forest classifier in pyspark's MLlib - using the same vegetation cover dataset

In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
import pyspark.sql.functions as f
import os, sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

In [2]:
from pyspark.sql.types import DoubleType

In [3]:
from pyspark.ml.feature import VectorAssembler

In [4]:
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.ml.feature import VectorIndexer

In [5]:
from pyspark.ml.classification import RandomForestClassifier

In [6]:
conf = SparkConf()
conf.set("spark.app.name","trees")
conf.set("spark.master","local[*]")
conf.set("spark.driver.memory","8g")

<pyspark.conf.SparkConf at 0x15effc900d0>

In [7]:
spark = SparkSession.builder\
                    .config(conf=conf)\
                    .getOrCreate()

Read in the data - (hint it has no headers)

In [8]:
data_without_header = spark.read\
                           .format("csv")\
                           .option("header",False)\
                           .option("inferSchema",True)\
                           .load(r"C:\Users\blais\Documents\ML\data\covertype\covtype.data")

In [9]:
data_without_header.printSchema()

root
 |-- _c0: integer (nullable = true)
 |-- _c1: integer (nullable = true)
 |-- _c2: integer (nullable = true)
 |-- _c3: integer (nullable = true)
 |-- _c4: integer (nullable = true)
 |-- _c5: integer (nullable = true)
 |-- _c6: integer (nullable = true)
 |-- _c7: integer (nullable = true)
 |-- _c8: integer (nullable = true)
 |-- _c9: integer (nullable = true)
 |-- _c10: integer (nullable = true)
 |-- _c11: integer (nullable = true)
 |-- _c12: integer (nullable = true)
 |-- _c13: integer (nullable = true)
 |-- _c14: integer (nullable = true)
 |-- _c15: integer (nullable = true)
 |-- _c16: integer (nullable = true)
 |-- _c17: integer (nullable = true)
 |-- _c18: integer (nullable = true)
 |-- _c19: integer (nullable = true)
 |-- _c20: integer (nullable = true)
 |-- _c21: integer (nullable = true)
 |-- _c22: integer (nullable = true)
 |-- _c23: integer (nullable = true)
 |-- _c24: integer (nullable = true)
 |-- _c25: integer (nullable = true)
 |-- _c26: integer (nullable = true)
 |-- _

We know what the column names are from inspecting the info file - put the col names in a list and then attach them to the dataframe.

In [10]:
colnames = ["Elevation", "Aspect", "Slope", \
            "Horizontal_Distance_To_Hydrology", \
            "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways", \
            "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm", \
            "Horizontal_Distance_To_Fire_Points"] + \
           [f"Wilderness_Area_{i}" for i in range(4)] + \
           [f"Soil_Type_{i}" for i in range(40)] + \
           ["Cover_Type"]

Adding the col names to the dataframe:

In [11]:
data = data_without_header.toDF(*colnames).withColumn("Cover_Type",f.col("Cover_Type").cast(DoubleType()))

Also - our decision tree algorithm will generally perform better if we convert the onehot encoded values back to categoricals.

In [12]:
def decode_one_hot(data):
    wilderness_cols = [f"Wilderness_Area_{i}" for i in range(4)]
    wilderness_assembler = VectorAssembler()\
                                            .setInputCols(wilderness_cols)\
                                            .setOutputCol("wilderness")
    unhot_udf = f.udf(lambda v: v.toArray().tolist().index(1))
    with_wilderness = wilderness_assembler.transform(data)\
                                                         .withColumn("wilderness", unhot_udf(f.col("wilderness")))\
                                                         .drop(*wilderness_cols)
    soil_cols = [f"Soil_Type_{i}" for i in range(40)]
    soil_assembler = VectorAssembler().setInputCols(soil_cols).setOutputCol("soil")
    with_soil = soil_assembler.transform(with_wilderness).withColumn("soil", unhot_udf(f.col("soil"))).drop(*soil_cols)
    return with_soil

In [13]:
decoded_data = decode_one_hot(data)

In [14]:
decoded_data.printSchema()

root
 |-- Elevation: integer (nullable = true)
 |-- Aspect: integer (nullable = true)
 |-- Slope: integer (nullable = true)
 |-- Horizontal_Distance_To_Hydrology: integer (nullable = true)
 |-- Vertical_Distance_To_Hydrology: integer (nullable = true)
 |-- Horizontal_Distance_To_Roadways: integer (nullable = true)
 |-- Hillshade_9am: integer (nullable = true)
 |-- Hillshade_Noon: integer (nullable = true)
 |-- Hillshade_3pm: integer (nullable = true)
 |-- Horizontal_Distance_To_Fire_Points: integer (nullable = true)
 |-- Cover_Type: double (nullable = true)
 |-- wilderness: string (nullable = true)
 |-- soil: string (nullable = true)



In [15]:
# need to create a special unique column that incorporates all the unique string and wilderness types and add an index as well

In [16]:
decoded_data = decoded_data\
                            .withColumn("id", f.monotonically_increasing_id())\
                            .withColumn("wilderness_soil", f.expr("CONCAT(wilderness,'-',soil)"))

Creating a 90-10% stratified split on the wilderness-soil column

In [17]:
unique_wilderness_soil = list(map(lambda x: x.wilderness_soil, decoded_data.select("wilderness_soil").distinct().collect()))

In [18]:
train_data = decoded_data.sampleBy("wilderness_soil",{label:0.9 for label in unique_wilderness_soil},seed=42)

In [19]:
train_data.cache()

DataFrame[Elevation: int, Aspect: int, Slope: int, Horizontal_Distance_To_Hydrology: int, Vertical_Distance_To_Hydrology: int, Horizontal_Distance_To_Roadways: int, Hillshade_9am: int, Hillshade_Noon: int, Hillshade_3pm: int, Horizontal_Distance_To_Fire_Points: int, Cover_Type: double, wilderness: string, soil: string, id: bigint, wilderness_soil: string]

In [20]:
train_data_ids = tuple(map(lambda x: x.id, train_data.select("id").distinct().collect()))

In [21]:
train_data.count()

522991

In [22]:
test_data = decoded_data.where(f"id NOT IN {train_data_ids}")

In [23]:
test_data.cache()

DataFrame[Elevation: int, Aspect: int, Slope: int, Horizontal_Distance_To_Hydrology: int, Vertical_Distance_To_Hydrology: int, Horizontal_Distance_To_Roadways: int, Hillshade_9am: int, Hillshade_Noon: int, Hillshade_3pm: int, Horizontal_Distance_To_Fire_Points: int, Cover_Type: double, wilderness: string, soil: string, id: bigint, wilderness_soil: string]

In [24]:
test_data.show(5)

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "c:\Users\blais\Documents\ML\venv\lib\site-packages\py4j\java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "c:\Users\blais\Documents\ML\venv\lib\site-packages\py4j\clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "C:\Users\blais\AppData\Local\Programs\Python\Python310\lib\socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
test_data.count()

58021

In [None]:
decoded_data.count()

581012

**Just Performed a 90-10 stratified train-test split:**

drop unecessary columns now:

In [None]:
train_data = train_data.drop("id","wilderness_soil")
test_data = test_data.drop("id","wilderness_soil")

In [None]:
train_data = train_data.withColumn("soil",f.col("soil").cast(DoubleType())).withColumn("wilderness",f.col("wilderness").cast(DoubleType()))
test_data = test_data.withColumn("soil",f.col("soil").cast(DoubleType())).withColumn("wilderness",f.col("wilderness").cast(DoubleType()))

Need to write my pipeline and train

Time for train - test splits - these should be stratified though - to ensure each split includes the same proportion of wilderness and soil. 

In [None]:
inputCols = [c for c in train_data.columns if c!="Cover_Type"]

First is the assembler - 

In [None]:
assembler = VectorAssembler().setInputCols(inputCols).setOutputCol("featureVector")

In [None]:
indexer = VectorIndexer()\
                        .setMaxCategories(40)\
                        .setInputCol("featureVector")\
                        .setOutputCol("indexedVector")

In [None]:
classifier = RandomForestClassifier(seed=1234, labelCol="Cover_Type",
                                    featuresCol="indexedVector",
                                    predictionCol="prediction")

In [None]:
pipeline = Pipeline().setStages([assembler, indexer, classifier])

In [None]:
paramGrid = ParamGridBuilder().\
                              addGrid(classifier.impurity, ["gini","entropy"])\
                              .addGrid(classifier.maxDepth, [1,20])\
                              .addGrid(classifier.maxBins,[40,300])\
                              .addGrid(classifier.minInfoGain, [0.0, 0.05])\
                              .build()

In [None]:
multiclassEval = MulticlassClassificationEvaluator()\
                                                    .setLabelCol("Cover_Type")\
                                                    .setPredictionCol("prediction")\
                                                    .setMetricName("accuracy")

ConnectionRefusedError: [WinError 10061] No connection could be made because the target machine actively refused it

In [None]:
validator = TrainValidationSplit(seed=1234,
                                estimator=pipeline,
                                evaluator=multiclassEval,
                                estimatorParamMaps=paramGrid,
                                trainRatio=0.9)

In [None]:
validator_model = validator.fit(train_data)

Py4JJavaError: An error occurred while calling o9177.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 256.0 failed 1 times, most recent failure: Lost task 1.0 in stage 256.0 (TID 12578) (host.docker.internal executor driver): java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.Double.valueOf(Double.java:524)
	at scala.runtime.BoxesRunTime.boxToDouble(BoxesRunTime.java:81)
	at org.apache.spark.ml.tree.CategoricalSplit.shouldGoLeft(Split.scala:109)
	at org.apache.spark.ml.tree.LearningNode.predictImpl(Node.scala:343)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$8(RandomForest.scala:560)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$8$adapted(RandomForest.scala:558)
	at org.apache.spark.ml.tree.impl.RandomForest$$$Lambda$6231/0x0000000801c66840.apply(Unknown Source)
	at scala.collection.immutable.HashMap$HashMap1.foreach(HashMap.scala:400)
	at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:728)
	at org.apache.spark.ml.tree.impl.RandomForest$.binSeqOp$1(RandomForest.scala:558)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$24(RandomForest.scala:655)
	at org.apache.spark.ml.tree.impl.RandomForest$$$Lambda$6230/0x0000000801c66040.apply(Unknown Source)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$21(RandomForest.scala:655)
	at org.apache.spark.ml.tree.impl.RandomForest$$$Lambda$6214/0x0000000801c45840.apply(Unknown Source)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.RDD$$Lambda$1960/0x0000000800d4b040.apply(Unknown Source)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:621)
	at org.apache.spark.executor.Executor$TaskRunner$$Lambda$1727/0x0000000800c67440.apply(Unknown Source)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2898)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2834)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2833)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2833)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1253)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1253)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1253)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3102)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3036)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3025)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:995)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:738)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:737)
	at org.apache.spark.ml.tree.impl.RandomForest$.findBestSplits(RandomForest.scala:663)
	at org.apache.spark.ml.tree.impl.RandomForest$.runBagged(RandomForest.scala:208)
	at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:302)
	at org.apache.spark.ml.classification.RandomForestClassifier.$anonfun$train$1(RandomForestClassifier.scala:168)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:139)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:47)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:114)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:78)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.Double.valueOf(Double.java:524)
	at scala.runtime.BoxesRunTime.boxToDouble(BoxesRunTime.java:81)
	at org.apache.spark.ml.tree.CategoricalSplit.shouldGoLeft(Split.scala:109)
	at org.apache.spark.ml.tree.LearningNode.predictImpl(Node.scala:343)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$8(RandomForest.scala:560)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$8$adapted(RandomForest.scala:558)
	at org.apache.spark.ml.tree.impl.RandomForest$$$Lambda$6231/0x0000000801c66840.apply(Unknown Source)
	at scala.collection.immutable.HashMap$HashMap1.foreach(HashMap.scala:400)
	at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:728)
	at org.apache.spark.ml.tree.impl.RandomForest$.binSeqOp$1(RandomForest.scala:558)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$24(RandomForest.scala:655)
	at org.apache.spark.ml.tree.impl.RandomForest$$$Lambda$6230/0x0000000801c66040.apply(Unknown Source)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findBestSplits$21(RandomForest.scala:655)
	at org.apache.spark.ml.tree.impl.RandomForest$$$Lambda$6214/0x0000000801c45840.apply(Unknown Source)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.RDD$$Lambda$1960/0x0000000800d4b040.apply(Unknown Source)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:621)
	at org.apache.spark.executor.Executor$TaskRunner$$Lambda$1727/0x0000000800c67440.apply(Unknown Source)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)


ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "c:\Users\blais\Documents\ML\venv\lib\site-packages\py4j\clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "C:\Users\blais\AppData\Local\Programs\Python\Python310\lib\socket.py", line 705, in readinto
    return self._sock.recv_into(b)
ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\Users\blais\Documents\ML\venv\lib\site-packages\py4j\java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "c:\Users\blais\Documents\ML\venv\lib\site-packages\py4j\clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving
ERROR:root:Exception while sending command.
Traceback (mos