diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 1929f9d02156e..4c643d9c6ce4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * Abstraction for Decision Tree models. @@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel { /** Weights for each tree, zippable with [[trees]] */ def treeWeights: Array[Double] + /** Weights used by the python wrappers. */ + // Note: An array cannot be returned directly due to serialization problems. + def pyTreeWeights: Vector = Vectors.dense(treeWeights) + /** Summary of the model */ override def toString: String = { // Implementing classes should generally override this method to be more descriptive. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2142b2a7bd966..40017fc39c08f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -265,7 +265,8 @@ class TreeEnsembleModels(JavaModel): @property def treeWeights(self): - return list(self._call_java("treeWeights")) + """Return the weights for each tree""" + return list(self._call_java("pyTreeWeights")) def __repr__(self): return self._call_java("toString")