In [1]:
import findspark

findspark.init('C:/Users/Bruno/anaconda3/Lib/site-packages/Spark')

In [2]:
from __future__ import print_function

from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler

if __name__ == "__main__":

    # Create a SparkSession (Note, the config section is only for Windows!)
    spark = SparkSession.builder.appName("DecisionTree").getOrCreate()

    
    # Load up data as dataframe
    data = spark.read.option("header", "true").option("inferSchema", "true")\
        .csv("C:/Users/Bruno/Desktop/Python/Projectos Big Data/Spark/realestate.csv")

    assembler = VectorAssembler().setInputCols(["HouseAge", "DistanceToMRT", \
                               "NumberConvenienceStores"]).setOutputCol("features")
    
    df = assembler.transform(data).select("PriceOfUnitArea", "features")

    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

    # Now create our decision tree
    dtr = DecisionTreeRegressor().setFeaturesCol("features").setLabelCol("PriceOfUnitArea")

    # Train the model using our training data
    model = dtr.fit(trainingDF)

    # Now see if we can predict values in our test data.
    # Generate predictions using our decision tree model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("PriceOfUnitArea").rdd.map(lambda x: x[0])

    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
      print(prediction)


    # Stop the session
    spark.stop()

(20.0, 11.2)
(26.758333333333336, 11.6)
(20.0, 12.2)
(26.758333333333336, 12.9)
(17.799999999999997, 13.0)
(20.0, 13.2)
(17.799999999999997, 13.4)
(20.100000000000005, 13.7)
(16.599999999999998, 14.7)
(20.0, 15.0)
(16.599999999999998, 15.6)
(17.799999999999997, 15.6)
(20.0, 17.4)
(26.758333333333336, 18.3)
(17.799999999999997, 18.6)
(20.0, 18.8)
(20.100000000000005, 19.1)
(17.799999999999997, 19.2)
(26.758333333333336, 20.5)
(20.100000000000005, 20.7)
(26.758333333333336, 20.9)
(27.857142857142854, 21.4)
(26.758333333333336, 21.8)
(43.2, 22.0)
(26.758333333333336, 22.1)
(43.2, 22.3)
(26.758333333333336, 22.3)
(20.100000000000005, 22.8)
(20.100000000000005, 22.8)
(26.758333333333336, 22.9)
(26.758333333333336, 23.0)
(30.14285714285714, 23.1)
(20.100000000000005, 23.1)
(30.14285714285714, 23.5)
(26.758333333333336, 23.6)
(26.758333333333336, 23.6)
(20.100000000000005, 23.8)
(30.14285714285714, 23.9)
(26.758333333333336, 24.6)
(43.2, 24.7)
(39.8, 24.7)
(26.758333333333336, 24.7)
(20.10000