diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 3d7a91dd39a71..963f81cb3ec39 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } @Since("1.4.0") - override def transformSchema(schema: StructType): StructType = { - validateParams() - $(estimator).transformSchema(schema) - } - - @Since("1.4.0") - override def validateParams(): Unit = { - super.validateParams() - val est = $(estimator) - for (paramMap <- $(estimatorParamMaps)) { - est.copy(paramMap).validateParams() - } - } + override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema) @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { @@ -331,11 +319,6 @@ class CrossValidatorModel private[ml] ( @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { - @Since("1.4.0") - override def validateParams(): Unit = { - bestModel.validateParams() - } - @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 4587e259e8bf7..70fa5f0234753 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } @Since("1.5.0") - override def transformSchema(schema: StructType): StructType = { - validateParams() - $(estimator).transformSchema(schema) - } - - @Since("1.5.0") - override def validateParams(): Unit = { - super.validateParams() - val est = $(estimator) - for (paramMap <- $(estimatorParamMaps)) { - est.copy(paramMap).validateParams() - } - } + override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema) @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { @@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { - @Since("1.5.0") - override def validateParams(): Unit = { - bestModel.validateParams() - } - @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 553f254172410..c004644ad8e55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.Estimator import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.sql.types.StructType /** * :: DeveloperApi :: @@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params { /** * param for the estimator to be validated + * * @group param */ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") @@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params { /** * param for estimator param maps + * * @group param */ val estimatorParamMaps: Param[Array[ParamMap]] = @@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params { /** * param for the evaluator used to select hyper-parameters that maximize the validated metric + * * @group param */ val evaluator: Param[Evaluator] = new Param(this, "evaluator", @@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params { /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) + + protected def transformSchemaImpl(schema: StructType): StructType = { + require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") + val firstEstimatorParamMap = $(estimatorParamMaps).head + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps).tail) { + est.copy(paramMap).transformSchema(schema) + } + est.copy(firstEstimatorParamMap).transformSchema(schema) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 748868554fe65..a3366c0e5934c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite { solver.getParam("abc") } - intercept[IllegalArgumentException] { - solver.validateParams() - } - solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") - solver.validateParams() intercept[IllegalArgumentException] { ParamMap(maxIter -> -10) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 9d23547f28447..7d990ce0bcfd8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid def clearMaxIter(): this.type = clear(maxIter) - override def validateParams(): Unit = { - super.validateParams() - require(isDefined(inputCol)) - } - override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 56545de14bd30..7af3c6d6ede47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -96,7 +96,7 @@ class CrossValidatorSuite assert(cvModel2.avgMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") @@ -110,12 +110,12 @@ class CrossValidatorSuite .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } @@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 5fb80091d0b4b..cf8dcefebc3aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext assert(cvModel2.validationMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import TrainValidationSplitSuite._ val est = new MyEstimator("est") @@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } } @@ -113,14 +113,13 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)