From 208e166707b44815cffee088457ad29e01f87474 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 8 Jul 2015 16:31:17 +0800 Subject: [PATCH 1/6] Naive Bayes API for spark.ml Pipelines --- .../spark/ml/classification/NaiveBayes.scala | 188 ++++++++++++++++++ .../mllib/classification/NaiveBayes.scala | 2 +- .../apache/spark/mllib/linalg/Matrices.scala | 2 +- .../classification/JavaNaiveBayesSuite.java | 98 +++++++++ .../ml/classification/NaiveBayesSuite.scala | 122 ++++++++++++ 5 files changed, 410 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala new file mode 100644 index 0000000000000..1b5e18b3e1065 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkException +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * Params for Naive Bayes Classifiers. + */ +private[ml] trait NaiveBayesParams extends PredictorParams { + + /** + * The smoothing parameter. + * @group param + */ + final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", + ParamValidators.gtEq(0)) + setDefault(lambda -> 1.0) + + /** @group getParam */ + final def getLambda: Double = $(lambda) + + /** + * The model type which is a string (case-sensitive). + * Supported options: "multinomial" (default) and "bernoulli". + * @group param + */ + final val modelType: Param[String] = new Param[String](this, "modelType", + "The model type which is a string (case-sensitive). Supported options: " + + "\"multinomial\" (default) and \"bernoulli\".") + setDefault(modelType -> "multinomial") + + /** @group getParam */ + final def getModelType: String = $(modelType) +} + +/** + * Naive Bayes Classifiers. + * It supports both Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle + * all kinds of discrete data and Bernoulli NB ([[http://tinyurl.com/p7c96j6]]) + * which can only handle 0-1 vector. + */ +class NaiveBayes(override val uid: String) + extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + with NaiveBayesParams { + + def this() = this(Identifiable.randomUID("nb")) + + /** + * Set the smoothing parameter. + * Default is 1.0. + * @group setParam + */ + def setLambda(value: Double): this.type = set(lambda, value) + + /** + * Set the model type using a string (case-sensitive). + * Supported options: "multinomial" and "bernoulli". + * Default is "multinomial" + * @return + */ + def setModelType(value: String): this.type = set(modelType, value) + + override protected def train(dataset: DataFrame): NaiveBayesModel = { + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) + NaiveBayesModel.fromOld(oldModel, this) + } + + override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) +} + +/** + * Model produced by [[NaiveBayes]] + */ +class NaiveBayesModel private[ml] ( + override val uid: String, + val labels: Array[Double], + val pi: Array[Double], + val theta: Array[Array[Double]], + val modelType: String) + extends PredictionModel[Vector, NaiveBayesModel] { + + /** String name for multinomial model type. */ + private[classification] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[classification] val Bernoulli: String = "bernoulli" + + /** Set of modelTypes that NaiveBayes supports */ + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + + private val piVector = new DenseVector(pi) + private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true) + + require(supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown modelType: ${modelType}.") + + /** + * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + * application of this condition (in predict function). + */ + private val (thetaMinusNegTheta, negThetaSum) = modelType match { + case Multinomial => (None, None) + case Bernoulli => + val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) + val thetaMinusNegTheta = thetaMatrix.map { value => + value - math.log(1.0 - math.exp(value)) + } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${modelType}.") + } + + override protected def predict(features: Vector): Double = { + modelType match { + case Multinomial => + val prob = thetaMatrix.multiply(features) + BLAS.axpy(1.0, piVector, prob) + labels(prob.argmax) + case Bernoulli => + features.foreachActive{ (index, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found ${features}") + } + } + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + labels(prob.argmax) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${modelType}.") + } + } + + override def copy(extra: ParamMap): NaiveBayesModel = { + copyValues(new NaiveBayesModel(uid, labels, pi, theta, modelType), extra) + } + + override def toString: String = { + s"NaiveBayesModel with ${labels.size} classes" + } + + private[ml] def toOld: OldNaiveBayesModel = { + new OldNaiveBayesModel(labels, pi, theta, modelType) + } + +} + +private[ml] object NaiveBayesModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldNaiveBayesModel, + parent: NaiveBayes): NaiveBayesModel = { + val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") + new NaiveBayesModel(uid, oldModel.labels, oldModel.pi, oldModel.theta, oldModel.modelType) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f51ee36d0dfcb..29a9cf6dab187 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * where D is number of features * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ -class NaiveBayesModel private[mllib] ( +class NaiveBayesModel private[spark] ( val labels: Array[Double], val pi: Array[Double], val theta: Array[Array[Double]], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 75e7004464af9..219a0b28139c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -291,7 +291,7 @@ class DenseMatrix( override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), + private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): DenseMatrix = { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java new file mode 100644 index 0000000000000..09a9fba0c19cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaNaiveBayesSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + public void validatePrediction(DataFrame predictionAndLabels) { + for (Row r : predictionAndLabels.collect()) { + double prediction = r.getAs(0); + double label = r.getAs(1); + assert(prediction == label); + } + } + + @Test + public void naiveBayesDefaultParams() { + NaiveBayes nb = new NaiveBayes(); + assert(nb.getLabelCol() == "label"); + assert(nb.getFeaturesCol() == "features"); + assert(nb.getPredictionCol() == "prediction"); + assert(nb.getLambda() == 1.0); + assert(nb.getModelType() == "multinomial"); + } + + @Test + public void testNaiveBayes() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), + RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)), + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)), + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(jrdd, schema); + NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); + NaiveBayesModel model = nb.fit(dataset); + + DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + validatePrediction(predictionAndLabels); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala new file mode 100644 index 0000000000000..cff8f4f254c79 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.classification.NaiveBayesSuite._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + + def validatePrediction(predictionAndLabels: DataFrame): Unit = { + val numOfErrorPredictions = predictionAndLabels.collect().count { + case Row(prediction: Double, label: Double) => + prediction != label + } + // At least 80% of the predictions should be on. + assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + } + + def validateModelFit( + piData: Array[Double], + thetaData: Array[Array[Double]], + model: NaiveBayesModel): Unit = { + def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { + (d1 - d2).abs <= precision + } + val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + for (i <- modelIndex) { + assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + } + for (i <- modelIndex) { + for (j <- 0 until thetaData(i._2).length) { + assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + } + } + } + + test("params") { + ParamsSuite.checkParams(new NaiveBayes) + val model = new NaiveBayesModel("nb", labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "multinomial") + ParamsSuite.checkParams(model) + } + + test("naive bayes: default params") { + val nb = new NaiveBayes + assert(nb.getLabelCol === "label") + assert(nb.getFeaturesCol === "features") + assert(nb.getPredictionCol === "prediction") + assert(nb.getLambda === 1.0) + assert(nb.getModelType === "multinomial") + } + + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + pi, theta, nPoints, 42, "multinomial")) + val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) + assert(model.hasParent) + + val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + pi, theta, nPoints, 17, "multinomial")) + + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + + validatePrediction(predictionAndLabels) + } + + test("Naive Bayes Bernoulli") { + val nPoints = 10000 + val pi = Array(0.5, 0.3, 0.2).map(math.log) + val theta = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + pi, theta, nPoints, 45, "bernoulli")) + val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) + assert(model.hasParent) + + val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + pi, theta, nPoints, 20, "bernoulli")) + + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + + validatePrediction(predictionAndLabels) + } +} From 3018a41c90fae16856cfaaf5bf464e63ef065934 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 12 Jul 2015 15:13:41 +0800 Subject: [PATCH 2/6] address comments --- .../spark/ml/classification/NaiveBayes.scala | 80 ++++++++++--------- .../apache/spark/mllib/linalg/Matrices.scala | 4 +- .../ml/classification/NaiveBayesSuite.scala | 47 ++++++----- 3 files changed, 69 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1b5e18b3e1065..dee0cda6b4c13 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} -import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Vector} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -35,24 +35,24 @@ private[ml] trait NaiveBayesParams extends PredictorParams { /** * The smoothing parameter. + * (default = 1.0). * @group param */ final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", ParamValidators.gtEq(0)) - setDefault(lambda -> 1.0) /** @group getParam */ final def getLambda: Double = $(lambda) /** * The model type which is a string (case-sensitive). - * Supported options: "multinomial" (default) and "bernoulli". + * Supported options: "multinomial" and "bernoulli". + * (default = multinomial) * @group param */ final val modelType: Param[String] = new Param[String](this, "modelType", "The model type which is a string (case-sensitive). Supported options: " + "\"multinomial\" (default) and \"bernoulli\".") - setDefault(modelType -> "multinomial") /** @group getParam */ final def getModelType: String = $(modelType) @@ -60,9 +60,13 @@ private[ml] trait NaiveBayesParams extends PredictorParams { /** * Naive Bayes Classifiers. - * It supports both Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle - * all kinds of discrete data and Bernoulli NB ([[http://tinyurl.com/p7c96j6]]) - * which can only handle 0-1 vector. + * It supports both Multinomial NB + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) + * which can handle finitely supported discrete data. For example, by converting documents into + * TF-IDF vectors, it can be used for document classification. By making every vector a + * binary (0/1) data, it can also be used as Bernoulli NB + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). + * The input feature values must be nonnegative. */ class NaiveBayes(override val uid: String) extends Predictor[Vector, NaiveBayes, NaiveBayesModel] @@ -76,14 +80,15 @@ class NaiveBayes(override val uid: String) * @group setParam */ def setLambda(value: Double): this.type = set(lambda, value) + setDefault(lambda -> 1.0) /** * Set the model type using a string (case-sensitive). * Supported options: "multinomial" and "bernoulli". * Default is "multinomial" - * @return */ def setModelType(value: String): this.type = set(modelType, value) + setDefault(modelType -> "multinomial") override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -99,26 +104,16 @@ class NaiveBayes(override val uid: String) */ class NaiveBayesModel private[ml] ( override val uid: String, - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]], + val labels: Vector, + val pi: Vector, + val theta: Matrix, val modelType: String) extends PredictionModel[Vector, NaiveBayesModel] { - /** String name for multinomial model type. */ - private[classification] val Multinomial: String = "multinomial" - - /** String name for Bernoulli model type. */ - private[classification] val Bernoulli: String = "bernoulli" - - /** Set of modelTypes that NaiveBayes supports */ - private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) - - private val piVector = new DenseVector(pi) - private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true) + import NaiveBayesModel.{Bernoulli, Multinomial, supportedModelTypes} require(supportedModelTypes.contains(modelType), - s"NaiveBayes was created with an unknown modelType: ${modelType}.") + s"NaiveBayes was created with an unknown modelType: $modelType.") /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. @@ -128,37 +123,37 @@ class NaiveBayesModel private[ml] ( private val (thetaMinusNegTheta, negThetaSum) = modelType match { case Multinomial => (None, None) case Bernoulli => - val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) - val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) - val thetaMinusNegTheta = thetaMatrix.map { value => + val negTheta = theta.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(theta.numCols){1.0}) + val thetaMinusNegTheta = theta.map { value => value - math.log(1.0 - math.exp(value)) } (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${modelType}.") + throw new UnknownError(s"Invalid modelType: $modelType.") } override protected def predict(features: Vector): Double = { modelType match { case Multinomial => - val prob = thetaMatrix.multiply(features) - BLAS.axpy(1.0, piVector, prob) + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) labels(prob.argmax) case Bernoulli => features.foreachActive{ (index, value) => if (value != 0.0 && value != 1.0) { throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found ${features}") + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") } } val prob = thetaMinusNegTheta.get.multiply(features) - BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, pi, prob) BLAS.axpy(1.0, negThetaSum.get, prob) labels(prob.argmax) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${modelType}.") + throw new UnknownError(s"Invalid modelType: $modelType.") } } @@ -170,19 +165,28 @@ class NaiveBayesModel private[ml] ( s"NaiveBayesModel with ${labels.size} classes" } - private[ml] def toOld: OldNaiveBayesModel = { - new OldNaiveBayesModel(labels, pi, theta, modelType) - } - } private[ml] object NaiveBayesModel { - /** (private[ml]) Convert a model from the old API */ + /** String name for multinomial model type. */ + private[classification] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[classification] val Bernoulli: String = "bernoulli" + + /** Set of modelTypes that NaiveBayes supports */ + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + + /** Convert a model from the old API */ def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") - new NaiveBayesModel(uid, oldModel.labels, oldModel.pi, oldModel.theta, oldModel.modelType) + val labels = Vectors.dense(oldModel.labels) + val pi = Vectors.dense(oldModel.pi) + val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, + oldModel.theta.flatten, true) + new NaiveBayesModel(uid, labels, pi, theta, oldModel.modelType) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 219a0b28139c1..042347af9ea0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable { /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ - private[mllib] def map(f: Double => Double): Matrix + private[spark] def map(f: Double => Double): Matrix /** Update all the values of this matrix using the function f. Performed in-place on the * backing array. For example, an operation such as addition or subtraction will only be @@ -557,7 +557,7 @@ class SparseMatrix( new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } - private[mllib] def map(f: Double => Double) = + private[spark] def map(f: Double => Double) = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): SparseMatrix = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index cff8f4f254c79..9935aedd2d413 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row @@ -36,28 +38,25 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { } def validateModelFit( - piData: Array[Double], - thetaData: Array[Array[Double]], + piData: Vector, + thetaData: Matrix, model: NaiveBayesModel): Unit = { - def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { - (d1 - d2).abs <= precision - } - val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + val modelIndex = (0 until piData.toArray.length).zip(model.labels.toArray.map(_.toInt)) for (i <- modelIndex) { - assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + assert(math.exp(model.pi(i._1)) ~== math.exp(piData(i._2)) absTol 0.05, "pi mismatch") } - for (i <- modelIndex) { - for (j <- 0 until thetaData(i._2).length) { - assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) - } + for (i <- modelIndex; + j <- 0 until thetaData.numCols) { + assert(math.exp(model.theta(i._1, j)) ~== math.exp(thetaData(i._2, j)) absTol 0.05, + "theta mismatch") } } test("params") { ParamsSuite.checkParams(new NaiveBayes) - val model = new NaiveBayesModel("nb", labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "multinomial") + val model = new NaiveBayesModel("nb", labels = Vectors.dense(Array(0.0, 1.0)), + pi = Vectors.dense(Array(0.2, 0.8)), theta = new DenseMatrix(2, 3, + Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)), "multinomial") ParamsSuite.checkParams(model) } @@ -72,15 +71,17 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Naive Bayes Multinomial") { val nPoints = 1000 - val pi = Array(0.5, 0.1, 0.4).map(math.log) - val theta = Array( + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( Array(0.70, 0.10, 0.10, 0.10), // label 0 Array(0.10, 0.70, 0.10, 0.10), // label 1 Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( - pi, theta, nPoints, 42, "multinomial")) + piArray, thetaArray, nPoints, 42, "multinomial")) val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") val model = nb.fit(testDataset) @@ -88,7 +89,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasParent) val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( - pi, theta, nPoints, 17, "multinomial")) + piArray, thetaArray, nPoints, 17, "multinomial")) val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") @@ -97,15 +98,17 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Naive Bayes Bernoulli") { val nPoints = 10000 - val pi = Array(0.5, 0.3, 0.2).map(math.log) - val theta = Array( + val piArray = Array(0.5, 0.3, 0.2).map(math.log) + val thetaArray = Array( Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( - pi, theta, nPoints, 45, "bernoulli")) + piArray, thetaArray, nPoints, 45, "bernoulli")) val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) @@ -113,7 +116,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasParent) val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( - pi, theta, nPoints, 20, "bernoulli")) + piArray, thetaArray, nPoints, 20, "bernoulli")) val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") From 3220b829f41c78ea5a61189e38e9d13ec1c17830 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 12 Jul 2015 15:58:34 +0800 Subject: [PATCH 3/6] trigger jenkins --- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 9935aedd2d413..a7f8a1424009a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -90,7 +90,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) @@ -117,7 +116,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) From a2b30880ec479ac644098c9c58fd3db4a12d4c59 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 15 Jul 2015 18:33:19 +0800 Subject: [PATCH 4/6] address comments --- .../spark/ml/classification/NaiveBayes.scala | 39 +++++++------------ .../mllib/classification/NaiveBayes.scala | 6 +-- .../ml/classification/NaiveBayesSuite.scala | 2 +- 3 files changed, 17 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index dee0cda6b4c13..ea1f4cb288e79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -50,9 +50,9 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * (default = multinomial) * @group param */ - final val modelType: Param[String] = new Param[String](this, "modelType", - "The model type which is a string (case-sensitive). Supported options: " + - "\"multinomial\" (default) and \"bernoulli\".") + final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + + "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", + ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) /** @group getParam */ final def getModelType: String = $(modelType) @@ -88,7 +88,7 @@ class NaiveBayes(override val uid: String) * Default is "multinomial" */ def setModelType(value: String): this.type = set(modelType, value) - setDefault(modelType -> "multinomial") + setDefault(modelType -> OldNaiveBayes.Multinomial) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -106,21 +106,17 @@ class NaiveBayesModel private[ml] ( override val uid: String, val labels: Vector, val pi: Vector, - val theta: Matrix, - val modelType: String) - extends PredictionModel[Vector, NaiveBayesModel] { + val theta: Matrix) + extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { - import NaiveBayesModel.{Bernoulli, Multinomial, supportedModelTypes} - - require(supportedModelTypes.contains(modelType), - s"NaiveBayes was created with an unknown modelType: $modelType.") + import OldNaiveBayes.{Bernoulli, Multinomial} /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra * application of this condition (in predict function). */ - private val (thetaMinusNegTheta, negThetaSum) = modelType match { + private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match { case Multinomial => (None, None) case Bernoulli => val negTheta = theta.map(value => math.log(1.0 - math.exp(value))) @@ -131,11 +127,11 @@ class NaiveBayesModel private[ml] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } override protected def predict(features: Vector): Double = { - modelType match { + $(modelType) match { case Multinomial => val prob = theta.multiply(features) BLAS.axpy(1.0, pi, prob) @@ -153,12 +149,12 @@ class NaiveBayesModel private[ml] ( labels(prob.argmax) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } } override def copy(extra: ParamMap): NaiveBayesModel = { - copyValues(new NaiveBayesModel(uid, labels, pi, theta, modelType), extra) + copyValues(new NaiveBayesModel(uid, labels, pi, theta), extra) } override def toString: String = { @@ -169,15 +165,6 @@ class NaiveBayesModel private[ml] ( private[ml] object NaiveBayesModel { - /** String name for multinomial model type. */ - private[classification] val Multinomial: String = "multinomial" - - /** String name for Bernoulli model type. */ - private[classification] val Bernoulli: String = "bernoulli" - - /** Set of modelTypes that NaiveBayes supports */ - private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) - /** Convert a model from the old API */ def fromOld( oldModel: OldNaiveBayesModel, @@ -187,6 +174,6 @@ private[ml] object NaiveBayesModel { val pi = Vectors.dense(oldModel.pi) val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, oldModel.theta.flatten, true) - new NaiveBayesModel(uid, labels, pi, theta, oldModel.modelType) + new NaiveBayesModel(uid, labels, pi, theta) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 29a9cf6dab187..f0490115a95cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -381,13 +381,13 @@ class NaiveBayes private ( object NaiveBayes { /** String name for multinomial model type. */ - private[classification] val Multinomial: String = "multinomial" + private[spark] val Multinomial: String = "multinomial" /** String name for Bernoulli model type. */ - private[classification] val Bernoulli: String = "bernoulli" + private[spark] val Bernoulli: String = "bernoulli" /* Set of modelTypes that NaiveBayes supports */ - private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index a7f8a1424009a..17ad062e921e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -56,7 +56,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", labels = Vectors.dense(Array(0.0, 1.0)), pi = Vectors.dense(Array(0.2, 0.8)), theta = new DenseMatrix(2, 3, - Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)), "multinomial") + Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4))) ParamsSuite.checkParams(model) } From c3de6874b6b7a73e652cb129d0bb18327594f32f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 16 Jul 2015 11:20:03 +0800 Subject: [PATCH 5/6] remove labels from ml.NaiveBayesModel --- .../spark/ml/classification/NaiveBayes.scala | 29 +++++++++++++++---- .../mllib/classification/NaiveBayes.scala | 2 +- .../ml/classification/NaiveBayesSuite.scala | 17 ++++------- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index ea1f4cb288e79..4a3ae126c810b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -92,6 +92,24 @@ class NaiveBayes(override val uid: String) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val instance = oldDataset.map{ + case LabeledPoint(label: Double, features: Vector) => label } + .treeAggregate(new MultiClassSummarizer)( + seqOp = (c, v) => (c, v) match { + case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) + }, + combOp = (c1, c2) => (c1, c2) match { + case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => + classSummarizer1.merge(classSummarizer2) + }) + val numInvalid = instance.countInvalid + val numClasses = instance.numClasses + if (numInvalid != 0) { + val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + + s"Found $numInvalid invalid labels." + logError(msg) + throw new SparkException(msg) + } val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } @@ -104,7 +122,6 @@ class NaiveBayes(override val uid: String) */ class NaiveBayesModel private[ml] ( override val uid: String, - val labels: Vector, val pi: Vector, val theta: Matrix) extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { @@ -135,7 +152,7 @@ class NaiveBayesModel private[ml] ( case Multinomial => val prob = theta.multiply(features) BLAS.axpy(1.0, pi, prob) - labels(prob.argmax) + prob.argmax case Bernoulli => features.foreachActive{ (index, value) => if (value != 0.0 && value != 1.0) { @@ -146,7 +163,7 @@ class NaiveBayesModel private[ml] ( val prob = thetaMinusNegTheta.get.multiply(features) BLAS.axpy(1.0, pi, prob) BLAS.axpy(1.0, negThetaSum.get, prob) - labels(prob.argmax) + prob.argmax case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") @@ -154,11 +171,11 @@ class NaiveBayesModel private[ml] ( } override def copy(extra: ParamMap): NaiveBayesModel = { - copyValues(new NaiveBayesModel(uid, labels, pi, theta), extra) + copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } override def toString: String = { - s"NaiveBayesModel with ${labels.size} classes" + s"NaiveBayesModel with ${pi.size} classes" } } @@ -174,6 +191,6 @@ private[ml] object NaiveBayesModel { val pi = Vectors.dense(oldModel.pi) val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, oldModel.theta.flatten, true) - new NaiveBayesModel(uid, labels, pi, theta) + new NaiveBayesModel(uid, pi, theta) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f0490115a95cf..0eb453ec3af46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -338,7 +338,7 @@ class NaiveBayes private ( BLAS.axpy(1.0, c2._2, c1._2) (c1._1 + c2._1, c1._2) } - ).collect() + ).collect().sortBy(_._1) val numLabels = aggregated.length var numDocuments = 0L diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 17ad062e921e4..76381a2741296 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -41,22 +41,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { piData: Vector, thetaData: Matrix, model: NaiveBayesModel): Unit = { - val modelIndex = (0 until piData.toArray.length).zip(model.labels.toArray.map(_.toInt)) - for (i <- modelIndex) { - assert(math.exp(model.pi(i._1)) ~== math.exp(piData(i._2)) absTol 0.05, "pi mismatch") - } - for (i <- modelIndex; - j <- 0 until thetaData.numCols) { - assert(math.exp(model.theta(i._1, j)) ~== math.exp(thetaData(i._2, j)) absTol 0.05, - "theta mismatch") - } + assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~== + Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch") + assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") } test("params") { ParamsSuite.checkParams(new NaiveBayes) - val model = new NaiveBayesModel("nb", labels = Vectors.dense(Array(0.0, 1.0)), - pi = Vectors.dense(Array(0.2, 0.8)), theta = new DenseMatrix(2, 3, - Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4))) + val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), + theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4))) ParamsSuite.checkParams(model) } From bc890f7ca99009a7f93e19710ada91a6477ee996 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 17 Jul 2015 22:17:59 +0800 Subject: [PATCH 6/6] remove labels valid check --- .../spark/ml/classification/NaiveBayes.scala | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 4a3ae126c810b..1f547e4a98af7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -92,24 +92,6 @@ class NaiveBayes(override val uid: String) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val instance = oldDataset.map{ - case LabeledPoint(label: Double, features: Vector) => label } - .treeAggregate(new MultiClassSummarizer)( - seqOp = (c, v) => (c, v) match { - case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) - }, - combOp = (c1, c2) => (c1, c2) match { - case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => - classSummarizer1.merge(classSummarizer2) - }) - val numInvalid = instance.countInvalid - val numClasses = instance.numClasses - if (numInvalid != 0) { - val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + - s"Found $numInvalid invalid labels." - logError(msg) - throw new SparkException(msg) - } val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) }