In [1]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488493 sha256=abf823cd865b884efe3ab898f0fd22ce303f61dd3c85ea0ee5ec9acfc69943b2
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1


In [4]:
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler

In [6]:
spark = SparkSession.builder.appName("DecisionTreeExample").getOrCreate()

Read data

In [7]:
data = spark.read.option("header", "true").option("inferSchema", "true").csv("realestate.csv")

In [11]:
data.show(2)

+---+---------------+--------+-------------+-----------------------+--------+---------+---------------+
| No|TransactionDate|HouseAge|DistanceToMRT|NumberConvenienceStores|Latitude|Longitude|PriceOfUnitArea|
+---+---------------+--------+-------------+-----------------------+--------+---------+---------------+
|  1|       2012.917|    32.0|     84.87882|                     10|24.98298|121.54024|           37.9|
|  2|       2012.917|    19.5|     306.5947|                      9|24.98034|121.53951|           42.2|
+---+---------------+--------+-------------+-----------------------+--------+---------+---------------+
only showing top 2 rows



In [13]:
assembler = VectorAssembler().setInputCols(["HouseAge", "DistanceToMRT", "NumberConvenienceStores", "Latitude", "Longitude"]).setOutputCol("features")

In [14]:
df = assembler.transform(data).select("PriceOfUnitArea", "features")

In [16]:
trainTest = df.randomSplit([0.8,0.2])
trainDF = trainTest[0]
testDF = trainTest[1]

Initialize DecisionTreeRegressor:

From https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.regression.DecisionTreeRegressor.html:

class pyspark.ml.regression.DecisionTreeRegressor(*, featuresCol: str = 'features', labelCol: str = 'label', predictionCol: str = 'prediction', maxDepth: int = 5, maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, impurity: str = 'variance', seed: Optional[int] = None, varianceCol: Optional[str] = None, weightCol: Optional[str] = None, leafCol: str = '', minWeightFractionPerNode: float = 0.0)[source]¶

In [17]:
spark_DecisionTree = DecisionTreeRegressor(featuresCol = 'features', labelCol='PriceOfUnitArea')

In [18]:
model = spark_DecisionTree.fit(trainDF)

In [21]:
predictions = model.transform(testDF).cache()

In [23]:
predictions.show(5)

+---------------+--------------------+------------------+
|PriceOfUnitArea|            features|        prediction|
+---------------+--------------------+------------------+
|           12.8|[16.5,4082.015,0....|           15.6625|
|           15.5|[26.9,4449.27,0.0...| 18.08888888888889|
|           17.4|[27.1,4412.765,1....|12.899999999999999|
|           18.6|[13.5,4197.349,0....|16.100000000000023|
|           20.5|[16.3,4066.587,0....|           15.6625|
+---------------+--------------------+------------------+
only showing top 5 rows



We convert to rdd as they are easy to deal with while extarcting values compared to dataframe.

In [24]:
predicted_values = predictions.select("prediction").rdd.map(lambda x:x[0])
label_values = predictions.select("PriceOfUnitArea").rdd.map(lambda x:x[0])

In [26]:
predicted_values_and_labels = predicted_values.zip(label_values).collect()

In [30]:
for prediction in predicted_values_and_labels:
  print(prediction)

(15.6625, 12.8)
(18.08888888888889, 15.5)
(12.899999999999999, 17.4)
(16.100000000000023, 18.6)
(15.6625, 20.5)
(18.08888888888889, 20.7)
(26.586567164179108, 21.3)
(26.450000000000003, 21.5)
(26.586567164179108, 22.3)
(12.899999999999999, 22.6)
(26.586567164179108, 23.2)
(26.586567164179108, 23.5)
(37.21666666666667, 23.5)
(26.586567164179108, 24.7)
(18.1, 24.7)
(26.586567164179108, 24.8)
(26.586567164179108, 25.6)
(26.586567164179108, 25.7)
(26.450000000000003, 27.0)
(26.586567164179108, 27.7)
(26.586567164179108, 28.4)
(26.586567164179108, 29.3)
(26.586567164179108, 29.4)
(26.586567164179108, 29.5)
(39.64090909090909, 30.0)
(28.720000000000006, 30.5)
(35.03333333333334, 30.9)
(39.64090909090909, 35.5)
(26.586567164179108, 35.6)
(39.64090909090909, 35.7)
(28.720000000000006, 36.8)
(35.03333333333334, 37.5)
(35.03333333333334, 38.1)
(35.03333333333334, 38.3)
(51.52857142857143, 38.4)
(39.64090909090909, 38.8)
(39.64090909090909, 40.3)
(41.71111111111111, 40.5)
(28.720000000000006, 40.

In [31]:
spark.stop()