Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13761] [ML] Remove remaining uses of validateParams #11790

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
$(estimator).transformSchema(schema)
}

@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
val est = $(estimator)
for (paramMap <- $(estimatorParamMaps)) {
est.copy(paramMap).validateParams()
}
}
override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)

@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidator = {
Expand Down Expand Up @@ -331,11 +319,6 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {

@Since("1.4.0")
override def validateParams(): Unit = {
bestModel.validateParams()
}

@Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] (

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
bestModel.transformSchema(schema)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
$(estimator).transformSchema(schema)
}

@Since("1.5.0")
override def validateParams(): Unit = {
super.validateParams()
val est = $(estimator)
for (paramMap <- $(estimatorParamMaps)) {
est.copy(paramMap).validateParams()
}
}
override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)

@Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplit = {
Expand Down Expand Up @@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") val validationMetrics: Array[Double])
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {

@Since("1.5.0")
override def validateParams(): Unit = {
bestModel.validateParams()
}

@Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
bestModel.transformSchema(schema)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.sql.types.StructType

/**
* :: DeveloperApi ::
Expand All @@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params {

/**
* param for the estimator to be validated
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intentional. But it's OK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not intentional, just IntelliJ. I'll leave it though.

* @group param
*/
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
Expand All @@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params {

/**
* param for estimator param maps
*
* @group param
*/
val estimatorParamMaps: Param[Array[ParamMap]] =
Expand All @@ -50,11 +53,22 @@ private[ml] trait ValidatorParams extends Params {

/**
* param for the evaluator used to select hyper-parameters that maximize the validated metric
*
* @group param
*/
val evaluator: Param[Evaluator] = new Param(this, "evaluator",
"evaluator used to select hyper-parameters that maximize the validated metric")

/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)

protected def transformSchemaImpl(schema: StructType): StructType = {
require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
val firstEstimatorParamMap = $(estimatorParamMaps).head
val est = $(estimator)
for (paramMap <- $(estimatorParamMaps).tail) {
est.copy(paramMap).transformSchema(schema)
}
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
solver.getParam("abc")
}

intercept[IllegalArgumentException] {
solver.validateParams()
}
solver.copy(ParamMap(inputCol -> "input")).validateParams()
solver.setInputCol("input")
assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol))
assert(solver.getInputCol === "input")
solver.validateParams()
intercept[IllegalArgumentException] {
ParamMap(maxIter -> -10)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid

def clearMaxIter(): this.type = clear(maxIter)

override def validateParams(): Unit = {
super.validateParams()
require(isDefined(inputCol))
}

override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}

class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand Down Expand Up @@ -96,7 +96,7 @@ class CrossValidatorSuite
assert(cvModel2.avgMetrics.length === lrParamMaps.length)
}

test("validateParams should check estimatorParamMaps") {
test("transformSchema should check estimatorParamMaps") {
import CrossValidatorSuite.{MyEstimator, MyEvaluator}

val est = new MyEstimator("est")
Expand All @@ -110,12 +110,12 @@ class CrossValidatorSuite
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)

cv.validateParams() // This should pass.
cv.transformSchema(new StructType()) // This should pass.

val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
cv.validateParams()
cv.transformSchema(new StructType())
}
}

Expand Down Expand Up @@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {

class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {

override def validateParams(): Unit = require($(inputCol).nonEmpty)

override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException
}

override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
require($(inputCol).nonEmpty)
schema
}

override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
assert(cvModel2.validationMetrics.length === lrParamMaps.length)
}

test("validateParams should check estimatorParamMaps") {
test("transformSchema should check estimatorParamMaps") {
import TrainValidationSplitSuite._

val est = new MyEstimator("est")
Expand All @@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
cv.validateParams() // This should pass.
cv.transformSchema(new StructType()) // This should pass.

val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
cv.validateParams()
cv.transformSchema(new StructType())
}
}
}
Expand All @@ -113,14 +113,13 @@ object TrainValidationSplitSuite {

class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {

override def validateParams(): Unit = require($(inputCol).nonEmpty)

override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException
}

override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
require($(inputCol).nonEmpty)
schema
}

override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
Expand Down