From 5682bfcc74fa0eb5d3085477c093f8625b62843c Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 9 Mar 2016 16:00:25 -0800 Subject: [PATCH 1/5] deprecate validateParameters --- .../main/scala/org/apache/spark/ml/Pipeline.scala | 14 -------------- .../main/scala/org/apache/spark/ml/Predictor.scala | 1 - .../scala/org/apache/spark/ml/Transformer.scala | 1 - .../org/apache/spark/ml/clustering/KMeans.scala | 1 - .../scala/org/apache/spark/ml/clustering/LDA.scala | 9 ++------- .../org/apache/spark/ml/feature/Bucketizer.scala | 1 - .../apache/spark/ml/feature/ChiSqSelector.scala | 2 -- .../apache/spark/ml/feature/CountVectorizer.scala | 1 - .../org/apache/spark/ml/feature/HashingTF.scala | 1 - .../scala/org/apache/spark/ml/feature/IDF.scala | 1 - .../org/apache/spark/ml/feature/Interaction.scala | 13 ++++--------- .../org/apache/spark/ml/feature/MaxAbsScaler.scala | 1 - .../org/apache/spark/ml/feature/MinMaxScaler.scala | 5 +---- .../apache/spark/ml/feature/OneHotEncoder.scala | 1 - .../scala/org/apache/spark/ml/feature/PCA.scala | 2 -- .../spark/ml/feature/QuantileDiscretizer.scala | 1 - .../org/apache/spark/ml/feature/RFormula.scala | 4 ---- .../apache/spark/ml/feature/SQLTransformer.scala | 1 - .../apache/spark/ml/feature/StandardScaler.scala | 2 -- .../apache/spark/ml/feature/StopWordsRemover.scala | 1 - .../apache/spark/ml/feature/StringIndexer.scala | 2 -- .../apache/spark/ml/feature/VectorAssembler.scala | 1 - .../apache/spark/ml/feature/VectorIndexer.scala | 2 -- .../org/apache/spark/ml/feature/VectorSlicer.scala | 8 ++------ .../org/apache/spark/ml/feature/Word2Vec.scala | 1 - .../scala/org/apache/spark/ml/param/params.scala | 7 ++++--- .../org/apache/spark/ml/recommendation/ALS.scala | 2 -- .../ml/regression/AFTSurvivalRegression.scala | 1 - .../spark/ml/regression/IsotonicRegression.scala | 1 - .../org/apache/spark/ml/clustering/LDASuite.scala | 12 +++++++----- .../spark/ml/feature/MinMaxScalerSuite.scala | 10 ++++++---- .../spark/ml/feature/VectorSlicerSuite.scala | 8 ++++---- 32 files changed, 30 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index cbac7bbf49fc4..f4c6214a56360 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -110,12 +110,6 @@ class Pipeline @Since("1.4.0") ( @Since("1.2.0") def getStages: Array[PipelineStage] = $(stages).clone() - @Since("1.4.0") - override def validateParams(): Unit = { - super.validateParams() - $(stages).foreach(_.validateParams()) - } - /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. @@ -175,7 +169,6 @@ class Pipeline @Since("1.4.0") ( @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { - validateParams() val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") @@ -297,12 +290,6 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } - @Since("1.4.0") - override def validateParams(): Unit = { - super.validateParams() - stages.foreach(_.validateParams()) - } - @Since("1.2.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -311,7 +298,6 @@ class PipelineModel private[ml] ( @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { - validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 4b27ee6c5a414..ebe48700f8717 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -46,7 +46,6 @@ private[ml] trait PredictorParams extends Params schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - validateParams() // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index fdce273193b7c..1f3325ad09ef1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -103,7 +103,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 79332b0d02157..ab00127899edf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -81,7 +81,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 6304b20d544ad..0de82b49ff6f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -263,13 +263,6 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) - } - - @Since("1.6.0") - override def validateParams(): Unit = { if (isSet(docConcentration)) { if (getDocConcentration.length != 1) { require(getDocConcentration.length == getK, s"LDA docConcentration was of length" + @@ -297,6 +290,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 0c75317d82703..33abc7c99d4b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -86,7 +86,6 @@ final class Bucketizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 4abc459f5369a..b9e9d56853605 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -88,7 +88,6 @@ final class ChiSqSelector(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) @@ -136,7 +135,6 @@ final class ChiSqSelectorModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index cf151458f0917..f7d08b39a9746 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,7 +70,6 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 8af00581f7e54..61a78d73c4347 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -69,7 +69,6 @@ class HashingTF(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index cebbe5c162f79..f36cf503a0b80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -52,7 +52,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 7d2a1da990fce..d3fe6e528f0b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -61,13 +61,15 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer // optimistic schema; does not contain any ML attributes @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateParams() + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { - validateParams() val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) @@ -217,13 +219,6 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer @Since("1.6.0") override def copy(extra: ParamMap): Interaction = defaultCopy(extra) - @Since("1.6.0") - override def validateParams(): Unit = { - require(get(inputCols).isDefined, "Input cols must be defined first.") - require(get(outputCol).isDefined, "Output col must be defined first.") - require($(inputCols).length > 0, "Input cols must have non-zero length.") - require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") - } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 09fad236422b7..7de5a4d5d314c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -37,7 +37,6 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 3b4209bbc49ad..b13684a1cb76a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -59,7 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -69,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H StructType(outputFields) } - override def validateParams(): Unit = { - require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e9df161c00b83..1e4028af3b69f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - validateParams() val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 80b124f74716d..305c3d187fcbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -77,7 +77,6 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -133,7 +132,6 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 18896fcc4d8c1..e830d2a9adc41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -78,7 +78,6 @@ final class QuantileDiscretizer(override val uid: String) def setSeed(value: Long): this.type = set(seed, value) override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index c21da218b36d6..ab5f4a1a9a6c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -167,7 +167,6 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { - validateParams() if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { @@ -200,7 +199,6 @@ class RFormulaModel private[feature]( } override def transformSchema(schema: StructType): StructType = { - validateParams() checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(withFeatures)) { @@ -263,7 +261,6 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def transformSchema(schema: StructType): StructType = { - validateParams() StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } @@ -312,7 +309,6 @@ private class VectorAttributeRewriter( } override def transformSchema(schema: StructType): StructType = { - validateParams() StructType( schema.fields.filter(_.name != vectorCol) ++ schema.fields.filter(_.name == vectorCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index af6494b234cee..e0ca45b9a6190 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -74,7 +74,6 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateParams() val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 9952d3bc9f1a5..26ee8e1bf1669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -94,7 +94,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -144,7 +143,6 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0d4c968633295..0a0e0b0960c88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -145,7 +145,6 @@ class StopWordsRemover(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7dd794b9d7d1d..c579a0d68ec6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,7 +39,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], @@ -275,7 +274,6 @@ class IndexToString private[ml] (override val uid: String) final def getLabels: Array[String] = $(labels) override def transformSchema(schema: StructType): StructType = { - validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType.isInstanceOf[NumericType], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7ff5ad143f80b..957e8e7a5983c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -106,7 +106,6 @@ class VectorAssembler(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 5c11760fab9b2..bf4aef2a74c71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -126,7 +126,6 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod } override def transformSchema(schema: StructType): StructType = { - validateParams() // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). val dataType = new VectorUDT @@ -355,7 +354,6 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() val dataType = new VectorUDT require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 300d63bd3a0da..b60e82de00c08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -89,11 +89,6 @@ final class VectorSlicer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def validateParams(): Unit = { - require($(indices).length > 0 || $(names).length > 0, - s"VectorSlicer requires that at least one feature be selected.") - } - override def transform(dataset: DataFrame): DataFrame = { // Validity checks transformSchema(dataset.schema) @@ -139,7 +134,8 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 3d3c7bdc2f4d8..95bae1c8a3127 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -92,7 +92,6 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d7d6c0f5fa16e..77abe94771762 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -58,9 +58,8 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali /** * Assert that the given value is valid for this parameter. * - * Note: Parameter checks involving interactions between multiple parameters should be - * implemented in [[Params.validateParams()]]. Checks for input/output columns should be - * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. + * Note: Parameter checks involving interactions between multiple parameters and input/output + * columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. * * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters * should be specified via [[ParamPair]]. @@ -549,7 +548,9 @@ trait Params extends Identifiable with Serializable { * Parameter value checks which do not depend on other parameters are handled by * [[Param.validate()]]. This method does not handle input/output column parameters; * those are checked during schema validation. + * @deprecated All the checks should be merged into transformSchema */ + @deprecated("All the checks should be merged into transformSchema", "2.0.0") def validateParams(): Unit = { // Do nothing by default. Override to handle Param interactions. } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index dacdac9a1df16..f3bc9f095a8c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -162,7 +162,6 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType @@ -220,7 +219,6 @@ class ALSModel private[ml] ( @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e4339d67b928d..0901642d392d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -99,7 +99,6 @@ private[regression] trait AFTSurvivalRegressionParams extends Params protected def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 36b006c10e1fb..20a09982014ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -105,7 +105,6 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures protected[ml] def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { - validateParams() if (fitting) { SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) if (hasWeightCol) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a3a8f65eac176..dd3f4c6e5391d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - // validateParams() - lda.validateParams() + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") + // validate parameters + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(1.1) - lda.validateParams() + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) - lda.validateParams() + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) withClue("LDA docConcentration validity check failed for bad array length") { intercept[IllegalArgumentException] { - lda.validateParams() + lda.transformSchema(dummyDF.schema) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 035bfc07b684d..87206c777e352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(10).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(0).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 94191e5df383b..6bb4678dc5f97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { - val slicer = new VectorSlicer + val slicer = new VectorSlicer().setInputCol("feature") ParamsSuite.checkParams(slicer) assert(slicer.getIndices.length === 0) assert(slicer.getNames.length === 0) withClue("VectorSlicer should not have any features selected by default") { intercept[IllegalArgumentException] { - slicer.validateParams() + slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true)))) } } } From 3ccf90a7857075ae6f4b59e7bf3a178521eb0604 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 10 Mar 2016 00:30:59 -0800 Subject: [PATCH 2/5] fix glm --- .../GeneralizedLinearRegression.scala | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index de1dff9421145..87b62500fc1ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import breeze.stats.distributions.{Gaussian => GD} import org.apache.hadoop.fs.Path +import org.apache.spark.sql.types.{StructType, DataType} import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} @@ -45,7 +46,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * to be used in the model. * Supported options: "gaussian", "binomial", "poisson" and "gamma". * Default is "gaussian". - * @group param + * + * @group param */ @Since("2.0.0") final val family: Param[String] = new Param(this, "family", @@ -61,7 +63,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the name of link function which provides the relationship * between the linear predictor and the mean of the distribution function. * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". - * @group param + * + * @group param */ @Since("2.0.0") final val link: Param[String] = new Param(this, "link", "The name of link function " + @@ -77,7 +80,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam import GeneralizedLinearRegression._ @Since("2.0.0") - override def validateParams(): Unit = { + override def validateAndTransformSchema(schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { if ($(solver) == "irls") { setDefault(maxIter -> 25) } @@ -86,6 +91,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + s"with ${$(family)} family does not support ${$(link)} link function.") } + super.validateAndTransformSchema(schema, fitting, featuresDataType) } } @@ -117,7 +123,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[family]]. * Default is "gaussian". - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setFamily(value: String): this.type = set(family, value) @@ -125,7 +132,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[link]]. - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setLink(value: String): this.type = set(link, value) @@ -133,7 +141,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets if we should fit the intercept. * Default is true. - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) @@ -141,7 +150,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the maximum number of iterations. * Default is 25 if the solver algorithm is "irls". - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setMaxIter(value: Int): this.type = set(maxIter, value) @@ -150,7 +160,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the convergence tolerance of iterations. * Smaller value will lead to higher accuracy with the cost of more iterations. * Default is 1E-6. - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setTol(value: Double): this.type = set(tol, value) @@ -159,7 +170,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the regularization parameter. * Default is 0.0. - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setRegParam(value: Double): this.type = set(regParam, value) @@ -169,7 +181,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the value of param [[weightCol]]. * If this is not set or empty, we treat all instance weights as 1.0. * Default is empty, so all instances have weight one. - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) @@ -178,7 +191,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the solver algorithm used for optimization. * Currently only support "irls" which is also the default solver. - * @group setParam + * + * @group setParam */ @Since("2.0.0") def setSolver(value: String): this.type = set(solver, value) @@ -305,7 +319,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * A description of the error distribution to be used in the model. - * @param name the name of the family. + * + * @param name the name of the family. */ private[ml] abstract class Family(val name: String) extends Serializable { @@ -326,7 +341,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Gets the [[Family]] object from its name. - * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + * + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". */ def fromName(name: String): Family = { name match { @@ -447,7 +463,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * A description of the link function to be used in the model. * The link function provides the relationship between the linear predictor * and the mean of the distribution function. - * @param name the name of link function. + * + * @param name the name of link function. */ private[ml] abstract class Link(val name: String) extends Serializable { @@ -465,7 +482,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Gets the [[Link]] object from its name. - * @param name link name: "identity", "logit", "log", + * + * @param name link name: "identity", "logit", "log", * "inverse", "probit", "cloglog" or "sqrt". */ def fromName(name: String): Link = { From 29c5a2fef94347aff95a22f1f75d940a7bccf0a1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 10 Mar 2016 09:46:58 -0800 Subject: [PATCH 3/5] style fix --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 87b62500fc1ce..a44f31bf9c274 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.regression import breeze.stats.distributions.{Gaussian => GD} import org.apache.hadoop.fs.Path -import org.apache.spark.sql.types.{StructType, DataType} import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} @@ -33,6 +32,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} /** * Params for Generalized Linear Regression. From 91f72a9d89f2bd7c83adb71573305b6559d6e71e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 16 Mar 2016 13:38:17 -0700 Subject: [PATCH 4/5] comments and format --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 4 ++-- .../spark/ml/regression/GeneralizedLinearRegression.scala | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 77abe94771762..69ada2d02f637 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -548,9 +548,9 @@ trait Params extends Identifiable with Serializable { * Parameter value checks which do not depend on other parameters are handled by * [[Param.validate()]]. This method does not handle input/output column parameters; * those are checked during schema validation. - * @deprecated All the checks should be merged into transformSchema + * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema */ - @deprecated("All the checks should be merged into transformSchema", "2.0.0") + @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0") def validateParams(): Unit = { // Do nothing by default. Override to handle Param interactions. } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 8dc2c9a3013c4..f7a1c0e04857a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -46,8 +46,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * to be used in the model. * Supported options: "gaussian", "binomial", "poisson" and "gamma". * Default is "gaussian". - * - * @group param + * + * @group param */ @Since("2.0.0") final val family: Param[String] = new Param(this, "family", @@ -80,7 +80,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam import GeneralizedLinearRegression._ @Since("2.0.0") - override def validateAndTransformSchema(schema: StructType, + override def validateAndTransformSchema( + schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { if ($(solver) == "irls") { From f3480440e7b7ac590f710709ca6ca0e13399ca11 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 16 Mar 2016 16:17:04 -0700 Subject: [PATCH 5/5] revert some unintentional comment change --- .../GeneralizedLinearRegression.scala | 40 ++++++------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index f7a1c0e04857a..46ba5589ff85f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -46,7 +46,6 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * to be used in the model. * Supported options: "gaussian", "binomial", "poisson" and "gamma". * Default is "gaussian". - * * @group param */ @Since("2.0.0") @@ -63,8 +62,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the name of link function which provides the relationship * between the linear predictor and the mean of the distribution function. * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". - * - * @group param + * @group param */ @Since("2.0.0") final val link: Param[String] = new Param(this, "link", "The name of link function " + @@ -124,8 +122,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[family]]. * Default is "gaussian". - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setFamily(value: String): this.type = set(family, value) @@ -133,8 +130,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[link]]. - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setLink(value: String): this.type = set(link, value) @@ -142,8 +138,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets if we should fit the intercept. * Default is true. - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) @@ -151,8 +146,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the maximum number of iterations. * Default is 25 if the solver algorithm is "irls". - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setMaxIter(value: Int): this.type = set(maxIter, value) @@ -161,8 +155,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the convergence tolerance of iterations. * Smaller value will lead to higher accuracy with the cost of more iterations. * Default is 1E-6. - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setTol(value: Double): this.type = set(tol, value) @@ -171,8 +164,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the regularization parameter. * Default is 0.0. - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setRegParam(value: Double): this.type = set(regParam, value) @@ -182,8 +174,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the value of param [[weightCol]]. * If this is not set or empty, we treat all instance weights as 1.0. * Default is empty, so all instances have weight one. - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) @@ -192,8 +183,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the solver algorithm used for optimization. * Currently only support "irls" which is also the default solver. - * - * @group setParam + * @group setParam */ @Since("2.0.0") def setSolver(value: String): this.type = set(solver, value) @@ -337,8 +327,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * A description of the error distribution to be used in the model. - * - * @param name the name of the family. + * @param name the name of the family. */ private[ml] abstract class Family(val name: String) extends Serializable { @@ -375,8 +364,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Gets the [[Family]] object from its name. - * - * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". */ def fromName(name: String): Family = { name match { @@ -555,8 +543,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * A description of the link function to be used in the model. * The link function provides the relationship between the linear predictor * and the mean of the distribution function. - * - * @param name the name of link function. + * @param name the name of link function. */ private[ml] abstract class Link(val name: String) extends Serializable { @@ -574,8 +561,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Gets the [[Link]] object from its name. - * - * @param name link name: "identity", "logit", "log", + * @param name link name: "identity", "logit", "log", * "inverse", "probit", "cloglog" or "sqrt". */ def fromName(name: String): Link = {