From 42c61eaeecff0c37fee1be35ba48494b70639b70 Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 23 Jul 2018 00:58:08 +0530 Subject: [PATCH] [Minor][ML]Added UT for checking maximum number of features for GeneralizedLinearRegression and WeightedLeastSquares Currently GeneralizedLinearRegression and WeightedLeastSquare doesn't support features more than 4096. But there is no UT added to check this. In this PR, I have added UT for checking the behaviour of both the algorithms, if the features are more than 4096. --- .../ml/optim/WeightedLeastSquaresSuite.scala | 18 ++++++++++++++++- .../GeneralizedLinearRegressionSuite.scala | 20 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 093d02ea7a14b..5ef2c4d1d7381 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.optim -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{BLAS, Vectors} import org.apache.spark.ml.util.TestingUtils._ @@ -539,4 +539,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext idx += 1 } } + + test("number of features more than MAX_NUM_FEATURES") { + // Evaluate with an Instances RDD, that contains more than the maximum number of features. + val numFeatures = WeightedLeastSquares.MAX_NUM_FEATURES + 1 + val inst = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.sparse(numFeatures, Array(0, 4), Array(3.0, 8.0))), + Instance(17.0, 2.0, Vectors.sparse(numFeatures, Array(1, 5), Array(3.0, 8.0)))), 2) + + val wls = new WeightedLeastSquares(fitIntercept = true, regParam = 0.5, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false) + + intercept[SparkException] { + wls.fit(inst) + }.getMessage.contains(s"we set the max number of features to " + + s"${WeightedLeastSquares.MAX_NUM_FEATURES} but got $numFeatures.") + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 997c50157dcda..829ed986dca98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.ml.regression import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.{LabeledPoint, RFormula} import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} +import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -1664,6 +1665,23 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } } + test("number of features more than MAX_NUM_FEATURES") { + // Evaluate with a dataset that contains more than the maximum number of features. + val numFeatues = WeightedLeastSquares.MAX_NUM_FEATURES + 1 + val dataset = Seq( + Instance(17.0, 1.0, Vectors.sparse(numFeatues, Array(0, 4), Array(3.0, 8.0))), + Instance(19.0, 1.0, Vectors.sparse(numFeatues, Array(1, 5), Array(4.0, 9.0)))) + .toDF() + + val trainer = new GeneralizedLinearRegression() + .setMaxIter(1) + + intercept[SparkException] { + trainer.fit(dataset) + }.getMessage.contains(s"GeneralizedLinearRegression only supports number of features <=" + + s" ${WeightedLeastSquares.MAX_NUM_FEATURES}.") + } + test("evaluate with labels that are not doubles") { // Evaulate with a dataset that contains Labels not as doubles to verify correct casting val dataset = Seq(