diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 089019dd8b0a1..7acf765869693 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - + package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite @@ -33,7 +33,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext test("train validation with logistic regression") { val dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) - + val lr = new LogisticRegression val lrParamMaps = new ParamGridBuilder() .addGrid(lr.regParam, Array(0.001, 1000.0)) @@ -52,7 +52,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) } - + test("train validation with linear regression") { val dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -82,7 +82,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext assert(parent2.getMaxIter === 10) assert(cvModel2.avgMetrics.length === lrParamMaps.length) } - + test("validateParams should check estimatorParamMaps") { import TrainValidationSplitSuite._