Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13761] [ML] Deprecate validateParams #11620

Closed
wants to merge 9 commits into from
14 changes: 0 additions & 14 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@ class Pipeline @Since("1.4.0") (
@Since("1.2.0")
def getStages: Array[PipelineStage] = $(stages).clone()

@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
$(stages).foreach(_.validateParams())
}

/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
Expand Down Expand Up @@ -175,7 +169,6 @@ class Pipeline @Since("1.4.0") (

@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
Expand Down Expand Up @@ -297,12 +290,6 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray)
}

@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
stages.foreach(_.validateParams())
}

@Since("1.2.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -311,7 +298,6 @@ class PipelineModel private[ml] (

@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}

Expand Down
1 change: 0 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ private[ml] trait PredictorParams extends Params
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
validateParams()
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
Expand Down
1 change: 0 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
Expand Down
9 changes: 2 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,6 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

@Since("1.6.0")
override def validateParams(): Unit = {
if (isSet(docConcentration)) {
if (getDocConcentration.length != 1) {
require(getDocConcentration.length == getK, s"LDA docConcentration was of length" +
Expand Down Expand Up @@ -297,6 +290,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ final class Bucketizer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ final class ChiSqSelector(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
Expand Down Expand Up @@ -136,7 +135,6 @@ final class ChiSqSelectorModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val newField = prepOutputField(schema)
val outputFields = schema.fields :+ newField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class HashingTF(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
Expand Down
1 change: 0 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
// optimistic schema; does not contain any ML attributes
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
require(get(inputCols).isDefined, "Input cols must be defined first.")
require(get(outputCol).isDefined, "Output col must be defined first.")
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}

@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
validateParams()
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)
Expand Down Expand Up @@ -217,13 +219,6 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
@Since("1.6.0")
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)

@Since("1.6.0")
override def validateParams(): Unit = {
require(get(inputCols).isDefined, "Input cols must be defined first.")
require(get(outputCol).isDefined, "Output col must be defined first.")
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
}
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand All @@ -69,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
StructType(outputFields)
}

override def validateParams(): Unit = {
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val outputColName = $(outputCol)

Expand Down
2 changes: 0 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -133,7 +132,6 @@ class PCAModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ final class QuantileDiscretizer(override val uid: String)
def setSeed(value: Long): this.type = set(seed, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
require(inputFields.forall(_.name != $(outputCol)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R

// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
validateParams()
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
Expand Down Expand Up @@ -200,7 +199,6 @@ class RFormulaModel private[feature](
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(withFeatures)) {
Expand Down Expand Up @@ -263,7 +261,6 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
}

Expand Down Expand Up @@ -312,7 +309,6 @@ private class VectorAttributeRewriter(
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(
schema.fields.filter(_.name != vectorCol) ++
schema.fields.filter(_.name == vectorCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -144,7 +143,6 @@ class StandardScalerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class StopWordsRemover(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
Expand Down Expand Up @@ -275,7 +274,6 @@ class IndexToString private[ml] (override val uid: String)
final def getLabels: Array[String] = $(labels)

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ class VectorAssembler(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
val dataType = new VectorUDT
Expand Down Expand Up @@ -355,7 +354,6 @@ class VectorIndexerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val dataType = new VectorUDT
require(isDefined(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ final class VectorSlicer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

override def validateParams(): Unit = {
require($(indices).length > 0 || $(names).length > 0,
s"VectorSlicer requires that at least one feature be selected.")
}

override def transform(dataset: DataFrame): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
Expand Down Expand Up @@ -139,7 +134,8 @@ final class VectorSlicer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
require($(indices).length > 0 || $(names).length > 0,
s"VectorSlicer requires that at least one feature be selected.")
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)

if (schema.fieldNames.contains($(outputCol))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
7 changes: 4 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
/**
* Assert that the given value is valid for this parameter.
*
* Note: Parameter checks involving interactions between multiple parameters should be
* implemented in [[Params.validateParams()]]. Checks for input/output columns should be
* implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
* Note: Parameter checks involving interactions between multiple parameters and input/output
* columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
*
* DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
* should be specified via [[ParamPair]].
Expand Down Expand Up @@ -549,7 +548,9 @@ trait Params extends Identifiable with Serializable {
* Parameter value checks which do not depend on other parameters are handled by
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
*/
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this now causes a number of deprecation warnings in the Spark code, which we're trying to get rid of. Can most of the remaining usages be transformed to not use this method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies! I should have checked the Jenkins logs. I'll send a clean-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jkbradley have you already started on this? Sorry for the troubling. I didn't remove the ones in CrossValidator and TrainValidationSplit because I think it can be handy if we can run some validation before submitting the paramMap. Let me know if I can help in any way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem; I just sent a PR for it.

def validateParams(): Unit = {
// Do nothing by default. Override to handle Param interactions.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
Expand Down Expand Up @@ -220,7 +219,6 @@ class ALSModel private[ml] (

@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
Expand Down
Loading