-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-8600] [ML] Naive Bayes API for spark.ml Pipelines #7284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
208e166
3018a41
3220b82
a2b3088
c3de687
bc890f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.classification | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} | ||
| import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} | ||
| import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} | ||
| import org.apache.spark.mllib.linalg._ | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.DataFrame | ||
|
|
||
| /** | ||
| * Params for Naive Bayes Classifiers. | ||
| */ | ||
| private[ml] trait NaiveBayesParams extends PredictorParams { | ||
|
|
||
| /** | ||
| * The smoothing parameter. | ||
| * (default = 1.0). | ||
| * @group param | ||
| */ | ||
| final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", | ||
| ParamValidators.gtEq(0)) | ||
|
|
||
| /** @group getParam */ | ||
| final def getLambda: Double = $(lambda) | ||
|
|
||
| /** | ||
| * The model type which is a string (case-sensitive). | ||
| * Supported options: "multinomial" and "bernoulli". | ||
| * (default = multinomial) | ||
| * @group param | ||
| */ | ||
| final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + | ||
| "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", | ||
| ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) | ||
|
|
||
| /** @group getParam */ | ||
| final def getModelType: String = $(modelType) | ||
| } | ||
|
|
||
| /** | ||
| * Naive Bayes Classifiers. | ||
| * It supports both Multinomial NB | ||
| * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) | ||
| * which can handle finitely supported discrete data. For example, by converting documents into | ||
| * TF-IDF vectors, it can be used for document classification. By making every vector a | ||
| * binary (0/1) data, it can also be used as Bernoulli NB | ||
| * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). | ||
| * The input feature values must be nonnegative. | ||
| */ | ||
| class NaiveBayes(override val uid: String) | ||
| extends Predictor[Vector, NaiveBayes, NaiveBayesModel] | ||
| with NaiveBayesParams { | ||
|
|
||
| def this() = this(Identifiable.randomUID("nb")) | ||
|
|
||
| /** | ||
| * Set the smoothing parameter. | ||
| * Default is 1.0. | ||
| * @group setParam | ||
| */ | ||
| def setLambda(value: Double): this.type = set(lambda, value) | ||
| setDefault(lambda -> 1.0) | ||
|
|
||
| /** | ||
| * Set the model type using a string (case-sensitive). | ||
| * Supported options: "multinomial" and "bernoulli". | ||
| * Default is "multinomial" | ||
| */ | ||
| def setModelType(value: String): this.type = set(modelType, value) | ||
| setDefault(modelType -> OldNaiveBayes.Multinomial) | ||
|
|
||
| override protected def train(dataset: DataFrame): NaiveBayesModel = { | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
| val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) | ||
| NaiveBayesModel.fromOld(oldModel, this) | ||
| } | ||
|
|
||
| override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) | ||
| } | ||
|
|
||
| /** | ||
| * Model produced by [[NaiveBayes]] | ||
| */ | ||
| class NaiveBayesModel private[ml] ( | ||
| override val uid: String, | ||
| val pi: Vector, | ||
| val theta: Matrix) | ||
| extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { | ||
|
|
||
| import OldNaiveBayes.{Bernoulli, Multinomial} | ||
|
|
||
| /** | ||
| * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. | ||
| * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra | ||
| * application of this condition (in predict function). | ||
| */ | ||
| private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match { | ||
| case Multinomial => (None, None) | ||
| case Bernoulli => | ||
| val negTheta = theta.map(value => math.log(1.0 - math.exp(value))) | ||
| val ones = new DenseVector(Array.fill(theta.numCols){1.0}) | ||
| val thetaMinusNegTheta = theta.map { value => | ||
| value - math.log(1.0 - math.exp(value)) | ||
| } | ||
| (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) | ||
| case _ => | ||
| // This should never happen. | ||
| throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") | ||
| } | ||
|
|
||
| override protected def predict(features: Vector): Double = { | ||
| $(modelType) match { | ||
| case Multinomial => | ||
| val prob = theta.multiply(features) | ||
| BLAS.axpy(1.0, pi, prob) | ||
| prob.argmax | ||
| case Bernoulli => | ||
| features.foreachActive{ (index, value) => | ||
| if (value != 0.0 && value != 1.0) { | ||
| throw new SparkException( | ||
| s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") | ||
| } | ||
| } | ||
| val prob = thetaMinusNegTheta.get.multiply(features) | ||
| BLAS.axpy(1.0, pi, prob) | ||
| BLAS.axpy(1.0, negThetaSum.get, prob) | ||
| prob.argmax | ||
| case _ => | ||
| // This should never happen. | ||
| throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") | ||
| } | ||
| } | ||
|
|
||
| override def copy(extra: ParamMap): NaiveBayesModel = { | ||
| copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) | ||
| } | ||
|
|
||
| override def toString: String = { | ||
| s"NaiveBayesModel with ${pi.size} classes" | ||
| } | ||
|
|
||
| } | ||
|
|
||
| private[ml] object NaiveBayesModel { | ||
|
|
||
| /** Convert a model from the old API */ | ||
| def fromOld( | ||
| oldModel: OldNaiveBayesModel, | ||
| parent: NaiveBayes): NaiveBayesModel = { | ||
| val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") | ||
| val labels = Vectors.dense(oldModel.labels) | ||
| val pi = Vectors.dense(oldModel.pi) | ||
| val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, | ||
| oldModel.theta.flatten, true) | ||
| new NaiveBayesModel(uid, pi, theta) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} | |
| * where D is number of features | ||
| * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" | ||
| */ | ||
| class NaiveBayesModel private[mllib] ( | ||
| class NaiveBayesModel private[spark] ( | ||
| val labels: Array[Double], | ||
| val pi: Array[Double], | ||
| val theta: Array[Array[Double]], | ||
|
|
@@ -338,7 +338,7 @@ class NaiveBayes private ( | |
| BLAS.axpy(1.0, c2._2, c1._2) | ||
| (c1._1 + c2._1, c1._2) | ||
| } | ||
| ).collect() | ||
| ).collect().sortBy(_._1) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sort by "labels" ascending so the pi and theta use the 0-based indices for labels. |
||
| val numLabels = aggregated.length | ||
| var numDocuments = 0L | ||
|
|
@@ -381,13 +381,13 @@ class NaiveBayes private ( | |
| object NaiveBayes { | ||
|
|
||
| /** String name for multinomial model type. */ | ||
| private[classification] val Multinomial: String = "multinomial" | ||
| private[spark] val Multinomial: String = "multinomial" | ||
|
|
||
| /** String name for Bernoulli model type. */ | ||
| private[classification] val Bernoulli: String = "bernoulli" | ||
| private[spark] val Bernoulli: String = "bernoulli" | ||
|
|
||
| /* Set of modelTypes that NaiveBayes supports */ | ||
| private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) | ||
| private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) | ||
|
|
||
| /** | ||
| * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.classification; | ||
|
|
||
| import java.io.Serializable; | ||
|
|
||
| import com.google.common.collect.Lists; | ||
| import org.junit.After; | ||
| import org.junit.Before; | ||
| import org.junit.Test; | ||
|
|
||
| import org.apache.spark.api.java.JavaRDD; | ||
| import org.apache.spark.api.java.JavaSparkContext; | ||
| import org.apache.spark.mllib.linalg.VectorUDT; | ||
| import org.apache.spark.mllib.linalg.Vectors; | ||
| import org.apache.spark.sql.DataFrame; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.RowFactory; | ||
| import org.apache.spark.sql.SQLContext; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.Metadata; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
|
|
||
| public class JavaNaiveBayesSuite implements Serializable { | ||
|
|
||
| private transient JavaSparkContext jsc; | ||
| private transient SQLContext jsql; | ||
|
|
||
| @Before | ||
| public void setUp() { | ||
| jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); | ||
| jsql = new SQLContext(jsc); | ||
| } | ||
|
|
||
| @After | ||
| public void tearDown() { | ||
| jsc.stop(); | ||
| jsc = null; | ||
| } | ||
|
|
||
| public void validatePrediction(DataFrame predictionAndLabels) { | ||
| for (Row r : predictionAndLabels.collect()) { | ||
| double prediction = r.getAs(0); | ||
| double label = r.getAs(1); | ||
| assert(prediction == label); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| public void naiveBayesDefaultParams() { | ||
| NaiveBayes nb = new NaiveBayes(); | ||
| assert(nb.getLabelCol() == "label"); | ||
| assert(nb.getFeaturesCol() == "features"); | ||
| assert(nb.getPredictionCol() == "prediction"); | ||
| assert(nb.getLambda() == 1.0); | ||
| assert(nb.getModelType() == "multinomial"); | ||
| } | ||
|
|
||
| @Test | ||
| public void testNaiveBayes() { | ||
| JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList( | ||
| RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), | ||
| RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), | ||
| RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), | ||
| RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)), | ||
| RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)), | ||
| RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)) | ||
| )); | ||
|
|
||
| StructType schema = new StructType(new StructField[]{ | ||
| new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), | ||
| new StructField("features", new VectorUDT(), false, Metadata.empty()) | ||
| }); | ||
|
|
||
| DataFrame dataset = jsql.createDataFrame(jrdd, schema); | ||
| NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); | ||
| NaiveBayesModel model = nb.fit(dataset); | ||
|
|
||
| DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); | ||
| validatePrediction(predictionAndLabels); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great that you included this 😃 In the future, it's okay to just check compatibility in Java tests and leave correctness tests to Scala tests. |
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
State default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should defaults be stated here or in
NaiveBayes(where I'm suggesting they be set)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the bulk of the description of a parameter is in the Param, rather than the getter/setter, I'd prefer the default be stated in the Param as well. But it's fine if it's stated in the setter method as well.