Skip to content

Commit

Permalink
[SPARK-8711] [ML] Add additional methods to PySpark ML tree models
Browse files Browse the repository at this point in the history
Add numNodes and depth to treeModels, add treeWeights to ensemble Models.
Add __repr__ to all models.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits:

23b08be [MechCoder] private [spark]
38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights
6d16ad8 [MechCoder] Fix Python 3 Error
47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods
819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
  • Loading branch information
MechCoder authored and mengxr committed Jul 7, 2015
1 parent 0a63d7a commit 1dbc4a1
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tree

import org.apache.spark.mllib.linalg.{Vectors, Vector}

/**
* Abstraction for Decision Tree models.
Expand Down Expand Up @@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]

/** Weights used by the python wrappers. */
// Note: An array cannot be returned directly due to serialization problems.
private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)

/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
Expand Down
20 changes: 16 additions & 4 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.ml.regression import RandomForestParams
from pyspark.ml.regression import (
RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc


Expand Down Expand Up @@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
>>> model = dt.fit(td)
>>> model.numNodes
3
>>> model.depth
1
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -269,7 +274,8 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)


class DecisionTreeClassificationModel(JavaModel):
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeClassifier.
"""
Expand All @@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
Expand All @@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -423,7 +432,7 @@ def getFeatureSubsetStrategy(self):
return self.getOrDefault(self.featureSubsetStrategy)


class RandomForestClassificationModel(JavaModel):
class RandomForestClassificationModel(TreeEnsembleModels):
"""
Model fitted by RandomForestClassifier.
"""
Expand All @@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
It supports binary labels, as well as both continuous and categorical features.
Note: Multiclass labels are not currently supported.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
Expand All @@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
>>> model = gbt.fit(td)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -558,7 +570,7 @@ def getStepSize(self):
return self.getOrDefault(self.stepSize)


class GBTClassificationModel(JavaModel):
class GBTClassificationModel(TreeEnsembleModels):
"""
Model fitted by GBTClassifier.
"""
Expand Down
46 changes: 43 additions & 3 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> dt = DecisionTreeRegressor(maxDepth=2)
>>> model = dt.fit(df)
>>> model.depth
1
>>> model.numNodes
3
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -239,7 +243,37 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)


class DecisionTreeRegressionModel(JavaModel):
@inherit_doc
class DecisionTreeModel(JavaModel):

@property
def numNodes(self):
"""Return number of nodes of the decision tree."""
return self._call_java("numNodes")

@property
def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")

def __repr__(self):
return self._call_java("toString")


@inherit_doc
class TreeEnsembleModels(JavaModel):

@property
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))

def __repr__(self):
return self._call_java("toString")


@inherit_doc
class DecisionTreeRegressionModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeRegressor.
"""
Expand All @@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -389,7 +426,7 @@ def getFeatureSubsetStrategy(self):
return self.getOrDefault(self.featureSubsetStrategy)


class RandomForestRegressionModel(JavaModel):
class RandomForestRegressionModel(TreeEnsembleModels):
"""
Model fitted by RandomForestRegressor.
"""
Expand All @@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
>>> model = gbt.fit(df)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -518,7 +558,7 @@ def getStepSize(self):
return self.getOrDefault(self.stepSize)


class GBTRegressionModel(JavaModel):
class GBTRegressionModel(TreeEnsembleModels):
"""
Model fitted by GBTRegressor.
"""
Expand Down

0 comments on commit 1dbc4a1

Please sign in to comment.