From 2ea17afb63f976500273518bf1b32f9efe250812 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 29 Dec 2017 20:06:56 -0800 Subject: [PATCH] [SPARK-22881][ML][TEST] ML regression package testsuite add StructuredStreaming test ## What changes were proposed in this pull request? ML regression package testsuite add StructuredStreaming test In order to make testsuite easier to modify, new helper function added in `MLTest`: ``` def testTransformerByGlobalCheckFunc[A : Encoder]( dataframe: DataFrame, transformer: Transformer, firstResultCol: String, otherResultCols: String*) (globalCheckFunction: Seq[Row] => Unit): Unit ``` ## How was this patch tested? N/A Author: WeichenXu Author: Bago Amirbekian Closes #19979 from WeichenXu123/ml_stream_test. --- .../AFTSurvivalRegressionSuite.scala | 19 ++++---- .../DecisionTreeRegressorSuite.scala | 43 ++++++++--------- .../ml/regression/GBTRegressorSuite.scala | 23 ++++----- .../GeneralizedLinearRegressionSuite.scala | 47 ++++++++++--------- .../regression/IsotonicRegressionSuite.scala | 43 +++++++---------- .../ml/regression/LinearRegressionSuite.scala | 25 +++++----- .../org/apache/spark/ml/util/MLTest.scala | 39 +++++++++++---- .../apache/spark/ml/util/MLTestSuite.scala | 12 ++++- .../spark/sql/streaming/StreamTest.scala | 27 +++++------ 9 files changed, 147 insertions(+), 131 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 02e5c6d294f44..4e4ff71c9de90 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -19,19 +19,16 @@ package org.apache.spark.ml.regression import scala.util.Random -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ -class AFTSurvivalRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -191,8 +188,8 @@ class AFTSurvivalRegressionSuite assert(model.predict(features) ~== responsePredictR relTol 1E-3) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetUnivariate).select("features", "prediction", "quantiles") - .collect().foreach { + testTransformer[(Vector, Double, Double)](datasetUnivariate, model, + "features", "prediction", "quantiles") { case Row(features: Vector, prediction: Double, quantiles: Vector) => assert(prediction ~== model.predict(features) relTol 1E-5) assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) @@ -261,8 +258,8 @@ class AFTSurvivalRegressionSuite assert(model.predict(features) ~== responsePredictR relTol 1E-3) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction", "quantiles") - .collect().foreach { + testTransformer[(Vector, Double, Double)](datasetMultivariate, model, + "features", "prediction", "quantiles") { case Row(features: Vector, prediction: Double, quantiles: Vector) => assert(prediction ~== model.predict(features) relTol 1E-5) assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) @@ -331,8 +328,8 @@ class AFTSurvivalRegressionSuite assert(model.predict(features) ~== responsePredictR relTol 1E-3) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction", "quantiles") - .collect().foreach { + testTransformer[(Vector, Double, Double)](datasetMultivariate, model, + "features", "prediction", "quantiles") { case Row(features: Vector, prediction: Double, quantiles: Vector) => assert(prediction ~== model.predict(features) relTol 1E-5) assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 642f266891b57..68a1218c23ece 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -21,19 +21,18 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeRegressorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { import DecisionTreeRegressorSuite.compareAPIs + import testImplicits._ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ @@ -89,14 +88,11 @@ class DecisionTreeRegressorSuite val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val model = dt.fit(df) - val predictions = model.transform(df) - .select(model.getFeaturesCol, model.getVarianceCol) - .collect() - - predictions.foreach { case Row(features: Vector, variance: Double) => - val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() - assert(variance === expectedVariance, - s"Expected variance $expectedVariance but got $variance.") + testTransformer[(Vector, Double)](df, model, "features", "variance") { + case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") } val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc) @@ -104,18 +100,19 @@ class DecisionTreeRegressorSuite dt.setMaxDepth(1) .setMaxBins(6) .setSeed(0) - val transformVarDF = dt.fit(varianceDF).transform(varianceDF) - val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map { - case Row(variance: Double) => variance - } - // Since max depth is set to 1, the best split point is that which splits the data - // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each - // data point in the left node is 0.667 and for each data point in the right node - // is 2.667 - val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667) - calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) => - assert(actual ~== expected absTol 1e-3) + testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF), + "variance") { case rows: Seq[Row] => + val calculatedVariances = rows.map(_.getDouble(0)) + + // Since max depth is set to 1, the best split point is that which splits the data + // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each + // data point in the left node is 0.667 and for each data point in the right node + // is 2.667 + val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667) + calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) => + assert(actual ~== expected absTol 1e-3) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index ecbb57126d759..11c593b521e65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,22 +19,20 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs import testImplicits._ @@ -91,11 +89,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val model = gbt.fit(df) MLTestingUtils.checkCopyAndUids(gbt, model) - val preds = model.transform(df) - val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) - // Checks based on SPARK-8736 (to ensure it is not doing classification) - assert(predictions.max() > 2) - assert(predictions.min() < -1) + + testTransformerByGlobalCheckFunc[(Double, Vector)](df, model, "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max > 2) + assert(predictions.min < -1) + } } test("Checkpointing") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index df7dee869d058..ef2ff94a5e213 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.{LabeledPoint, RFormula} import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -33,8 +33,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.FloatType -class GeneralizedLinearRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -268,8 +267,8 @@ class GeneralizedLinearRegressionSuite s"$link link and fitIntercept = $fitIntercept.") val familyLink = FamilyAndLink(trainer) - model.transform(dataset).select("features", "prediction", "linkPrediction").collect() - .foreach { + testTransformer[(Double, Vector)](dataset, model, + "features", "prediction", "linkPrediction") { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept val prediction2 = familyLink.fitted(eta) @@ -278,7 +277,7 @@ class GeneralizedLinearRegressionSuite s"gaussian family, $link link and fitIntercept = $fitIntercept.") assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.") - } + } idx += 1 } @@ -384,8 +383,8 @@ class GeneralizedLinearRegressionSuite s"$link link and fitIntercept = $fitIntercept.") val familyLink = FamilyAndLink(trainer) - model.transform(dataset).select("features", "prediction", "linkPrediction").collect() - .foreach { + testTransformer[(Double, Vector)](dataset, model, + "features", "prediction", "linkPrediction") { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept val prediction2 = familyLink.fitted(eta) @@ -394,7 +393,7 @@ class GeneralizedLinearRegressionSuite s"binomial family, $link link and fitIntercept = $fitIntercept.") assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.") - } + } idx += 1 } @@ -456,8 +455,8 @@ class GeneralizedLinearRegressionSuite s"$link link and fitIntercept = $fitIntercept.") val familyLink = FamilyAndLink(trainer) - model.transform(dataset).select("features", "prediction", "linkPrediction").collect() - .foreach { + testTransformer[(Double, Vector)](dataset, model, + "features", "prediction", "linkPrediction") { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept val prediction2 = familyLink.fitted(eta) @@ -466,7 +465,7 @@ class GeneralizedLinearRegressionSuite s"poisson family, $link link and fitIntercept = $fitIntercept.") assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.") - } + } idx += 1 } @@ -562,8 +561,8 @@ class GeneralizedLinearRegressionSuite s"$link link and fitIntercept = $fitIntercept.") val familyLink = FamilyAndLink(trainer) - model.transform(dataset).select("features", "prediction", "linkPrediction").collect() - .foreach { + testTransformer[(Double, Vector)](dataset, model, + "features", "prediction", "linkPrediction") { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept val prediction2 = familyLink.fitted(eta) @@ -572,7 +571,7 @@ class GeneralizedLinearRegressionSuite s"gamma family, $link link and fitIntercept = $fitIntercept.") assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.") - } + } idx += 1 } @@ -649,8 +648,8 @@ class GeneralizedLinearRegressionSuite s"and variancePower = $variancePower.") val familyLink = FamilyAndLink(trainer) - model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect() - .foreach { + testTransformer[(Double, Double, Vector)](datasetTweedie, model, + "features", "prediction", "linkPrediction") { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept val prediction2 = familyLink.fitted(eta) @@ -661,7 +660,8 @@ class GeneralizedLinearRegressionSuite assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " + s"and variancePower = $variancePower.") - } + } + idx += 1 } } @@ -724,8 +724,8 @@ class GeneralizedLinearRegressionSuite s"fitIntercept = $fitIntercept and variancePower = $variancePower.") val familyLink = FamilyAndLink(trainer) - model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect() - .foreach { + testTransformer[(Double, Double, Vector)](datasetTweedie, model, + "features", "prediction", "linkPrediction") { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept val prediction2 = familyLink.fitted(eta) @@ -736,7 +736,8 @@ class GeneralizedLinearRegressionSuite assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with tweedie family, fitIntercept = $fitIntercept " + s"and variancePower = $variancePower.") - } + } + idx += 1 } } @@ -861,8 +862,8 @@ class GeneralizedLinearRegressionSuite s" and fitIntercept = $fitIntercept.") val familyLink = FamilyAndLink(trainer) - model.transform(dataset).select("features", "offset", "prediction", "linkPrediction") - .collect().foreach { + testTransformer[(Double, Double, Double, Vector)](dataset, model, + "features", "offset", "prediction", "linkPrediction") { case Row(features: DenseVector, offset: Double, prediction1: Double, linkPrediction1: Double) => val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 180f5f7ce5ab2..18fbbce936a2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.ml.regression -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class IsotonicRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -44,13 +41,11 @@ class IsotonicRegressionSuite val model = ir.fit(dataset) - val predictions = model - .transform(dataset) - .select("prediction").rdd.map { case Row(pred) => - pred - }.collect() - - assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + testTransformerByGlobalCheckFunc[(Double, Double, Double)](dataset, model, + "prediction") { case rows: Seq[Row] => + val predictions = rows.map(_.getDouble(0)) + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + } assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8)) assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) @@ -64,13 +59,11 @@ class IsotonicRegressionSuite val model = ir.fit(dataset) val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) - val predictions = model - .transform(features) - .select("prediction").rdd.map { - case Row(pred) => pred - }.collect() - - assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + testTransformerByGlobalCheckFunc[Tuple1[Double]](features, model, + "prediction") { case rows: Seq[Row] => + val predictions = rows.map(_.getDouble(0)) + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } } test("params validation") { @@ -157,13 +150,11 @@ class IsotonicRegressionSuite val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0)) - val predictions = model - .transform(features) - .select("prediction").rdd.map { - case Row(pred) => pred - }.collect() - - assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) + testTransformerByGlobalCheckFunc[Tuple1[Double]](features, model, + "prediction") { case rows: Seq[Row] => + val predictions = rows.map(_.getDouble(0)) + assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) + } } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 9bb2895858f33..d42cb1714478f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.ml.regression import scala.util.Random -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { @@ -363,8 +362,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept ~== interceptR2 relTol 1E-3) assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + @@ -416,8 +415,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept ~== interceptR2 absTol 1E-2) assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + @@ -467,7 +466,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept ~== interceptR2 relTol 1E-2) assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + @@ -518,7 +518,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept ~== interceptR2 absTol 1E-2) assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + @@ -570,8 +571,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept ~== interceptR2 relTol 1E-2) assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + @@ -624,8 +625,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept ~== interceptR2 absTol 1E-2) assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 7a5426ebadaa5..17678aa611a48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -53,12 +53,12 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => } } - def testTransformerOnStreamData[A : Encoder]( + private[util] def testTransformerOnStreamData[A : Encoder]( dataframe: DataFrame, transformer: Transformer, firstResultCol: String, otherResultCols: String*) - (checkFunction: Row => Unit): Unit = { + (globalCheckFunction: Seq[Row] => Unit): Unit = { val columnNames = dataframe.schema.fieldNames val stream = MemoryStream[A] @@ -70,22 +70,43 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => .select(firstResultCol, otherResultCols: _*) testStream(streamOutput) ( AddData(stream, data: _*), - CheckAnswer(checkFunction) + CheckAnswer(globalCheckFunction) ) } + private[util] def testTransformerOnDF( + dataframe: DataFrame, + transformer: Transformer, + firstResultCol: String, + otherResultCols: String*) + (globalCheckFunction: Seq[Row] => Unit): Unit = { + val dfOutput = transformer.transform(dataframe) + val outputs = dfOutput.select(firstResultCol, otherResultCols: _*).collect() + globalCheckFunction(outputs) + } + def testTransformer[A : Encoder]( dataframe: DataFrame, transformer: Transformer, firstResultCol: String, otherResultCols: String*) (checkFunction: Row => Unit): Unit = { - testTransformerOnStreamData(dataframe, transformer, firstResultCol, - otherResultCols: _*)(checkFunction) + testTransformerByGlobalCheckFunc( + dataframe, + transformer, + firstResultCol, + otherResultCols: _*) { rows: Seq[Row] => rows.foreach(checkFunction(_)) } + } - val dfOutput = transformer.transform(dataframe) - dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row => - checkFunction(row) - } + def testTransformerByGlobalCheckFunc[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + firstResultCol: String, + otherResultCols: String*) + (globalCheckFunction: Seq[Row] => Unit): Unit = { + testTransformerOnStreamData(dataframe, transformer, firstResultCol, + otherResultCols: _*)(globalCheckFunction) + testTransformerOnDF(dataframe, transformer, firstResultCol, + otherResultCols: _*)(globalCheckFunction) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala index 56217ec4f3b0c..20c5b5395f6a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.{PipelineModel, Transformer} import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.sql.Row @@ -32,10 +31,13 @@ class MLTestSuite extends MLTest { val indexer = new StringIndexer().setStringOrderType("alphabetAsc") .setInputCol("label").setOutputCol("indexed") val indexerModel = indexer.fit(data) - testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { + testTransformer[(Int, String)](data, indexerModel, "id", "indexed") { case Row(id: Int, indexed: Double) => assert(id === indexed.toInt) } + testTransformerByGlobalCheckFunc[(Int, String)] (data, indexerModel, "id", "indexed") { rows => + assert(rows.map(_.getDouble(1)).max === 5.0) + } intercept[Exception] { testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { @@ -43,5 +45,11 @@ class MLTestSuite extends MLTest { assert(id != indexed.toInt) } } + intercept[Exception] { + testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { + rows: Seq[Row] => + assert(rows.map(_.getDouble(1)).max === 1.0) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index fb9ebc81dd750..4b7f0fbe97d4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -137,8 +137,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false) - def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = - CheckAnswerRowsByFunc(checkFunction, false) + def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(globalCheckFunction, false) } /** @@ -161,8 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false) - def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = - CheckAnswerRowsByFunc(checkFunction, true) + def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(globalCheckFunction, true) } case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean) @@ -177,9 +177,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } - case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean) - extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName: ${checkFunction.toString()}" + case class CheckAnswerRowsByFunc( + globalCheckFunction: Seq[Row] => Unit, + lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName" private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } @@ -639,14 +640,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } - case CheckAnswerRowsByFunc(checkFunction, lastOnly) => + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - sparkAnswer.foreach { row => - try { - checkFunction(row) - } catch { - case e: Throwable => failTest(e.toString) - } + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) } } pos += 1