From c66476ea8d31db92a6723d5f702bc1be6e0ea2aa Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 20 Jun 2016 13:48:01 -0700 Subject: [PATCH 1/2] Added import for DecisionTreeRegressionModel to fix NameError in GBT model --- python/pyspark/ml/classification.py | 6 ++++-- python/pyspark/ml/regression.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 121b9262dd9de..880a44804c9ff 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -21,8 +21,8 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param.shared import * -from pyspark.ml.regression import ( - RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) +from pyspark.ml.regression import (DecisionTreeModel, DecisionTreeRegressionModel, + RandomForestParams, TreeEnsembleModels, TreeEnsembleParams) from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper @@ -798,6 +798,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.treeWeights == model2.treeWeights True + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] .. versionadded:: 1.4.0 """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index db31993f0fb70..8d2378d51fb7e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -994,6 +994,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.treeWeights == model2.treeWeights True + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] .. versionadded:: 1.4.0 """ From dc5a6de599cbc9eb5bffed2f14073cab37bb334e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 20 Jun 2016 13:59:08 -0700 Subject: [PATCH 2/2] reformatted imports to be more consistent --- python/pyspark/ml/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 880a44804c9ff..a3cd91790c42e 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -21,8 +21,8 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param.shared import * -from pyspark.ml.regression import (DecisionTreeModel, DecisionTreeRegressionModel, - RandomForestParams, TreeEnsembleModels, TreeEnsembleParams) +from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \ + RandomForestParams, TreeEnsembleModels, TreeEnsembleParams from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper