Skip to content

Commit

Permalink
[SPARK-22881][ML][TEST] ML regression package testsuite add Structure…
Browse files Browse the repository at this point in the history
…dStreaming test

## What changes were proposed in this pull request?

ML regression package testsuite add StructuredStreaming test

In order to make testsuite easier to modify, new helper function added in `MLTest`:
```
def testTransformerByGlobalCheckFunc[A : Encoder](
      dataframe: DataFrame,
      transformer: Transformer,
      firstResultCol: String,
      otherResultCols: String*)
      (globalCheckFunction: Seq[Row] => Unit): Unit
```

## How was this patch tested?

N/A

Author: WeichenXu <weichen.xu@databricks.com>
Author: Bago Amirbekian <bago@databricks.com>

Closes #19979 from WeichenXu123/ml_stream_test.
  • Loading branch information
WeichenXu123 authored and jkbradley committed Dec 30, 2017
1 parent 8169630 commit 2ea17af
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@ package org.apache.spark.ml.regression

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._

class AFTSurvivalRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {

import testImplicits._

Expand Down Expand Up @@ -191,8 +188,8 @@ class AFTSurvivalRegressionSuite
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)

model.transform(datasetUnivariate).select("features", "prediction", "quantiles")
.collect().foreach {
testTransformer[(Vector, Double, Double)](datasetUnivariate, model,
"features", "prediction", "quantiles") {
case Row(features: Vector, prediction: Double, quantiles: Vector) =>
assert(prediction ~== model.predict(features) relTol 1E-5)
assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
Expand Down Expand Up @@ -261,8 +258,8 @@ class AFTSurvivalRegressionSuite
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)

model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
.collect().foreach {
testTransformer[(Vector, Double, Double)](datasetMultivariate, model,
"features", "prediction", "quantiles") {
case Row(features: Vector, prediction: Double, quantiles: Vector) =>
assert(prediction ~== model.predict(features) relTol 1E-5)
assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
Expand Down Expand Up @@ -331,8 +328,8 @@ class AFTSurvivalRegressionSuite
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)

model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
.collect().foreach {
testTransformer[(Vector, Double, Double)](datasetMultivariate, model,
"features", "prediction", "quantiles") {
case Row(features: Vector, prediction: Double, quantiles: Vector) =>
assert(prediction ~== model.predict(features) relTol 1E-5)
assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

class DecisionTreeRegressorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {

import DecisionTreeRegressorSuite.compareAPIs
import testImplicits._

private var categoricalDataPointsRDD: RDD[LabeledPoint] = _

Expand Down Expand Up @@ -89,33 +88,31 @@ class DecisionTreeRegressorSuite
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
val model = dt.fit(df)

val predictions = model.transform(df)
.select(model.getFeaturesCol, model.getVarianceCol)
.collect()

predictions.foreach { case Row(features: Vector, variance: Double) =>
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
assert(variance === expectedVariance,
s"Expected variance $expectedVariance but got $variance.")
testTransformer[(Vector, Double)](df, model, "features", "variance") {
case Row(features: Vector, variance: Double) =>
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
assert(variance === expectedVariance,
s"Expected variance $expectedVariance but got $variance.")
}

val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0)
dt.setMaxDepth(1)
.setMaxBins(6)
.setSeed(0)
val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map {
case Row(variance: Double) => variance
}

// Since max depth is set to 1, the best split point is that which splits the data
// into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
// data point in the left node is 0.667 and for each data point in the right node
// is 2.667
val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
assert(actual ~== expected absTol 1e-3)
testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF),
"variance") { case rows: Seq[Row] =>
val calculatedVariances = rows.map(_.getDouble(0))

// Since max depth is set to 1, the best split point is that which splits the data
// into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
// data point in the left node is 0.667 and for each data point in the right node
// is 2.667
val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
assert(actual ~== expected absTol 1e-3)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@ package org.apache.spark.ml.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.Utils

/**
* Test suite for [[GBTRegressor]].
*/
class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {

import GBTRegressorSuite.compareAPIs
import testImplicits._
Expand Down Expand Up @@ -91,11 +89,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val model = gbt.fit(df)

MLTestingUtils.checkCopyAndUids(gbt, model)
val preds = model.transform(df)
val predictions = preds.select("prediction").rdd.map(_.getDouble(0))
// Checks based on SPARK-8736 (to ensure it is not doing classification)
assert(predictions.max() > 2)
assert(predictions.min() < -1)

testTransformerByGlobalCheckFunc[(Double, Vector)](df, model, "prediction") {
case rows: Seq[Row] =>
val predictions = rows.map(_.getDouble(0))
// Checks based on SPARK-8736 (to ensure it is not doing classification)
assert(predictions.max > 2)
assert(predictions.min < -1)
}
}

test("Checkpointing") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@ import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.feature.{LabeledPoint, RFormula}
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatType

class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest {

import testImplicits._

Expand Down Expand Up @@ -268,8 +267,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = FamilyAndLink(trainer)
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
testTransformer[(Double, Vector)](dataset, model,
"features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
Expand All @@ -278,7 +277,7 @@ class GeneralizedLinearRegressionSuite
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.")
}
}

idx += 1
}
Expand Down Expand Up @@ -384,8 +383,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = FamilyAndLink(trainer)
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
testTransformer[(Double, Vector)](dataset, model,
"features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
Expand All @@ -394,7 +393,7 @@ class GeneralizedLinearRegressionSuite
s"binomial family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.")
}
}

idx += 1
}
Expand Down Expand Up @@ -456,8 +455,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = FamilyAndLink(trainer)
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
testTransformer[(Double, Vector)](dataset, model,
"features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
Expand All @@ -466,7 +465,7 @@ class GeneralizedLinearRegressionSuite
s"poisson family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.")
}
}

