In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.\
    builder.\
    master('local').\
    appName('tree-methods-basics').\
    getOrCreate()

In [2]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (
    DecisionTreeClassifier,
    RandomForestClassifier,
    GBTClassifier
)
from pyspark.ml.regression import (
    DecisionTreeRegressor,
    RandomForestRegressor,
    GBTRegressor
)

In [3]:
data = spark.read.format('libsvm').\
    load('D:/learn-ab/learning-PySpark/sample-data/sample-libsvm-data.txt')

In [4]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [5]:
data.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [6]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [7]:
?DecisionTreeClassifier

[1;31mInit signature:[0m
[0mDecisionTreeClassifier[0m[1;33m([0m[1;33m
[0m    [1;33m*[0m[1;33m,[0m[1;33m
[0m    [0mfeaturesCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'features'[0m[1;33m,[0m[1;33m
[0m    [0mlabelCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'label'[0m[1;33m,[0m[1;33m
[0m    [0mpredictionCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'prediction'[0m[1;33m,[0m[1;33m
[0m    [0mprobabilityCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'probability'[0m[1;33m,[0m[1;33m
[0m    [0mrawPredictionCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'rawPrediction'[0m[1;33m,[0m[1;33m
[0m    [0mmaxDepth[0m[1;33m:[0m [0mint[0m [1;33m=[0m [1;36m5[0m[1;33m,[0m[1;33m
[0m    [0mmaxBins[0m[1;33m:[0m [0mint[0m [1;33m=[0m [1;36m32[0m[1;33m,[0m[1;33m
[0m    [0mminInstancesPerNode[0m[1;33m:[0m [0mint[0m [1;33m=[0m [1;36m1[0m[1;33m,[0m[1;33m
[0m    [0mminInfoGain[0m[1;33m:[0m [0mfl

In [8]:
dc_clf = DecisionTreeClassifier()
rf_clf = RandomForestClassifier()
gb_clf = GBTClassifier()

In [9]:
dc_clf_model = dc_clf.fit(train_data)
rf_clf_model = rf_clf.fit(train_data)
gb_clf_model = gb_clf.fit(train_data)

In [10]:
dc_clf_pred = dc_clf_model.transform(test_data)
rf_clf_pred = rf_clf_model.transform(test_data)
gb_clf_pred = gb_clf_model.transform(test_data)

In [11]:
dc_clf_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[125,126,127...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[181,182,183...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[97,98,99,12...|   [0.0,36.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [12]:
rf_clf_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [15.0,5.0]|[0.75,0.25]|       0.0|
|  0.0|(692,[124,125,126...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[125,126,127...|   [19.0,1.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[126,127,128...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[181,182,183...|   [18.0,2.0]|  [0.9,0.1]|       0.0|
|  1.0|(692,[97,98,99,12...|   [4.0,16.0]|  [0.2,0.8]|       1.0|
|  1.0|(69

In [13]:
gb_clf_pred.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[100,101,102...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.29370791752444...|[0.93004727978904...|       0.0|
|  0.0|(692,[152,153,154...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[152,153,154...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[154

In [14]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [15]:
?MulticlassClassificationEvaluator

[1;31mInit signature:[0m
[0mMulticlassClassificationEvaluator[0m[1;33m([0m[1;33m
[0m    [1;33m*[0m[1;33m,[0m[1;33m
[0m    [0mpredictionCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'prediction'[0m[1;33m,[0m[1;33m
[0m    [0mlabelCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'label'[0m[1;33m,[0m[1;33m
[0m    [0mmetricName[0m[1;33m:[0m [1;34m'MulticlassClassificationEvaluatorMetricType'[0m [1;33m=[0m [1;34m'f1'[0m[1;33m,[0m[1;33m
[0m    [0mweightCol[0m[1;33m:[0m [0mOptional[0m[1;33m[[0m[0mstr[0m[1;33m][0m [1;33m=[0m [1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mmetricLabel[0m[1;33m:[0m [0mfloat[0m [1;33m=[0m [1;36m0.0[0m[1;33m,[0m[1;33m
[0m    [0mbeta[0m[1;33m:[0m [0mfloat[0m [1;33m=[0m [1;36m1.0[0m[1;33m,[0m[1;33m
[0m    [0mprobabilityCol[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'probability'[0m[1;33m,[0m[1;33m
[0m    [0meps[0m[1;33m:[0m [0mfloat[0m [1;33m=[0m [1;36m1e-15[

In [16]:
acc_eval = MulticlassClassificationEvaluator(
    metricName='accuracy'
)

In [17]:
print("Decision Tree Accuracy     : ", acc_eval.evaluate(dc_clf_pred))
print("Random Forest Accuracy     : ", acc_eval.evaluate(rf_clf_pred))
print("Gradient Boosting Accuracy : ", acc_eval.evaluate(gb_clf_pred))

Decision Tree Accuracy     :  0.967741935483871
Random Forest Accuracy     :  0.967741935483871
Gradient Boosting Accuracy :  0.967741935483871


In [18]:
dc_clf_model.featureImportances

SparseVector(692, {100: 0.0565, 406: 0.9435})

In [19]:
rf_clf_model.featureImportances

SparseVector(692, {99: 0.0018, 237: 0.0023, 244: 0.0755, 263: 0.0145, 268: 0.0019, 272: 0.1033, 301: 0.0032, 344: 0.0029, 349: 0.0208, 352: 0.0114, 358: 0.0025, 375: 0.0121, 379: 0.0445, 398: 0.0166, 401: 0.0077, 406: 0.0547, 407: 0.0449, 413: 0.0311, 428: 0.035, 433: 0.1413, 434: 0.05, 435: 0.0171, 440: 0.0334, 457: 0.014, 460: 0.0017, 461: 0.0955, 463: 0.0527, 489: 0.0386, 490: 0.0052, 491: 0.0028, 496: 0.0027, 511: 0.0117, 539: 0.036, 572: 0.0027, 577: 0.0032, 665: 0.0045})

In [20]:
gb_clf_model.featureImportances

SparseVector(692, {100: 0.0297, 267: 0.0068, 293: 0.0007, 320: 0.0023, 406: 0.5065, 433: 0.0688, 434: 0.2258, 490: 0.1424, 568: 0.017})