Skip to content

Commit

Permalink
PySpark ml DecisionTreeClassifier, Regressor support export/import
Browse files Browse the repository at this point in the history
  • Loading branch information
GayathriMurali committed Mar 22, 2016
1 parent 0ce0163 commit 0154444
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
16 changes: 14 additions & 2 deletions python/pyspark/ml/classification.py
Expand Up @@ -276,7 +276,8 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
TreeClassifierParams, HasCheckpointInterval, HasSeed):
TreeClassifierParams, HasCheckpointInterval, HasSeed, MLWritable,
MLReadable):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for classification.
Expand Down Expand Up @@ -311,6 +312,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> model.transform(test1).head().prediction
1.0
>>> dtc_path = temp_path + "/dtc"
>>> dt.save(dtc_path)
>>> dt2 = DecisionTreeClassifier.load(dtc_path)
>>> dt2.getMaxDepth()
2
>>> model_path = temp_path + "/dtc_model"
>>> model.save(model_path)
>>> model2 = DecisionTreeClassificationModel.load(model_path)
>>> model.featureImportances == model2.featureImportances
True
.. versionadded:: 1.4.0
"""

Expand Down Expand Up @@ -359,7 +371,7 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel):
class DecisionTreeClassificationModel(DecisionTreeModel, MLWritable, MLReadable):
"""
Model fitted by DecisionTreeClassifier.
Expand Down
16 changes: 14 additions & 2 deletions python/pyspark/ml/regression.py
Expand Up @@ -385,7 +385,7 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
HasSeed):
HasSeed, MLWritable, MLReadable):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
Expand All @@ -409,6 +409,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> dtr_path = temp_path + "/dtr"
>>> dt.save(dtr_path)
>>> dt2 = DecisionTreeRegressor.load(dtr_path)
>>> dt2.getMaxDepth()
2
>>> model_path = temp_path + "/dtr_model"
>>> model.save(model_path)
>>> model2 = DecisionTreeRegressionModel.load(model_path)
>>> model.numNodes == model2.numNodes
True
>>> model.depth == model2.depth
True
.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -454,7 +466,7 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeModel(JavaModel):
class DecisionTreeModel(JavaModel, MLWritable, MLReadable):
"""Abstraction for Decision Tree models.
.. versionadded:: 1.5.0
Expand Down
40 changes: 38 additions & 2 deletions python/pyspark/ml/tests.py
Expand Up @@ -38,13 +38,13 @@
import tempfile

from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.mllib.linalg import DenseVector
Expand Down Expand Up @@ -560,6 +560,42 @@ def test_pipeline_persistence(self):
except OSError:
pass

def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()
dtc_path = path + "/dtc"
dt.save(dtc_path)
dt2 = DecisionTreeClassifier.load(dtc_path)
self.assertEqual(dt2.uid, dt2.maxDepth.parent,
"Loaded DecisionTreeClassifier instance uid (%s) "
"did not match Param's uid (%s)"
% (dt2.uid, dt2.maxDepth.parent))
self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
"Loaded DecisionTreeClassifier instance default params did not match " +
"original defaults")
try:
rmtree(path)
except OSError:
pass

def test_decisiontree_regressor(self):
dt = DecisionTreeRegressor(maxDepth=1)
path = tempfile.mkdtemp()
dtr_path = path + "/dtr"
dt.save(dtr_path)
dt2 = DecisionTreeClassifier.load(dtr_path)
self.assertEqual(dt2.uid, dt2.maxDepth.parent,
"Loaded DecisionTreeRegressor instance uid (%s) "
"did not match Param's uid (%s)"
% (dt2.uid, dt2.maxDepth.parent))
self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
"Loaded DecisionTreeRegressor instance default params did not match " +
"original defaults")
try:
rmtree(path)
except OSError:
pass


class HasThrowableProperty(Params):

Expand Down

0 comments on commit 0154444

Please sign in to comment.