Skip to content

Commit

Permalink
[SPARK-10278] [MLLIB] [PYSPARK] Add @SInCE annotation to pyspark.mlli…
Browse files Browse the repository at this point in the history
…b.tree

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #8685 from yu-iskw/SPARK-10278.
  • Loading branch information
yu-iskw authored and mengxr committed Sep 17, 2015
1 parent 0ded87a commit 39b44cb
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import random

from pyspark import SparkContext, RDD
from pyspark import SparkContext, RDD, since
from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
Expand All @@ -30,6 +30,11 @@


class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
"""TreeEnsembleModel
.. versionadded:: 1.3.0
"""
@since("1.3.0")
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
Expand All @@ -45,12 +50,14 @@ def predict(self, x):
else:
return self.call("predict", _convert_to_vector(x))

@since("1.3.0")
def numTrees(self):
"""
Get number of trees in ensemble.
"""
return self.call("numTrees")

@since("1.3.0")
def totalNumNodes(self):
"""
Get total number of nodes, summed over all trees in the
Expand All @@ -62,6 +69,7 @@ def __repr__(self):
""" Summary of model """
return self._java_model.toString()

@since("1.3.0")
def toDebugString(self):
""" Full model """
return self._java_model.toDebugString()
Expand All @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
.. note:: Experimental
A decision tree model for classification or regression.
.. versionadded:: 1.1.0
"""
@since("1.1.0")
def predict(self, x):
"""
Predict the label of one or more examples.
Expand All @@ -90,16 +101,23 @@ def predict(self, x):
else:
return self.call("predict", _convert_to_vector(x))

@since("1.1.0")
def numNodes(self):
"""Get number of nodes in tree, including leaf nodes."""
return self._java_model.numNodes()

@since("1.1.0")
def depth(self):
"""Get depth of tree.
E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
"""
return self._java_model.depth()

def __repr__(self):
""" summary of model. """
return self._java_model.toString()

@since("1.2.0")
def toDebugString(self):
""" full model. """
return self._java_model.toDebugString()
Expand All @@ -115,6 +133,8 @@ class DecisionTree(object):
Learning algorithm for a decision tree model for classification or
regression.
.. versionadded:: 1.1.0
"""

@classmethod
Expand All @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m
return DecisionTreeModel(model)

@classmethod
@since("1.1.0")
def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
Expand Down Expand Up @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)

@classmethod
@since("1.1.0")
def trainRegressor(cls, data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
Expand Down Expand Up @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader):
.. note:: Experimental
Represents a random forest model.
.. versionadded:: 1.2.0
"""

@classmethod
Expand All @@ -252,6 +276,8 @@ class RandomForest(object):
Learning algorithm for a random forest model for classification or
regression.
.. versionadded:: 1.2.0
"""

supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
Expand All @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees,
return RandomForestModel(model)

@classmethod
@since("1.2.0")
def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32,
seed=None):
Expand Down Expand Up @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
maxDepth, maxBins, seed)

@classmethod
@since("1.2.0")
def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto",
impurity="variance", maxDepth=4, maxBins=32, seed=None):
"""
Expand Down Expand Up @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
.. note:: Experimental
Represents a gradient-boosted tree model.
.. versionadded:: 1.3.0
"""

@classmethod
Expand All @@ -431,6 +461,8 @@ class GradientBoostedTrees(object):
Learning algorithm for a gradient boosted trees model for
classification or regression.
.. versionadded:: 1.3.0
"""

@classmethod
Expand All @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo,
return GradientBoostedTreesModel(model)

@classmethod
@since("1.3.0")
def trainClassifier(cls, data, categoricalFeaturesInfo,
loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3,
maxBins=32):
Expand Down Expand Up @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
loss, numIterations, learningRate, maxDepth, maxBins)

@classmethod
@since("1.3.0")
def trainRegressor(cls, data, categoricalFeaturesInfo,
loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3,
maxBins=32):
Expand Down

0 comments on commit 39b44cb

Please sign in to comment.