Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8711] [ML] Add additional methods to PySpark ML tree models #7095

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tree

import org.apache.spark.mllib.linalg.{Vectors, Vector}

/**
* Abstraction for Decision Tree models.
Expand Down Expand Up @@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]

/** Weights used by the python wrappers. */
// Note: An array cannot be returned directly due to serialization problems.
def javaTreeWeights: Vector = Vectors.dense(treeWeights)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it package private. Java users do not really need it.


/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
Expand Down
20 changes: 16 additions & 4 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.ml.regression import RandomForestParams
from pyspark.ml.regression import (
RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc


Expand Down Expand Up @@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
>>> model = dt.fit(td)
>>> model.numNodes
3
>>> model.depth
1
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -269,7 +274,8 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)


class DecisionTreeClassificationModel(JavaModel):
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeClassifier.
"""
Expand All @@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
It supports both binary and multiclass labels, as well as both continuous and categorical
features.

>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
Expand All @@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -423,7 +432,7 @@ def getFeatureSubsetStrategy(self):
return self.getOrDefault(self.featureSubsetStrategy)


class RandomForestClassificationModel(JavaModel):
class RandomForestClassificationModel(TreeEnsembleModels):
"""
Model fitted by RandomForestClassifier.
"""
Expand All @@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
It supports binary labels, as well as both continuous and categorical features.
Note: Multiclass labels are not currently supported.

>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
Expand All @@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
>>> model = gbt.fit(td)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -558,7 +570,7 @@ def getStepSize(self):
return self.getOrDefault(self.stepSize)


class GBTClassificationModel(JavaModel):
class GBTClassificationModel(TreeEnsembleModels):
"""
Model fitted by GBTClassifier.
"""
Expand Down
46 changes: 43 additions & 3 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> dt = DecisionTreeRegressor(maxDepth=2)
>>> model = dt.fit(df)
>>> model.depth
1
>>> model.numNodes
3
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -239,7 +243,37 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)


class DecisionTreeRegressionModel(JavaModel):
@inherit_doc
class DecisionTreeModel(JavaModel):

@property
def numNodes(self):
"""Return number of nodes of the decision tree."""
return self._call_java("numNodes")

@property
def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")

def __repr__(self):
return self._call_java("toString")


@inherit_doc
class TreeEnsembleModels(JavaModel):

@property
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))

def __repr__(self):
return self._call_java("toString")


@inherit_doc
class DecisionTreeRegressionModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeRegressor.
"""
Expand All @@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
learning algorithm for regression.
It supports both continuous and categorical features.

>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -389,7 +426,7 @@ def getFeatureSubsetStrategy(self):
return self.getOrDefault(self.featureSubsetStrategy)


class RandomForestRegressionModel(JavaModel):
class RandomForestRegressionModel(TreeEnsembleModels):
"""
Model fitted by RandomForestRegressor.
"""
Expand All @@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
learning algorithm for regression.
It supports both continuous and categorical features.

>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
>>> model = gbt.fit(df)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -518,7 +558,7 @@ def getStepSize(self):
return self.getOrDefault(self.stepSize)


class GBTRegressionModel(JavaModel):
class GBTRegressionModel(TreeEnsembleModels):
"""
Model fitted by GBTRegressor.
"""
Expand Down