idx += 1
}
Expand Down Expand Up @@ -562,8 +561,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = FamilyAndLink(trainer)
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
testTransformer[(Double, Vector)](dataset, model,
"features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
Expand All @@ -572,7 +571,7 @@ class GeneralizedLinearRegressionSuite
s"gamma family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.")
}
}

idx += 1
}
Expand Down Expand Up @@ -649,8 +648,8 @@ class GeneralizedLinearRegressionSuite
s"and variancePower = $variancePower.")

val familyLink = FamilyAndLink(trainer)
model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect()
.foreach {
testTransformer[(Double, Double, Vector)](datasetTweedie, model,
"features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
Expand All @@ -661,7 +660,8 @@ class GeneralizedLinearRegressionSuite
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " +
s"and variancePower = $variancePower.")
}
}

idx += 1
}
}
Expand Down Expand Up @@ -724,8 +724,8 @@ class GeneralizedLinearRegressionSuite
s"fitIntercept = $fitIntercept and variancePower = $variancePower.")

val familyLink = FamilyAndLink(trainer)
model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect()
.foreach {
testTransformer[(Double, Double, Vector)](datasetTweedie, model,
"features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
Expand All @@ -736,7 +736,8 @@ class GeneralizedLinearRegressionSuite
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with tweedie family, fitIntercept = $fitIntercept " +
s"and variancePower = $variancePower.")
}
}

idx += 1
}
}
Expand Down Expand Up @@ -861,8 +862,8 @@ class GeneralizedLinearRegressionSuite
s" and fitIntercept = $fitIntercept.")

val familyLink = FamilyAndLink(trainer)
model.transform(dataset).select("features", "offset", "prediction", "linkPrediction")
.collect().foreach {
testTransformer[(Double, Double, Double, Vector)](dataset, model,
"features", "offset", "prediction", "linkPrediction") {
case Row(features: DenseVector, offset: Double, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset
Expand Down
Loading

0 comments on commit 2ea17af

Please sign in to comment.