Skip to content

Commit

Permalink
[SPARK-28968][ML] Add HasNumFeatures in the scala side
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add HasNumFeatures in the scala side, with `1<<18` as the default value

### Why are the changes needed?
HasNumFeatures is already added in the py side, it is reasonable to keep them in sync.
I don't find other similar place.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
Existing testsuites

Closes #25671 from zhengruifeng/add_HasNumFeatures.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
  • Loading branch information
zhengruifeng committed Sep 6, 2019
1 parent cb0cddf commit 4664a08
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 42 deletions.
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasInputCols, HasNumFeatures, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand Down Expand Up @@ -83,7 +83,7 @@ import org.apache.spark.util.collection.OpenHashMap
*/
@Since("2.3.0")
class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transformer
with HasInputCols with HasOutputCol with DefaultParamsWritable {
with HasInputCols with HasOutputCol with HasNumFeatures with DefaultParamsWritable {

@Since("2.3.0")
def this() = this(Identifiable.randomUID("featureHasher"))
Expand All @@ -99,21 +99,6 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
val categoricalCols = new StringArrayParam(this, "categoricalCols",
"numeric columns to treat as categorical")

/**
* Number of features. Should be greater than 0.
* (default = 2^18^)
* @group param
*/
@Since("2.3.0")
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))

setDefault(numFeatures -> (1 << 18))

/** @group getParam */
@Since("2.3.0")
def getNumFeatures: Int = $(numFeatures)

/** @group setParam */
@Since("2.3.0")
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
Expand Down
20 changes: 4 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasInputCol, HasNumFeatures, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.sql.{DataFrame, Dataset}
Expand All @@ -43,7 +43,8 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion
*/
@Since("1.2.0")
class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
extends Transformer with HasInputCol with HasOutputCol with HasNumFeatures
with DefaultParamsWritable {

private var hashFunc: Any => Int = FeatureHasher.murmur3Hash

Expand All @@ -58,15 +59,6 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Number of features. Should be greater than 0.
* (default = 2^18^)
* @group param
*/
@Since("1.2.0")
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))

/**
* Binary toggle to control term frequency counts.
* If true, all non-zero counts are set to 1. This is useful for discrete probabilistic
Expand All @@ -79,11 +71,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
"This is useful for discrete probabilistic models that model binary events rather " +
"than integer counts")

setDefault(numFeatures -> (1 << 18), binary -> false)

/** @group getParam */
@Since("1.2.0")
def getNumFeatures: Int = $(numFeatures)
setDefault(binary -> false)

/** @group setParam */
@Since("1.2.0")
Expand Down
Expand Up @@ -63,6 +63,8 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
ParamDesc[Array[String]]("outputCols", "output column names"),
ParamDesc[Int]("numFeatures", "Number of features. Should be greater than 0",
Some("262144"), isValid = "ParamValidators.gt(0)"),
ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
"disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
"every 10 iterations. Note: this setting will be ignored if the checkpoint directory " +
Expand Down Expand Up @@ -95,9 +97,8 @@ private[shared] object SharedParamsCodeGen {
Some("false"), isExpertParam = true),
ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false),
ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" +
" and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"),
isValid = "(value: String) => " +
"org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"),
" and 'cosine'", Some("\"euclidean\""),
isValid = "ParamValidators.inArray(Array(\"euclidean\", \"cosine\"))"),
ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " +
"each row is for training or for validation. False indicates training; true indicates " +
"validation.")
Expand Down
Expand Up @@ -274,6 +274,25 @@ trait HasOutputCols extends Params {
final def getOutputCols: Array[String] = $(outputCols)
}

/**
* Trait for shared param numFeatures (default: 262144). This trait may be changed or
* removed between minor versions.
*/
@DeveloperApi
trait HasNumFeatures extends Params {

/**
* Param for Number of features. Should be greater than 0.
* @group param
*/
final val numFeatures: IntParam = new IntParam(this, "numFeatures", "Number of features. Should be greater than 0", ParamValidators.gt(0))

setDefault(numFeatures, 262144)

/** @group getParam */
final def getNumFeatures: Int = $(numFeatures)
}

/**
* Trait for shared param checkpointInterval. This trait may be changed or
* removed between minor versions.
Expand Down Expand Up @@ -506,7 +525,7 @@ trait HasLoss extends Params {
}

/**
* Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or
* Trait for shared param distanceMeasure (default: "euclidean"). This trait may be changed or
* removed between minor versions.
*/
@DeveloperApi
Expand All @@ -516,9 +535,9 @@ trait HasDistanceMeasure extends Params {
* Param for The distance measure. Supported options: 'euclidean' and 'cosine'.
* @group param
*/
final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value))
final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", ParamValidators.inArray(Array("euclidean", "cosine")))

setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN)
setDefault(distanceMeasure, "euclidean")

/** @group getParam */
final def getDistanceMeasure: String = $(distanceMeasure)
Expand Down
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Expand Up @@ -200,6 +200,12 @@ object MimaExcludes {
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),

// [SPARK-28968][ML] Add HasNumFeatures in the scala side
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.FeatureHasher.getNumFeatures"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.FeatureHasher.numFeatures"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.HashingTF.getNumFeatures"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.HashingTF.numFeatures"),

// [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/param/_shared_params_code_gen.py
Expand Up @@ -120,7 +120,8 @@ def get$Name(self):
("inputCols", "input column names.", None, "TypeConverters.toListString"),
("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"),
("outputCols", "output column names.", None, "TypeConverters.toListString"),
("numFeatures", "number of features.", None, "TypeConverters.toInt"),
("numFeatures", "Number of features. Should be greater than 0.", "262144",
"TypeConverters.toInt"),
("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " +
"E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " +
"this setting will be ignored if the checkpoint directory is not set in the SparkContext.",
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/ml/param/shared.py
Expand Up @@ -281,13 +281,14 @@ def getOutputCols(self):

class HasNumFeatures(Params):
"""
Mixin for param numFeatures: number of features.
Mixin for param numFeatures: Number of features. Should be greater than 0.
"""

numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", typeConverter=TypeConverters.toInt)
numFeatures = Param(Params._dummy(), "numFeatures", "Number of features. Should be greater than 0.", typeConverter=TypeConverters.toInt)

def __init__(self):
super(HasNumFeatures, self).__init__()
self._setDefault(numFeatures=262144)

def setNumFeatures(self, value):
"""
Expand Down

0 comments on commit 4664a08

Please sign in to comment.