Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.classification

import org.apache.spark.SparkException
import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame

/**
* Params for Naive Bayes Classifiers.
*/
private[ml] trait NaiveBayesParams extends PredictorParams {

/**
* The smoothing parameter.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

State default

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should defaults be stated here or in NaiveBayes (where I'm suggesting they be set)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the bulk of the description of a parameter is in the Param, rather than the getter/setter, I'd prefer the default be stated in the Param as well. But it's fine if it's stated in the setter method as well.

* (default = 1.0).
* @group param
*/
final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
ParamValidators.gtEq(0))

/** @group getParam */
final def getLambda: Double = $(lambda)

/**
* The model type which is a string (case-sensitive).
* Supported options: "multinomial" and "bernoulli".
* (default = multinomial)
* @group param
*/
final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
"which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray))

/** @group getParam */
final def getModelType: String = $(modelType)
}

/**
* Naive Bayes Classifiers.
* It supports both Multinomial NB
* ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]])
* which can handle finitely supported discrete data. For example, by converting documents into
* TF-IDF vectors, it can be used for document classification. By making every vector a
* binary (0/1) data, it can also be used as Bernoulli NB
* ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
* The input feature values must be nonnegative.
*/
class NaiveBayes(override val uid: String)
extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {

def this() = this(Identifiable.randomUID("nb"))

/**
* Set the smoothing parameter.
* Default is 1.0.
* @group setParam
*/
def setLambda(value: Double): this.type = set(lambda, value)
setDefault(lambda -> 1.0)

/**
* Set the model type using a string (case-sensitive).
* Supported options: "multinomial" and "bernoulli".
* Default is "multinomial"
*/
def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial)

override protected def train(dataset: DataFrame): NaiveBayesModel = {
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)
}

override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
}

/**
* Model produced by [[NaiveBayes]]
*/
class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {

import OldNaiveBayes.{Bernoulli, Multinomial}

/**
* Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
* This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
* application of this condition (in predict function).
*/
private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
case Multinomial => (None, None)
case Bernoulli =>
val negTheta = theta.map(value => math.log(1.0 - math.exp(value)))
val ones = new DenseVector(Array.fill(theta.numCols){1.0})
val thetaMinusNegTheta = theta.map { value =>
value - math.log(1.0 - math.exp(value))
}
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}

override protected def predict(features: Vector): Double = {
$(modelType) match {
case Multinomial =>
val prob = theta.multiply(features)
BLAS.axpy(1.0, pi, prob)
prob.argmax
case Bernoulli =>
features.foreachActive{ (index, value) =>
if (value != 0.0 && value != 1.0) {
throw new SparkException(
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
}
}
val prob = thetaMinusNegTheta.get.multiply(features)
BLAS.axpy(1.0, pi, prob)
BLAS.axpy(1.0, negThetaSum.get, prob)
prob.argmax
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
}

override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
}

override def toString: String = {
s"NaiveBayesModel with ${pi.size} classes"
}

}

private[ml] object NaiveBayesModel {

/** Convert a model from the old API */
def fromOld(
oldModel: OldNaiveBayesModel,
parent: NaiveBayes): NaiveBayesModel = {
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
val labels = Vectors.dense(oldModel.labels)
val pi = Vectors.dense(oldModel.pi)
val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
oldModel.theta.flatten, true)
new NaiveBayesModel(uid, pi, theta)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* where D is number of features
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
*/
class NaiveBayesModel private[mllib] (
class NaiveBayesModel private[spark] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
Expand Down Expand Up @@ -338,7 +338,7 @@ class NaiveBayes private (
BLAS.axpy(1.0, c2._2, c1._2)
(c1._1 + c2._1, c1._2)
}
).collect()
).collect().sortBy(_._1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort by "labels" ascending so the pi and theta use the 0-based indices for labels.

val numLabels = aggregated.length
var numDocuments = 0L
Expand Down Expand Up @@ -381,13 +381,13 @@ class NaiveBayes private (
object NaiveBayes {

/** String name for multinomial model type. */
private[classification] val Multinomial: String = "multinomial"
private[spark] val Multinomial: String = "multinomial"

/** String name for Bernoulli model type. */
private[classification] val Bernoulli: String = "bernoulli"
private[spark] val Bernoulli: String = "bernoulli"

/* Set of modelTypes that NaiveBayes supports */
private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)

/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable {
/** Map the values of this matrix using a function. Generates a new matrix. Performs the
* function on only the backing array. For example, an operation such as addition or
* subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
private[mllib] def map(f: Double => Double): Matrix
private[spark] def map(f: Double => Double): Matrix

/** Update all the values of this matrix using the function f. Performed in-place on the
* backing array. For example, an operation such as addition or subtraction will only be
Expand Down Expand Up @@ -291,7 +291,7 @@ class DenseMatrix(

override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())

private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
isTransposed)

private[mllib] def update(f: Double => Double): DenseMatrix = {
Expand Down Expand Up @@ -557,7 +557,7 @@ class SparseMatrix(
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
}

private[mllib] def map(f: Double => Double) =
private[spark] def map(f: Double => Double) =
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)

private[mllib] def update(f: Double => Double): SparseMatrix = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.classification;

import java.io.Serializable;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class JavaNaiveBayesSuite implements Serializable {

private transient JavaSparkContext jsc;
private transient SQLContext jsql;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

public void validatePrediction(DataFrame predictionAndLabels) {
for (Row r : predictionAndLabels.collect()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
assert(prediction == label);
}
}

@Test
public void naiveBayesDefaultParams() {
NaiveBayes nb = new NaiveBayes();
assert(nb.getLabelCol() == "label");
assert(nb.getFeaturesCol() == "features");
assert(nb.getPredictionCol() == "prediction");
assert(nb.getLambda() == 1.0);
assert(nb.getModelType() == "multinomial");
}

@Test
public void testNaiveBayes() {
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
));

StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});

DataFrame dataset = jsql.createDataFrame(jrdd, schema);
NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);

DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
validatePrediction(predictionAndLabels);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great that you included this 😃 In the future, it's okay to just check compatibility in Java tests and leave correctness tests to Scala tests.

}
}
Loading