In [1]:
from pyspark import SparkContext
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import MLUtils
from pyspark.sql.types import StructType, IntegerType, DateType, StringType,StructField, FloatType
from pyspark.ml.feature import VectorAssembler,StringIndexer
from pyspark.sql.functions import col

In [2]:
schema = StructType([StructField('sepal_length', FloatType()),
                     StructField('sepal_width', FloatType()),
                     StructField('petal_length', FloatType()),
                     StructField('petal_width', FloatType()),
                     StructField('class', StringType())
                    ])

In [3]:
data = spark.read \
        .schema(schema)\
        .format("csv")\
        .option("header",True)\
        .load("iris.txt")

In [4]:
data.show(2)

+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|      class|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 2 rows



In [5]:
data.select('class').distinct().show()

+---------------+
|          class|
+---------------+
| Iris-virginica|
|    Iris-setosa|
|Iris-versicolor|
+---------------+



In [6]:
labelIndexer = StringIndexer(inputCol="class", outputCol="indexedLabel").fit(data).transform(data)

In [7]:
labelIndexer.show(2)

+------------+-----------+------------+-----------+-----------+------------+
|sepal_length|sepal_width|petal_length|petal_width|      class|indexedLabel|
+------------+-----------+------------+-----------+-----------+------------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|         0.0|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|         0.0|
+------------+-----------+------------+-----------+-----------+------------+
only showing top 2 rows



In [8]:
assembler = VectorAssembler(
    inputCols=["sepal_length", "sepal_width", "petal_length",'petal_width'],
    outputCol="features").transform(labelIndexer)

In [9]:
assembler.show(2)

+------------+-----------+------------+-----------+-----------+------------+--------------------+
|sepal_length|sepal_width|petal_length|petal_width|      class|indexedLabel|            features|
+------------+-----------+------------+-----------+-----------+------------+--------------------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|         0.0|[5.09999990463256...|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|         0.0|[4.90000009536743...|
+------------+-----------+------------+-----------+-----------+------------+--------------------+
only showing top 2 rows



In [10]:
dfVector = MLUtils.convertVectorColumnsFromML(assembler, "features")

In [11]:
data = (dfVector.select(col("indexedLabel").alias("label"), col("features"))
  .rdd
  .map(lambda row: LabeledPoint(row.label, row.features)))

In [12]:
(trainingData, testData) = data.randomSplit([0.7, 0.3])

In [13]:
trainingData.take(5)

[LabeledPoint(0.0, [5.099999904632568,3.5,1.399999976158142,0.20000000298023224]),
 LabeledPoint(0.0, [4.699999809265137,3.200000047683716,1.2999999523162842,0.20000000298023224]),
 LabeledPoint(0.0, [4.599999904632568,3.0999999046325684,1.5,0.20000000298023224]),
 LabeledPoint(0.0, [5.400000095367432,3.9000000953674316,1.7000000476837158,0.4000000059604645]),
 LabeledPoint(0.0, [5.0,3.4000000953674316,1.5,0.20000000298023224])]

In [14]:
model = DecisionTree.trainClassifier(trainingData, numClasses=3, categoricalFeaturesInfo={},
                                         impurity='gini', maxDepth=5, maxBins=32)

In [15]:
predictions = model.predict(testData.map(lambda x: x.features))

In [16]:
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

In [18]:
testErr = labelsAndPredictions.filter(
        lambda lp: lp[0] != lp[1]).count() / float(testData.count())

In [19]:
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())

Test Error = 0.1590909090909091
Learned classification tree model:
DecisionTreeModel classifier of depth 5 with 13 nodes
  If (feature 3 <= 0.800000011920929)
   Predict: 0.0
  Else (feature 3 > 0.800000011920929)
   If (feature 3 <= 1.6500000357627869)
    Predict: 1.0
   Else (feature 3 > 1.6500000357627869)
    If (feature 3 <= 1.75)
     If (feature 0 <= 4.950000047683716)
      Predict: 2.0
     Else (feature 0 > 4.950000047683716)
      Predict: 1.0
    Else (feature 3 > 1.75)
     If (feature 2 <= 4.8500001430511475)
      If (feature 0 <= 5.950000047683716)
       Predict: 1.0
      Else (feature 0 > 5.950000047683716)
       Predict: 2.0
     Else (feature 2 > 4.8500001430511475)
      Predict: 2.0



In [None]:
#model.save(sc, "DecisionTreeModel")
#sameModel = DecisionTreeModel.load(sc, "DecisionTreeModel")