From 1161a2e2678daea4dba7a56b6413541f23317eae Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Wed, 24 Jun 2015 16:36:14 -0400 Subject: [PATCH 1/6] SPARK-8484. Added TrainValidationSplit for hyper-parameter tuning. It randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model. --- .../spark/ml/tuning/CrossValidator.scala | 127 +++--------- .../ml/tuning/TrainValidationSplit.scala | 108 ++++++++++ .../apache/spark/ml/tuning/Validation.scala | 190 ++++++++++++++++++ .../org/apache/spark/mllib/util/MLUtils.scala | 26 ++- .../spark/ml/tuning/CrossValidatorSuite.scala | 53 +---- .../ml/tuning/TrainValidationSplitSuite.scala | 104 ++++++++++ .../spark/ml/tuning/ValidationSuite.scala | 56 ++++++ 7 files changed, 512 insertions(+), 152 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e2444ab65b43b..9b6b2bdd35211 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -27,43 +27,11 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends Params { - - /** - * param for the estimator to be cross-validated - * @group param - */ - val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") - - /** @group getParam */ - def getEstimator: Estimator[_] = $(estimator) - - /** - * param for estimator param maps - * @group param - */ - val estimatorParamMaps: Param[Array[ParamMap]] = - new Param(this, "estimatorParamMaps", "param maps for the estimator") - - /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) - - /** - * param for the evaluator used to select hyper-parameters that maximize the cross-validated - * metric - * @group param - */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - - /** @group getParam */ - def getEvaluator: Evaluator = $(evaluator) - +private[ml] trait CrossValidatorParams extends ValidationParams { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -83,82 +51,53 @@ private[ml] trait CrossValidatorParams extends Params { * K-fold cross validation. */ @Experimental -class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] +class CrossValidator(uid: String) + extends Validation[CrossValidatorModel, CrossValidator](uid) with CrossValidatorParams with Logging { def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS - /** @group setParam */ - def setEstimator(value: Estimator[_]): this.type = set(estimator, value) - - /** @group setParam */ - def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) - - /** @group setParam */ - def setEvaluator(value: Evaluator): this.type = set(evaluator, value) - /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: DataFrame): CrossValidatorModel = { + override protected[ml] def validationLogic( + dataset: DataFrame, + est: Estimator[_], + eval: Evaluator, + epm: Array[ParamMap], + numModels: Int): Array[Double] = { + val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext - val est = $(estimator) - val eval = $(evaluator) - val epm = $(estimatorParamMaps) - val numModels = epm.length + val metrics = new Array[Double](epm.length) val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() - // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() + val newMetrics = measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) + var i = 0 while (i < numModels) { - // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) - logDebug(s"Got metric $metric for model trained with ${epm(i)}.") - metrics(i) += metric + metrics(i) += newMetrics(i) i += 1 } - validationDataset.unpersist() } - f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") - val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) - } - - override def transformSchema(schema: StructType): StructType = { - $(estimator).transformSchema(schema) - } - override def validateParams(): Unit = { - super.validateParams() - val est = $(estimator) - for (paramMap <- $(estimatorParamMaps)) { - est.copy(paramMap).validateParams() - } + f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) + metrics } - override def copy(extra: ParamMap): CrossValidator = { - val copied = defaultCopy(extra).asInstanceOf[CrossValidator] - if (copied.isDefined(estimator)) { - copied.setEstimator(copied.getEstimator.copy(extra)) - } - if (copied.isDefined(evaluator)) { - copied.setEvaluator(copied.getEvaluator.copy(extra)) - } - copied + override protected[ml] def createModel( + uid: String, + bestModel: Model[_], + metrics: Array[Double]): CrossValidatorModel = { + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } } @@ -168,23 +107,11 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM */ @Experimental class CrossValidatorModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { - - override def validateParams(): Unit = { - bestModel.validateParams() - } - - override def transform(dataset: DataFrame): DataFrame = { - transformSchema(dataset.schema, logging = true) - bestModel.transform(dataset) - } - - override def transformSchema(schema: StructType): StructType = { - bestModel.transformSchema(schema) - } + uid: String, + bestModel: Model[_], + avgMetrics: Array[Double]) + extends ValidationModel[CrossValidatorModel](uid, bestModel, avgMetrics) + with CrossValidatorParams { override def copy(extra: ParamMap): CrossValidatorModel = { val copied = new CrossValidatorModel( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala new file mode 100644 index 0000000000000..04b098c5b159d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.DataFrame + +/** + * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. + */ +private[ml] trait TrainValidationSplitParams extends ValidationParams { + /** + * Param for ratio between train and validation data. Must be between 0 and 1. + * Default: 0.75 + * @group param + */ + val trainRatio: DoubleParam = new DoubleParam(this, "numFolds", + "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) + + /** @group getParam */ + def getTrainPercent: Double = $(trainRatio) + + setDefault(trainRatio -> 0.75) +} + +/** + * :: Experimental :: + * Validation for hyper-parameter tuning. + * Randomly splits the input dataset into train and validation sets. + * And uses evaluation metric on the validation set to select the best model. + * Similar to CrossValidator, but only splits the set once. + */ +@Experimental +class TrainValidationSplit(uid: String) + extends Validation[TrainValidationSplitModel, TrainValidationSplit](uid) + with TrainValidationSplitParams with Logging { + + def this() = this(Identifiable.randomUID("cv")) + + /** @group setParam */ + def setTrainRatio(value: Double): this.type = set(trainRatio, value) + + override protected[ml] def validationLogic( + dataset: DataFrame, + est: Estimator[_], + eval: Evaluator, + epm: Array[ParamMap], + numModels: Int): Array[Double] = { + + val schema = dataset.schema + transformSchema(schema, logging = true) + val sqlCtx = dataset.sqlContext + + val splits = MLUtils.sample(dataset.rdd, $(trainRatio), 1) + val trainingDataset = sqlCtx.createDataFrame(splits._1, schema).cache() + val validationDataset = sqlCtx.createDataFrame(splits._2, schema).cache() + measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) + } + + override protected[ml] def createModel( + uid: String, + bestModel: Model[_], + metrics: Array[Double]): TrainValidationSplitModel = { + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + } +} + +/** + * :: Experimental :: + * Model from train validation split. + */ +@Experimental +class TrainValidationSplitModel private[ml] ( + uid: String, + bestModel: Model[_], + avgMetrics: Array[Double]) + extends ValidationModel[TrainValidationSplitModel](uid, bestModel, avgMetrics) + with TrainValidationSplitParams { + + override def copy(extra: ParamMap): TrainValidationSplitModel = { + val copied = new TrainValidationSplitModel ( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala new file mode 100644 index 0000000000000..3e93f7c8097ac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +import scala.reflect.ClassTag + +/** + * :: DeveloperApi :: + * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. + */ +@DeveloperApi +private[ml] trait ValidationParams extends Params { + + /** + * param for the estimator to be validated + * @group param + */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + + /** @group getParam */ + def getEstimator: Estimator[_] = $(estimator) + + /** + * param for estimator param maps + * @group param + */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + + /** @group getParam */ + def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) + + /** + * param for the evaluator used to select hyper-parameters that maximize the validated metric + * @group param + */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", + "evaluator used to select hyper-parameters that maximize the validated metric") + + /** @group getParam */ + def getEvaluator: Evaluator = $(evaluator) +} + +/** + * :: DeveloperApi :: + * Abstract class for validation approaches for hyper-parameter tuning. + */ +@DeveloperApi +private[ml] abstract class Validation[M <: Model[M], V <: Validation[M, _] : ClassTag] + (override val uid: String) + extends Estimator[M] + with Logging with ValidationParams { + + def this() = this(Identifiable.randomUID("cv")) + + /** @group setParam */ + def setEstimator(value: Estimator[_]): V = set(estimator, value).asInstanceOf[V] + + /** @group setParam */ + def setEstimatorParamMaps(value: Array[ParamMap]): V = + set(estimatorParamMaps, value).asInstanceOf[V] + + /** @group setParam */ + def setEvaluator(value: Evaluator): V = set(evaluator, value).asInstanceOf[V] + + override def fit(dataset: DataFrame): M = { + val sqlCtx = dataset.sqlContext + val est = $(estimator) + val eval = $(evaluator) + val epm = $(estimatorParamMaps) + val numModels = epm.length + + val metrics = validationLogic(dataset, est, eval, epm, numModels) + + logInfo(s"Average validation metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best cross-validation metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + + createModel(uid, bestModel, metrics) + } + + private[ml] def measureModels( + trainingDataset: DataFrame, + validationDataset: DataFrame, + est: Estimator[_], + eval: Evaluator, + epm: Array[ParamMap], + numModels: Int) = { + + val metrics = new Array[Double](numModels) + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() + var i = 0 + + // multi-model training + while (i < numModels) { + // TODO: duplicate evaluator to take extra params from input + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + validationDataset.unpersist() + + metrics + } + + protected[ml] def validationLogic( + dataset: DataFrame, + est: Estimator[_], + eval: Evaluator, + epm: Array[ParamMap], + numModels: Int): Array[Double] + + protected[ml] def createModel(uid: String, bestModel: Model[_], metrics: Array[Double]): M + + override def transformSchema(schema: StructType): StructType = { + $(estimator).transformSchema(schema) + } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } + + override def copy(extra: ParamMap): V = { + val copied = defaultCopy(extra).asInstanceOf[V] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } +} + +/** + * :: DeveloperApi :: + * Model from validation. + */ +@DeveloperApi +private[ml] abstract class ValidationModel[M <: Model[M]] private[ml] ( + override val uid: String, + val bestModel: Model[_], + val avgMetrics: Array[Double]) + extends Model[M] { + + override def validateParams(): Unit = { + bestModel.validateParams() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + bestModel.transform(dataset) + } + + override def transformSchema(schema: StructType): StructType = { + bestModel.transformSchema(schema) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7c5cfa7bd84ce..a5454f193718f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.util +import org.apache.spark.util.Utils + import scala.reflect.ClassTag import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} @@ -258,14 +260,28 @@ object MLUtils { def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => - val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, - complement = false) - val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) - val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) - (training, validation) + sample(rdd, (fold - 1) / numFoldsF, fold / numFoldsF, seed) }.toArray } + /** + * :: Experimental :: + * Return a pair of RDDs with the first element + * containing the training data, a complement of the validation data and the second + * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + */ + @Experimental + def sample[T: ClassTag]( + rdd: RDD[T], + lb: Double, + ub: Double, + seed: Int = Utils.random.nextInt()): (RDD[T], RDD[T]) = { + val sampler = new BernoulliCellSampler[T](lb, ub, complement = false) + val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) + val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) + (training, validation) + } + /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index db64511a76055..ffb566be85d22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,29 +18,19 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.types.StructType class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var dataset: DataFrame = _ - - override def beforeAll(): Unit = { - super.beforeAll() - val sqlContext = new SQLContext(sc) - dataset = sqlContext.createDataFrame( + test("cross validation with logistic regression") { + val dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) - } - test("cross validation with logistic regression") { val lr = new LogisticRegression val lrParamMaps = new ParamGridBuilder() .addGrid(lr.regParam, Array(0.001, 1000.0)) @@ -90,7 +80,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import org.apache.spark.ml.tuning.ValidationSuite._ val est = new MyEstimator("est") val eval = new MyEvaluator @@ -111,35 +101,4 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } -} - -object CrossValidatorSuite { - - abstract class MyModel extends Model[MyModel] - - class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - - override def validateParams(): Unit = require($(inputCol).nonEmpty) - - override def fit(dataset: DataFrame): MyModel = { - throw new UnsupportedOperationException - } - - override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException - } - - override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) - } - - class MyEvaluator extends Evaluator { - - override def evaluate(dataset: DataFrame): Double = { - throw new UnsupportedOperationException - } - - override val uid: String = "eval" - - override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) - } -} +} \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala new file mode 100644 index 0000000000000..0f63694bcf17a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, RegressionEvaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} + +class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { + test("train validation with logistic regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(cv.getTrainPercent === 0.5) + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + } + + test("train validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new TrainValidationSplit() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.avgMetrics.length === lrParamMaps.length) + } + + test("validateParams should check estimatorParamMaps") { + import org.apache.spark.ml.tuning.ValidationSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new TrainValidationSplit() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala new file mode 100644 index 0000000000000..02db33071ece7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +object ValidationSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) + } +} \ No newline at end of file From dff51c78b078a36b80f025400aaa7fc5c093bf5c Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Wed, 24 Jun 2015 16:58:26 -0400 Subject: [PATCH 2/6] SPARK-8484. Naming. --- .../scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala | 2 +- .../org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 04b098c5b159d..b47808c9f8c3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -39,7 +39,7 @@ private[ml] trait TrainValidationSplitParams extends ValidationParams { "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) /** @group getParam */ - def getTrainPercent: Double = $(trainRatio) + def getTrainRatio: Double = $(trainRatio) setDefault(trainRatio -> 0.75) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 0f63694bcf17a..1944da15a2354 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -43,7 +43,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setTrainRatio(0.5) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] - assert(cv.getTrainPercent === 0.5) + assert(cv.getTrainRatio === 0.5) assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) From d033da473382359aba59afdd8a4f8746f2cd17d2 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Wed, 24 Jun 2015 17:07:08 -0400 Subject: [PATCH 3/6] SPARK-8484. Newlines. --- .../scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 2 +- .../test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index ffb566be85d22..a86a8f062e122 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -101,4 +101,4 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } -} \ No newline at end of file +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala index 02db33071ece7..f6f558b4875a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidationSuite.scala @@ -53,4 +53,4 @@ object ValidationSuite { override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) } -} \ No newline at end of file +} From ead62123944101aa21ba25292ead0956927e1312 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 25 Jun 2015 13:17:58 -0400 Subject: [PATCH 4/6] Import sorting. --- .../org/apache/spark/ml/tuning/Validation.scala | 3 ++- .../org/apache/spark/mllib/util/MLUtils.scala | 17 +++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala index 3e93f7c8097ac..d08e6aad10955 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import scala.reflect.ClassTag + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.evaluation.Evaluator @@ -26,7 +28,6 @@ import org.apache.spark.ml.{Model, Estimator} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import scala.reflect.ClassTag /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index a5454f193718f..2d207ffbe501a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,23 +17,20 @@ package org.apache.spark.mllib.util -import org.apache.spark.util.Utils - import scala.reflect.ClassTag import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliCellSampler -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils +import org.apache.spark.util.random.BernoulliCellSampler + /** * Helper methods to load, save and pre-process data used in ML Lib. From 79928814ce8dbc98f0e32b9cea7de2b13aa98ed8 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 25 Jun 2015 14:38:44 -0400 Subject: [PATCH 5/6] SPARK-8484. PR comments https://github.com/apache/spark/pull/6996 --- .../spark/ml/tuning/CrossValidator.scala | 10 +++--- ...nSplit.scala => TrainValidatorSplit.scala} | 32 +++++++++---------- .../{Validation.scala => Validator.scala} | 10 +++--- .../ml/tuning/TrainValidationSplitSuite.scala | 6 ++-- 4 files changed, 29 insertions(+), 29 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/tuning/{TrainValidationSplit.scala => TrainValidatorSplit.scala} (73%) rename mllib/src/main/scala/org/apache/spark/ml/tuning/{Validation.scala => Validator.scala} (94%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 9b6b2bdd35211..64ff940d73478 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.DataFrame /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidationParams { +private[ml] trait CrossValidatorParams extends ValidatorParams { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -52,7 +52,7 @@ private[ml] trait CrossValidatorParams extends ValidationParams { */ @Experimental class CrossValidator(uid: String) - extends Validation[CrossValidatorModel, CrossValidator](uid) + extends Validator[CrossValidatorModel, CrossValidator](uid) with CrossValidatorParams with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -80,11 +80,11 @@ class CrossValidator(uid: String) val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val newMetrics = measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) + val metricsPerSplit = measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) var i = 0 while (i < numModels) { - metrics(i) += newMetrics(i) + metrics(i) += metricsPerSplit(i) i += 1 } } @@ -110,7 +110,7 @@ class CrossValidatorModel private[ml] ( uid: String, bestModel: Model[_], avgMetrics: Array[Double]) - extends ValidationModel[CrossValidatorModel](uid, bestModel, avgMetrics) + extends ValidatorModel[CrossValidatorModel](uid, bestModel, avgMetrics) with CrossValidatorParams { override def copy(extra: ParamMap): CrossValidatorModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidatorSplit.scala similarity index 73% rename from mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala rename to mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidatorSplit.scala index b47808c9f8c3b..25da04c068dc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidatorSplit.scala @@ -27,15 +27,15 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame /** - * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. + * Params for [[TrainValidatorSplit]] and [[TrainValidatorSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidationParams { +private[ml] trait TrainValidatorSplitParams extends ValidatorParams { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 * @group param */ - val trainRatio: DoubleParam = new DoubleParam(this, "numFolds", + val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) /** @group getParam */ @@ -52,9 +52,9 @@ private[ml] trait TrainValidationSplitParams extends ValidationParams { * Similar to CrossValidator, but only splits the set once. */ @Experimental -class TrainValidationSplit(uid: String) - extends Validation[TrainValidationSplitModel, TrainValidationSplit](uid) - with TrainValidationSplitParams with Logging { +class TrainValidatorSplit(uid: String) + extends Validator[TrainValidatorSplitModel, TrainValidatorSplit](uid) + with TrainValidatorSplitParams with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -72,17 +72,17 @@ class TrainValidationSplit(uid: String) transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext - val splits = MLUtils.sample(dataset.rdd, $(trainRatio), 1) - val trainingDataset = sqlCtx.createDataFrame(splits._1, schema).cache() - val validationDataset = sqlCtx.createDataFrame(splits._2, schema).cache() + val (training, validation) = MLUtils.sample(dataset.rdd, $(trainRatio), 1) + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) } override protected[ml] def createModel( uid: String, bestModel: Model[_], - metrics: Array[Double]): TrainValidationSplitModel = { - copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + metrics: Array[Double]): TrainValidatorSplitModel = { + copyValues(new TrainValidatorSplitModel(uid, bestModel, metrics).setParent(this)) } } @@ -91,15 +91,15 @@ class TrainValidationSplit(uid: String) * Model from train validation split. */ @Experimental -class TrainValidationSplitModel private[ml] ( +class TrainValidatorSplitModel private[ml] ( uid: String, bestModel: Model[_], avgMetrics: Array[Double]) - extends ValidationModel[TrainValidationSplitModel](uid, bestModel, avgMetrics) - with TrainValidationSplitParams { + extends ValidatorModel[TrainValidatorSplitModel](uid, bestModel, avgMetrics) + with TrainValidatorSplitParams { - override def copy(extra: ParamMap): TrainValidationSplitModel = { - val copied = new TrainValidationSplitModel ( + override def copy(extra: ParamMap): TrainValidatorSplitModel = { + val copied = new TrainValidatorSplitModel ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/Validator.scala similarity index 94% rename from mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala rename to mllib/src/main/scala/org/apache/spark/ml/tuning/Validator.scala index d08e6aad10955..dd385f4d7c6e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/Validation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/Validator.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.StructType /** * :: DeveloperApi :: - * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. + * Common params for [[TrainValidatorSplitParams]] and [[CrossValidatorParams]]. */ @DeveloperApi -private[ml] trait ValidationParams extends Params { +private[ml] trait ValidatorParams extends Params { /** * param for the estimator to be validated @@ -71,10 +71,10 @@ private[ml] trait ValidationParams extends Params { * Abstract class for validation approaches for hyper-parameter tuning. */ @DeveloperApi -private[ml] abstract class Validation[M <: Model[M], V <: Validation[M, _] : ClassTag] +private[ml] abstract class Validator[M <: Model[M], V <: Validator[M, _] : ClassTag] (override val uid: String) extends Estimator[M] - with Logging with ValidationParams { + with Logging with ValidatorParams { def this() = this(Identifiable.randomUID("cv")) @@ -170,7 +170,7 @@ private[ml] abstract class Validation[M <: Model[M], V <: Validation[M, _] : Cla * Model from validation. */ @DeveloperApi -private[ml] abstract class ValidationModel[M <: Model[M]] private[ml] ( +private[ml] abstract class ValidatorModel[M <: Model[M]] private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 1944da15a2354..c8fc2397d005c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -36,7 +36,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .addGrid(lr.maxIter, Array(0, 10)) .build() val eval = new BinaryClassificationEvaluator - val cv = new TrainValidationSplit() + val cv = new TrainValidatorSplit() .setEstimator(lr) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) @@ -60,7 +60,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .addGrid(trainer.maxIter, Array(0, 10)) .build() val eval = new RegressionEvaluator() - val cv = new TrainValidationSplit() + val cv = new TrainValidatorSplit() .setEstimator(trainer) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) @@ -88,7 +88,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .addGrid(est.inputCol, Array("input1", "input2")) .build() - val cv = new TrainValidationSplit() + val cv = new TrainValidatorSplit() .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) From be64a131d020f642d569baef75c7f5c05420bad9 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 25 Jun 2015 14:44:45 -0400 Subject: [PATCH 6/6] SPARK-8484. PR comments https://github.com/apache/spark/pull/6996 --- .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 64ff940d73478..568776c6b51b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -80,7 +80,8 @@ class CrossValidator(uid: String) val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val metricsPerSplit = measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) + val metricsPerSplit = + measureModels(trainingDataset, validationDataset, est, eval, epm, numModels) var i = 0 while (i < numModels) {