Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30938][ML][MLLIB] BinaryClassificationMetrics optimization #27682

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

/**
Expand Down Expand Up @@ -101,10 +101,19 @@ class BinaryClassificationMetrics @Since("3.0.0") (
@Since("1.0.0")
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
val numParts = rocCurve.getNumPartitions
rocCurve.mapPartitionsWithIndex { case (pid, iter) =>
if (numParts == 1) {
require(pid == 0)
Iterator.single((0.0, 0.0)) ++ iter ++ Iterator.single((1.0, 1.0))
} else if (pid == 0) {
Iterator.single((0.0, 0.0)) ++ iter
} else if (pid == numParts - 1) {
iter ++ Iterator.single((1.0, 1.0))
} else {
iter
}
}
}

/**
Expand All @@ -124,7 +133,13 @@ class BinaryClassificationMetrics @Since("3.0.0") (
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val (_, firstPrecision) = prCurve.first()
confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
prCurve.mapPartitionsWithIndex { case (pid, iter) =>
if (pid == 0) {
Iterator.single((0.0, firstPrecision)) ++ iter
} else {
iter
}
}
}

/**
Expand Down Expand Up @@ -182,28 +197,40 @@ class BinaryClassificationMetrics @Since("3.0.0") (
val countsSize = counts.count()
// Group the iterator into chunks of about countsSize / numBins points,
// so that the resulting number of bins is about numBins
var grouping = countsSize / numBins
val grouping = countsSize / numBins
if (grouping < 2) {
// numBins was more than half of the size; no real point in down-sampling to bins
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
counts
} else {
if (grouping >= Int.MaxValue) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterator.grouped(size: Int) does not support grouping larger than Int.MaxValue
After this change, BinaryClassificationMetrics can deal with grouping larger than Int.MaxValue

logWarning(
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
grouping = Int.MaxValue
counts.mapPartitions { iter =>
if (iter.hasNext) {
var score = Double.NaN
var agg = new BinaryLabelCounter()
var cnt = 0L
iter.flatMap { pair =>
score = pair._1
agg += pair._2
cnt += 1
if (cnt == grouping) {
// The score of the combined point will be just the last one's score,
// which is also the minimal in each chunk since all scores are already
// sorted in descending.
// The combined point will contain all counts in this chunk. Thus, calculated
// metrics (like precision, recall, etc.) on its score (or so-called threshold)
// are the same as those without sampling.
val ret = (score, agg)
agg = new BinaryLabelCounter()
cnt = 0
Some(ret)
} else None
} ++ {
if (cnt > 0) {
Iterator.single((score, agg))
} else Iterator.empty
}
} else Iterator.empty
}
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
// The score of the combined point will be just the last one's score, which is also
// the minimal in each chunk since all scores are already sorted in descending.
val lastScore = pairs.last._1
// The combined point will contain all counts in this chunk. Thus, calculated
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
// the same as those without sampling.
val agg = new BinaryLabelCounter()
pairs.foreach(pair => agg += pair._2)
(lastScore, agg)
})
}
}

Expand Down