Skip to content

Commit

Permalink
[SPARK-10884][ML] Support prediction on single instance for regressio…
Browse files Browse the repository at this point in the history
…n and classification related models

## What changes were proposed in this pull request?

Support prediction on single instance for regression and classification related models (i.e., PredictionModel, ClassificationModel and their sub classes).
Add corresponding test cases.

## How was this patch tested?

Test cases added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19381 from WeichenXu123/single_prediction.
  • Loading branch information
WeichenXu123 authored and jkbradley committed Mar 21, 2018
1 parent 500b21c commit bf09f2f
Show file tree
Hide file tree
Showing 25 changed files with 176 additions and 23 deletions.
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Expand Up @@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,

/**
* Predict label for the given features.
* This internal method is used to implement `transform()` and output [[predictionCol]].
* This method is used to implement `transform()` and output [[predictionCol]].
*/
protected def predict(features: FeaturesType): Double
@Since("2.4.0")
def predict(features: FeaturesType): Double
}
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.classification

import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
Expand Down Expand Up @@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur

/**
* Predict label for the given features.
* This internal method is used to implement `transform()` and output [[predictionCol]].
* This method is used to implement `transform()` and output [[predictionCol]].
*
* This default implementation for classification predicts the index of the maximum value
* from `predictRaw()`.
*/
override protected def predict(features: FeaturesType): Double = {
override def predict(features: FeaturesType): Double = {
raw2prediction(predictRaw(features))
}

Expand Down
Expand Up @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}

Expand Down
Expand Up @@ -267,7 +267,7 @@ class GBTClassificationModel private[ml](
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
if (isDefined(thresholds)) {
super.predict(features)
Expand Down
Expand Up @@ -316,7 +316,7 @@ class LinearSVCModel private[classification] (
BLAS.dot(features, coefficients) + intercept
}

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
if (margin(features) > $(threshold)) 1.0 else 0.0
}

Expand Down
Expand Up @@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] (
* Predict label for the given feature vector.
* The behavior of this can be adjusted using `thresholds`.
*/
override protected def predict(features: Vector): Double = if (isMultinomial) {
override def predict(features: Vector): Double = if (isMultinomial) {
super.predict(features)
} else {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
Expand Down
Expand Up @@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
* Predict label for the given features.
* This internal method is used to implement `transform()` and output [[predictionCol]].
*/
override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
LabelConverter.decodeLabel(mlpModel.predict(features))
}

Expand Down
Expand Up @@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] (
private[ml] def this(rootNode: Node, numFeatures: Int) =
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}

Expand Down
Expand Up @@ -230,7 +230,7 @@ class GBTRegressionModel private[ml](
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
Expand Down
Expand Up @@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] (

private lazy val familyAndLink = FamilyAndLink(this)

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
predict(features, 0.0)
}

Expand Down
Expand Up @@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] (
}


override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
dot(features, coefficients) + intercept
}

Expand Down
Expand Up @@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] (
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
Expand Down
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
Expand Down Expand Up @@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
Vector, DecisionTreeClassificationModel](this, newTree, newData)
}

test("prediction on single instance") {
val rdd = continuousDataPointsForMulticlassRDD
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
.setMaxDepth(4)
.setMaxBins(100)
val categoricalFeatures = Map(0 -> 3)
val numClasses = 3

val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)

testPredictionModelSinglePrediction(newTree, newData)
}

test("training with 1-category categorical feature") {
val data = sc.parallelize(Seq(
LabeledPoint(0, Vectors.dense(0, 2, 3)),
Expand Down
Expand Up @@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
Vector, GBTClassificationModel](this, gbtModel, validationDataset)
}

test("prediction on single instance") {

val gbt = new GBTClassifier().setSeed(123)
val trainingDataset = trainData.toDF("label", "features")
val gbtModel = gbt.fit(trainingDataset)

testPredictionModelSinglePrediction(gbtModel, trainingDataset)
}

test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
Expand Down
Expand Up @@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}

test("prediction on single instance") {
val trainer = new LinearSVC()
val model = trainer.fit(smallBinaryDataset)
testPredictionModelSinglePrediction(model, smallBinaryDataset)
}

