From 7a972c73192431d1c5980cfe40c76a1b88da87f8 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 30 Sep 2016 21:53:41 +0800 Subject: [PATCH 1/2] create pr --- .../apache/spark/ml/classification/NaiveBayes.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0d652aa4c65a1..048b340b9955e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,6 +25,8 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.Row import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType @@ -362,9 +364,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) From d3c91b3b998ba0be03172e8b560a51cd056b15fd Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 30 Sep 2016 22:11:22 +0800 Subject: [PATCH 2/2] fix style --- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 048b340b9955e..6775745167b08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -26,8 +26,7 @@ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.Row -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType