Skip to content

Commit

Permalink
[SPARK-32310][ML][PYSPARK] ML params default value parity in feature …
Browse files Browse the repository at this point in the history
…and tuning

### What changes were proposed in this pull request?
set params default values in trait Params for feature and tuning in both Scala and Python.

### Why are the changes needed?
Make ML has the same default param values between estimator and its corresponding transformer, and also between Scala and Python.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing and modified tests

Closes #29153 from huaxingao/default2.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Huaxin Gao <huaxing@us.ibm.com>
  • Loading branch information
huaxingao committed Aug 3, 2020
1 parent c6109ba commit bc78859
Show file tree
Hide file tree
Showing 22 changed files with 274 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
/** @group getParam */
def getMissingValue: Double = $(missingValue)

setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)

/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
Expand Down Expand Up @@ -144,8 +146,6 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
@Since("3.0.0")
def setRelativeError(value: Double): this.type = set(relativeError, value)

setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)

override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
/** @group getParam */
def getMax: Double = $(max)

setDefault(min -> 0.0, max -> 1.0)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
Expand Down Expand Up @@ -93,8 +95,6 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
def this() = this(Identifiable.randomUID("minMaxScal"))

setDefault(min -> 0.0, max -> 1.0)

/** @group setParam */
@Since("1.5.0")
def setInputCol(value: String): this.type = set(inputCol, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,20 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
"during fitting, invalid data will result in an error.",
ParamValidators.inArray(OneHotEncoder.supportedHandleInvalids))

setDefault(handleInvalid, OneHotEncoder.ERROR_INVALID)

/**
* Whether to drop the last category in the encoded vector (default: true)
* @group param
*/
@Since("2.3.0")
final val dropLast: BooleanParam =
new BooleanParam(this, "dropLast", "whether to drop the last category")
setDefault(dropLast -> true)

/** @group getParam */
@Since("2.3.0")
def getDropLast: Boolean = $(dropLast)

setDefault(handleInvalid -> OneHotEncoder.ERROR_INVALID, dropLast -> true)

/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ private[feature] trait QuantileDiscretizerBase extends Params
val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " +
"categories) into which data points are grouped. Must be >= 2.",
ParamValidators.gtEq(2))
setDefault(numBuckets -> 2)

/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
Expand Down Expand Up @@ -82,7 +81,8 @@ private[feature] trait QuantileDiscretizerBase extends Params
"how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

setDefault(handleInvalid -> Bucketizer.ERROR_INVALID, numBuckets -> 2)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with
@Since("2.1.0")
val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
"Force to index label whether it is numeric or string")
setDefault(forceIndexLabel -> false)

/** @group getParam */
@Since("2.1.0")
Expand All @@ -80,7 +79,6 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with
"type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)

/**
* Param for how to order categories of a string FEATURE column used by `StringIndexer`.
Expand Down Expand Up @@ -113,12 +111,14 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with
"The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " +
"RFormula drops the same category as R when encoding strings.",
ParamValidators.inArray(StringIndexer.supportedStringOrderType))
setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc)

/** @group getParam */
@Since("2.3.0")
def getStringIndexerOrderType: String = $(stringIndexerOrderType)

setDefault(forceIndexLabel -> false, handleInvalid -> StringIndexer.ERROR_INVALID,
stringIndexerOrderType -> StringIndexer.frequencyDesc)

protected def hasLabelCol(schema: StructType): Boolean = {
schema.map(_.name).contains($(labelCol))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ private[feature] trait RobustScalerParams extends Params with HasInputCol with H
/** @group getParam */
def getLower: Double = $(lower)

setDefault(lower -> 0.25)

/**
* Upper quantile to calculate quantile range, shared by all features
* Default: 0.75
Expand All @@ -64,8 +62,6 @@ private[feature] trait RobustScalerParams extends Params with HasInputCol with H
/** @group getParam */
def getUpper: Double = $(upper)

setDefault(upper -> 0.75)

/**
* Whether to center the data with median before scaling.
* It will build a dense output, so take care when applying to sparse input.
Expand All @@ -78,8 +74,6 @@ private[feature] trait RobustScalerParams extends Params with HasInputCol with H
/** @group getParam */
def getWithCentering: Boolean = $(withCentering)

setDefault(withCentering -> false)

/**
* Whether to scale the data to quantile range.
* Default: true
Expand All @@ -91,7 +85,7 @@ private[feature] trait RobustScalerParams extends Params with HasInputCol with H
/** @group getParam */
def getWithScaling: Boolean = $(withScaling)

setDefault(withScaling -> true)
setDefault(withScaling -> true, lower -> 0.25, upper -> 0.75, withCentering -> false)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ private[feature] trait SelectorParams extends Params
"Number of features that selector will select, ordered by ascending p-value. If the" +
" number of features is < numTopFeatures, then this will select all features.",
ParamValidators.gtEq(1))
setDefault(numTopFeatures -> 50)

/** @group getParam */
@Since("3.1.0")
Expand All @@ -66,7 +65,6 @@ private[feature] trait SelectorParams extends Params
final val percentile = new DoubleParam(this, "percentile",
"Percentile of features that selector will select, ordered by ascending p-value.",
ParamValidators.inRange(0, 1))
setDefault(percentile -> 0.1)

/** @group getParam */
@Since("3.1.0")
Expand All @@ -81,7 +79,6 @@ private[feature] trait SelectorParams extends Params
@Since("3.1.0")
final val fpr = new DoubleParam(this, "fpr", "The higest p-value for features to be kept.",
ParamValidators.inRange(0, 1))
setDefault(fpr -> 0.05)

/** @group getParam */
@Since("3.1.0")
Expand All @@ -96,7 +93,6 @@ private[feature] trait SelectorParams extends Params
@Since("3.1.0")
final val fdr = new DoubleParam(this, "fdr",
"The upper bound of the expected false discovery rate.", ParamValidators.inRange(0, 1))
setDefault(fdr -> 0.05)

/** @group getParam */
def getFdr: Double = $(fdr)
Expand All @@ -110,7 +106,6 @@ private[feature] trait SelectorParams extends Params
@Since("3.1.0")
final val fwe = new DoubleParam(this, "fwe",
"The upper bound of the expected family-wise error rate.", ParamValidators.inRange(0, 1))
setDefault(fwe -> 0.05)

/** @group getParam */
def getFwe: Double = $(fwe)
Expand All @@ -125,12 +120,13 @@ private[feature] trait SelectorParams extends Params
"The selector type. Supported options: numTopFeatures, percentile, fpr, fdr, fwe",
ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr",
"fwe")))
setDefault(selectorType -> "numTopFeatures")

/** @group getParam */
@Since("3.1.0")
def getSelectorType: String = $(selectorType)

setDefault(numTopFeatures -> 50, percentile -> 0.1, fpr -> 0.05, fdr -> 0.05, fwe -> 0.05,
selectorType -> "numTopFeatures")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))

setDefault(handleInvalid, StringIndexer.ERROR_INVALID)

/**
* Param for how to order labels of string column. The first label after ordering is assigned
* an index of 0.
Expand All @@ -80,6 +78,9 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
ParamValidators.inArray(StringIndexer.supportedStringOrderType))

setDefault(handleInvalid -> StringIndexer.ERROR_INVALID,
stringOrderType -> StringIndexer.frequencyDesc)

/** @group getParam */
@Since("2.3.0")
def getStringOrderType: String = $(stringOrderType)
Expand Down Expand Up @@ -155,7 +156,6 @@ class StringIndexer @Since("1.4.0") (
/** @group setParam */
@Since("2.3.0")
def setStringOrderType(value: String): this.type = set(stringOrderType, value)
setDefault(stringOrderType, StringIndexer.frequencyDesc)

/** @group setParam */
@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
"number of categories of the feature).",
ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))

setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)

/**
* Threshold for the number of values a categorical feature can take.
* If a feature is found to have {@literal >} maxCategories values, then it is declared
Expand All @@ -75,10 +73,10 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
" If a feature is found to have > maxCategories values, then it is declared continuous.",
ParamValidators.gtEq(2))

setDefault(maxCategories -> 20)

/** @group getParam */
def getMaxCategories: Int = $(maxCategories)

setDefault(maxCategories -> 20, handleInvalid -> VectorIndexer.ERROR_INVALID)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
"An array of indices to select features from a vector column." +
" There can be no overlap with names.", VectorSlicer.validIndices)

setDefault(indices -> Array.emptyIntArray)

