Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17704][ML][MLlib] ChiSqSelector performance improvement. #15299

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (

import ChiSqSelectorModel._

/** list of indices to select (filter). Must be ordered asc */
/** list of indices to select (filter). */
@Since("1.6.0")
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
* @param selectedFeatures list of indices to select (filter).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should say "since the model requires sorted indices, selectedFeatures will be sorted" or something - just to make it clear the model does have this requirement, but takes care of that itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind that, though, my original theory behind this little change was that the sorting is wholly an implementation detail that callers don't need to promise or be promised about these features. It's very small, but, do we need to even promise these are sorted here?

*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {

require(isSorted(selectedFeatures), "Array has to be sorted asc")
private val filterIndices = selectedFeatures.sorted

@deprecated("not intended for subclasses to use", "2.1.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I also fail to see why this needs to be exposed. +1 on deprecation.

protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
val len = array.length
Expand All @@ -61,17 +62,16 @@ class ChiSqSelectorModel @Since("1.3.0") (
*/
@Since("1.3.0")
override def transform(vector: Vector): Vector = {
compress(vector, selectedFeatures)
compress(vector)
}

/**
* Returns a vector with features filtered.
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
* @param filterIndices indices of features to filter, must be ordered asc
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
private def compress(features: Vector): Vector = {
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
Expand Down Expand Up @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
*/
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data)
val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
val features = selectorType match {
case ChiSqSelector.KBest =>
chiSqTestResult.zipWithIndex
chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
case ChiSqSelector.Percentile =>
chiSqTestResult.zipWithIndex
chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
chiSqTestResult.zipWithIndex
.filter{ case (res, _) => res.pValue < alpha }
chiSqTestResult
.filter { case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
val indices = features.map { case (_, indices) => indices }.sorted
val indices = features.map { case (_, index) => index }
new ChiSqSelectorModel(indices)
}
}
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
@since("2.0.0")
def selectedFeatures(self):
"""
List of indices to select (filter). Must be ordered asc.
List of indices to select (filter).
"""
return self._call_java("selectedFeatures")

Expand Down