test("linearSVC comparison with R e1071 and scikit-learn") {
val trainer1 = new LinearSVC()
.setRegParam(0.00002) // set regParam = 2.0 / datasize / c
Expand Down
Expand Up @@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
Vector, LogisticRegressionModel](this, model, smallBinaryDataset)
}

test("prediction on single instance") {
val blor = new LogisticRegression().setFamily("binomial")
val blorModel = blor.fit(smallBinaryDataset)
testPredictionModelSinglePrediction(blorModel, smallBinaryDataset)
val mlor = new LogisticRegression().setFamily("multinomial")
val mlorModel = mlor.fit(smallMultinomialDataset)
testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset)
}

test("coefficients and intercept methods") {
val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial")
val mlrModel = mlr.fit(smallMultinomialDataset)
Expand Down
Expand Up @@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe
}
}

test("prediction on single instance") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
.setSeed(123L)
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
testPredictionModelSinglePrediction(model, dataset)
}

test("Predicted class probabilities: calibration on toy dataset") {
val layers = Array[Int](4, 5, 2)

Expand Down
Expand Up @@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
Vector, NaiveBayesModel](this, model, testDataset)
}

test("prediction on single instance") {
val nPoints = 1000
val piArray = Array(0.5, 0.1, 0.4).map(math.log)
val thetaArray = Array(
Array(0.70, 0.10, 0.10, 0.10), // label 0
Array(0.10, 0.70, 0.10, 0.10), // label 1
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)

val trainDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF()
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(trainDataset)

val validationDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()

testPredictionModelSinglePrediction(model, validationDataset)
}

test("Naive Bayes with weighted samples") {
val numClasses = 3
def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = {
Expand Down
Expand Up @@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
Vector, RandomForestClassificationModel](this, model, df)
}

test("prediction on single instance") {
val rdd = orderedLabeledPoints5_20
val rf = new RandomForestClassifier()
.setImpurity("Gini")
.setMaxDepth(3)
.setNumTrees(3)
.setSeed(123)
val categoricalFeatures = Map.empty[Int, Int]
val numClasses = 2

val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val model = rf.fit(df)

testPredictionModelSinglePrediction(model, df)
}

test("Fitting without numClasses in metadata") {
val df: DataFrame = TreeTests.featureImportanceData(sc).toDF()
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
Expand Down
Expand Up @@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
assert(importances.toArray.forall(_ >= 0.0))
}

test("prediction on single instance") {
val dt = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(3)
.setSeed(123)

// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)

val model = dt.fit(df)
testPredictionModelSinglePrediction(model, df)
}

test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
Expand Down
Expand Up @@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}

test("prediction on single instance") {
val gbt = new GBTRegressor()
.setMaxDepth(2)
.setMaxIter(2)
val model = gbt.fit(trainData.toDF())
testPredictionModelSinglePrediction(model, validationData.toDF)
}

test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
Expand Down
Expand Up @@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
assert(model.getLink === "identity")
}

test("prediction on single instance") {
val glr = new GeneralizedLinearRegression
val model = glr.setFamily("gaussian").setLink("identity")
.fit(datasetGaussianIdentity)

testPredictionModelSinglePrediction(model, datasetGaussianIdentity)
}

test("generalized linear regression: gaussian family against glm") {
/*
R code:
Expand Down
Expand Up @@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
}
}

test("prediction on single instance") {
val trainer = new LinearRegression
val model = trainer.fit(datasetWithDenseFeature)

testPredictionModelSinglePrediction(model, datasetWithDenseFeature)
}

test("linear regression model with constant label") {
/*
R code:
Expand Down
Expand Up @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}

/**
* Test suite for [[RandomForestRegressor]].
*/
class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest{
class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{

import RandomForestRegressorSuite.compareAPIs
import testImplicits._

private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _

Expand Down Expand Up @@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}

test("prediction on single instance") {
val rf = new RandomForestRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(10)
.setNumTrees(1)
.setFeatureSubsetStrategy("auto")
.setSeed(123)

val df = orderedLabeledPoints50_1000.toDF()
val model = rf.fit(df)
testPredictionModelSinglePrediction(model, df)
}

test("Feature importance with toy data") {
val rf = new RandomForestRegressor()
.setImpurity("variance")
Expand Down

0 comments on commit bf09f2f

Please sign in to comment.