In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("new_tree").getOrCreate()

In [4]:
data = spark.read.csv("/dataset/tree/College.csv", inferSchema=True, header=True)

In [5]:
from pyspark.ml.feature import VectorAssembler

In [6]:
data.printSchema()

root
 |-- School: string (nullable = true)
 |-- Private: string (nullable = true)
 |-- Apps: integer (nullable = true)
 |-- Accept: integer (nullable = true)
 |-- Enroll: integer (nullable = true)
 |-- Top10perc: integer (nullable = true)
 |-- Top25perc: integer (nullable = true)
 |-- F_Undergrad: integer (nullable = true)
 |-- P_Undergrad: integer (nullable = true)
 |-- Outstate: integer (nullable = true)
 |-- Room_Board: integer (nullable = true)
 |-- Books: integer (nullable = true)
 |-- Personal: integer (nullable = true)
 |-- PhD: integer (nullable = true)
 |-- Terminal: integer (nullable = true)
 |-- S_F_Ratio: double (nullable = true)
 |-- perc_alumni: integer (nullable = true)
 |-- Expend: integer (nullable = true)
 |-- Grad_Rate: integer (nullable = true)



In [11]:
assembler = VectorAssembler(inputCols=["Apps",
                                        "Accept",
                                        "Enroll",
                                        "Top10perc",
                                        "Top25perc",
                                        "F_Undergrad",
                                        "P_Undergrad",
                                        "Outstate",
                                        "Room_Board",
                                        "Books",
                                        "Personal",
                                        "PhD",
                                       "Terminal",
                                       "S_F_Ratio",
                                       "perc_alumni",
                                       "Expend",
                                       "Grad_Rate"], outputCol="features")

In [12]:
output = assembler.transform(data)

In [14]:
from pyspark.ml.feature import StringIndexer

In [15]:
indexer = StringIndexer(inputCol="Private", outputCol="PrivateIndex")

In [19]:
output_fixed = indexer.fit(output).transform(output)

In [18]:
output_fixed.printSchema()

root
 |-- School: string (nullable = true)
 |-- Private: string (nullable = true)
 |-- Apps: integer (nullable = true)
 |-- Accept: integer (nullable = true)
 |-- Enroll: integer (nullable = true)
 |-- Top10perc: integer (nullable = true)
 |-- Top25perc: integer (nullable = true)
 |-- F_Undergrad: integer (nullable = true)
 |-- P_Undergrad: integer (nullable = true)
 |-- Outstate: integer (nullable = true)
 |-- Room_Board: integer (nullable = true)
 |-- Books: integer (nullable = true)
 |-- Personal: integer (nullable = true)
 |-- PhD: integer (nullable = true)
 |-- Terminal: integer (nullable = true)
 |-- S_F_Ratio: double (nullable = true)
 |-- perc_alumni: integer (nullable = true)
 |-- Expend: integer (nullable = true)
 |-- Grad_Rate: integer (nullable = true)
 |-- features: vector (nullable = true)
 |-- PrivateIndex: double (nullable = false)



In [20]:
final_data = output_fixed.select("PrivateIndex","features")
final_data.show()

+------------+--------------------+
|PrivateIndex|            features|
+------------+--------------------+
|         0.0|[1660.0,1232.0,72...|
|         0.0|[2186.0,1924.0,51...|
|         0.0|[1428.0,1097.0,33...|
|         0.0|[417.0,349.0,137....|
|         0.0|[193.0,146.0,55.0...|
|         0.0|[587.0,479.0,158....|
|         0.0|[353.0,340.0,103....|
|         0.0|[1899.0,1720.0,48...|
|         0.0|[1038.0,839.0,227...|
|         0.0|[582.0,498.0,172....|
|         0.0|[1732.0,1425.0,47...|
|         0.0|[2652.0,1900.0,48...|
|         0.0|[1179.0,780.0,290...|
|         0.0|[1267.0,1080.0,38...|
|         0.0|[494.0,313.0,157....|
|         0.0|[1420.0,1093.0,22...|
|         0.0|[4302.0,992.0,418...|
|         0.0|[1216.0,908.0,423...|
|         0.0|[1130.0,704.0,322...|
|         1.0|[3540.0,2001.0,10...|
+------------+--------------------+
only showing top 20 rows



In [21]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [22]:
from pyspark.ml.classification import (RandomForestClassifier, DecisionTreeClassifier, GBTClassifier)

In [24]:
from pyspark.ml import pipeline

In [25]:
dtc = DecisionTreeClassifier(labelCol="PrivateIndex", featuresCol="features")
rfc = RandomForestClassifier(labelCol="PrivateIndex", featuresCol="features")
gbt = GBTClassifier(labelCol="PrivateIndex", featuresCol="features")

In [26]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [27]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [45]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [40]:
binary_eval = BinaryClassificationEvaluator(labelCol="PrivateIndex", rawPredictionCol="prediction")

In [41]:
print("Decision tree", binary_eval.evaluate(dtc_preds))
print("Random tree", binary_eval.evaluate(rfc_preds))
print("Greident tree", binary_eval.evaluate(gbt_preds))

Decision tree 0.9430107526881721
Random tree 0.9435483870967742
Greident tree 0.9457885304659499


In [42]:
dtc_preds.printSchema()

root
 |-- PrivateIndex: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [43]:
rfc_preds.printSchema()

root
 |-- PrivateIndex: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [44]:
gbt_preds.printSchema()

root
 |-- PrivateIndex: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [50]:
multi_eval = MulticlassClassificationEvaluator(labelCol = "PrivateIndex", metricName="accuracy")

In [51]:
print("Multi evaluate", multi_eval.evaluate(rfc_preds))

Multi evaluate 0.9710743801652892