/** @group getParam */
@Since("1.5.0")
def getIndices: Array[Int] = $(indices)
Expand All @@ -79,8 +77,6 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
"An array of feature names to select features from a vector column." +
" There can be no overlap with indices.", VectorSlicer.validNames)

setDefault(names -> Array.empty[String])

/** @group getParam */
@Since("1.5.0")
def getNames: Array[String] = $(names)
Expand All @@ -97,6 +93,8 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

setDefault(indices -> Array.emptyIntArray, names -> Array.empty[String])

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
// Validity checks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ private[feature] trait Word2VecBase extends Params
final val vectorSize = new IntParam(
this, "vectorSize", "the dimension of codes after transforming from words (> 0)",
ParamValidators.gt(0))
setDefault(vectorSize -> 100)

/** @group getParam */
def getVectorSize: Int = $(vectorSize)
Expand All @@ -60,7 +59,6 @@ private[feature] trait Word2VecBase extends Params
final val windowSize = new IntParam(
this, "windowSize", "the window size (context words from [-window, window]) (> 0)",
ParamValidators.gt(0))
setDefault(windowSize -> 5)

/** @group expertGetParam */
def getWindowSize: Int = $(windowSize)
Expand All @@ -73,7 +71,6 @@ private[feature] trait Word2VecBase extends Params
final val numPartitions = new IntParam(
this, "numPartitions", "number of partitions for sentences of words (> 0)",
ParamValidators.gt(0))
setDefault(numPartitions -> 1)

/** @group getParam */
def getNumPartitions: Int = $(numPartitions)
Expand All @@ -86,7 +83,6 @@ private[feature] trait Word2VecBase extends Params
*/
final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
"appear to be included in the word2vec model's vocabulary (>= 0)", ParamValidators.gtEq(0))
setDefault(minCount -> 5)

/** @group getParam */
def getMinCount: Int = $(minCount)
Expand All @@ -101,13 +97,12 @@ private[feature] trait Word2VecBase extends Params
final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " +
"(in words) of each sentence in the input data. Any sentence longer than this threshold will " +
"be divided into chunks up to the size (> 0)", ParamValidators.gt(0))
setDefault(maxSentenceLength -> 1000)

/** @group getParam */
def getMaxSentenceLength: Int = $(maxSentenceLength)

setDefault(stepSize -> 0.025)
setDefault(maxIter -> 1)
setDefault(vectorSize -> 100, windowSize -> 5, numPartitions -> 1, minCount -> 5,
maxSentenceLength -> 1000, stepSize -> 0.025, maxIter -> 1)

/**
* Validate and transform the input schema.
Expand Down
16 changes: 5 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
"Fraction of the training data used for learning each decision tree, in range (0, 1].",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))

setDefault(subsamplingRate -> 1.0)

/** @group getParam */
final def getSubsamplingRate: Double = $(subsamplingRate)

Expand Down Expand Up @@ -386,10 +384,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)

setDefault(featureSubsetStrategy -> "auto")

/** @group getParam */
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)

setDefault(subsamplingRate -> 1.0, featureSubsetStrategy -> "auto")
}

/**
Expand Down Expand Up @@ -448,8 +446,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
new IntParam(this, "numTrees", "Number of trees to train (at least 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)

Expand All @@ -461,11 +457,11 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
final val bootstrap: BooleanParam = new BooleanParam(this, "bootstrap",
"Whether bootstrap samples are used when building trees.")

setDefault(bootstrap -> true)

/** @group getParam */
@Since("3.0.0")
final def getBootstrap: Boolean = $(bootstrap)

setDefault(numTrees -> 20, bootstrap -> true)
}

private[ml] trait RandomForestClassifierParams
Expand Down Expand Up @@ -518,9 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
"(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))

setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01)

setDefault(featureSubsetStrategy -> "all")
setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01, featureSubsetStrategy -> "all")

/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
/** @group getParam */
def getNumFolds: Int = $(numFolds)

setDefault(numFolds -> 3)

/**
* Param for the column name of user specified fold number. Once this is specified,
* `CrossValidator` won't do random k-fold split. Note that this column should be
Expand All @@ -68,7 +66,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {

def getFoldCol: String = $(foldCol)

setDefault(foldCol, "")
setDefault(foldCol -> "", numFolds -> 3)
}

/**
Expand Down
Loading

0 comments on commit bc78859

Please sign in to comment.