-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-11560] [MLLIB] Optimize KMeans implementation / remove 'runs' #15342
Conversation
CC @yanboliang |
Test build #66310 has finished for PR 15342 at commit
|
}.toArray) | ||
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { | ||
val sample = data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt()) | ||
sample.map(v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample.map(_.toDense)
?
} | ||
|
||
/** | ||
* Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. | ||
* Initialize set of cluster centers using the k-means|| algorithm by Bahmani et al. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "Initialize a set"
} | ||
|
||
/** | ||
* Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. | ||
* Initialize set of cluster centers using the k-means|| algorithm by Bahmani et al. | ||
* (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries | ||
* to find with dissimilar cluster centers by starting with a random center and then doing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"to find with dissimilar" -> "to find dissimilar" (while we're here)
// Initialize empty centers and point costs. | ||
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) | ||
var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity)) | ||
var costs = data.map(_ => Double.PositiveInfinity) | ||
|
||
// Initialize each run's first center to a random point. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Initialize the first center to a random point."
@@ -558,6 +475,7 @@ object KMeans { | |||
* Trains a k-means model using specified parameters and the default values for unspecified. | |||
*/ | |||
@Since("0.8.0") | |||
@deprecated("Use train method without 'runs'", "2.1.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two other train
signatures that use runs
, but have not been marked as deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, though there's no alternative to those with the same arguments. We could add another overload and deprecate the others. I'm OK with that too, just felt a little gross to add yet more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add them for completeness, and deprecate all overloads using runs
.
costs.unpersist(blocking = false) | ||
bcNewCentersList.foreach(_.destroy(false)) | ||
|
||
// Finally, we might have a set of more than k candidate centers for each run; weigh each | ||
if (centers.size <= k) { | ||
return centers.toArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to avoid the return
keyword and just put the other code under the else here. But it is a small preference.
val costs = Array.fill(numRuns)(0.0) | ||
|
||
var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) | ||
var active = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find converged
to be more intuitive, but not a strong preference.
Looking good. My main concern is that now you can have the following: scala> model.getK
res2: Int = 3
scala> model.clusterCenters.length
res3: Int = 1 We could set the model |
That's right. |
Otherwise updated to reflect all the other review comments, thanks. |
Test build #66381 has finished for PR 15342 at commit
|
This is what SPARK-3261 is about. It's a corner case to be sure. To me it seems like having duplicate centroids is worse because the model loses some of its meaning. Points may arbitrarily assign to one or the other of two identical centroids. Of the 3 possible behaviors, looks like we have all 3 on the table:
I suppose I prefer the new behavior but I can't say I feel that strongly. I guess matching scikit has some value. |
What are the circumstances that lead to duplicate cluster centers? Other than the obvious having less data than requested centers. The comment on the original JIRA said training 1.3M points asking for 10k clusters only returned 1k centers. |
Good question. I think he's saying that it returned 1K centers after this change. It's a good point that this would also speed things up considerably, because computing the distance to duplicate centroids is all superfluous work. |
@srowen That is not the impression that I got from "I just ran clustering on 1.3M points, asking for 10,000 clusters. This clustering run resulted in 1019 unique cluster centers." @derrickburns Can you clarify a bit here? Also, could you tell us the nature of the data that was used for your clustering? |
I'm wondering, what's the use case for allowing duplicate centroids? it doesn't have a reasonable meaning and does slow down execution. I don't feel so strongly about it and would like to get the change to remove "runs" in regardless, so could back that out, but I'd be a little more convinced if it were more than just matching scikit |
I backed out the change for SPARK-3261; that part is actually tiny and separable now anyway. We can discuss that here too but wanted to split it from the main change for expediency. |
Test build #66578 has finished for PR 15342 at commit
|
@sethah are you OK with this part? We can still talk about the k centroids bit, either here or on the JIRA. |
@srowen I will take a look shortly. |
|
||
/** | ||
* Number of clusters to create (k). | ||
* Number of clusters to create (k). Note that if the input has fewer than k elements, | ||
* then it's possible that fewer than k clusters are created. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we back out change to avoid duplicate centroids, does this annotation be invalid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooops, right
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) | ||
}.toArray) | ||
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { | ||
data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to cast vector to dense one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess not, but the centers become immediately dense in the first iteration of runAlgorithm
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least, it's what the existing code did.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can optimize this issue at #14937, since it will have different treatment for dense and sparse vector.
// On each step, sample 2 * k points on average for each run with probability proportional | ||
// to their squared distance from that run's centers. Note that only distances between points | ||
val centers = ArrayBuffer[VectorWithNorm]() | ||
var newCenters = Seq(sample.head.toDense) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe irrelevant with this PR, why we need to cast it to dense one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is performance critical, but centers ++= newCenters
will be faster if newCenters
is an Array instead of List.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it faster if it's an Array instead of Seq? or am I getting the comments cross-wired?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ArrayBuffer.++=
is optimized for IndexedSeq
, but not for List
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference is probably negligible. I just thought we could use Array
if there is no specific preference for using a List
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a Seq but yeah no specific preference. Where do you see that optimization BTW? I just see it implemented for TraversableOnce
, and that's what my IDE says it calls even when given an Array
chosen.foreach { case (p, rs) => | ||
rs.foreach(newCenters(_) += p.toDense) | ||
} | ||
newCenters = chosen.map(_.toDense) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
bcCenters.destroy(blocking = false) | ||
|
||
// Update the cluster centers and costs | ||
converged = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think changed
would be more intuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I thought the opposite flag, converged
was more intuitive. If you don't feel strongly about it, let's leave it, but, if you'd moderately prefer changed
then I don't mind that. I think it's the same thing with the flag inverted.
|
||
// Update the cluster centers and costs | ||
converged = true | ||
totalContribs.foreach { case (j, (sum, count)) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compared with the original code, foreach
may slower than while
loop if you have a large k
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that? I'm aware that Scala for
comprehensions can desugar into something surprisingly expensive, but this seems clearer and about the same as a while
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general while
is faster than foreach
(creating and calling an anonymous function), but I'd be surprised if it affected performance here because we are only running this once per iteration and the bulk of the cost will be distributed computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor things, but otherwise LGTM.
bcCenters.destroy(blocking = false) | ||
|
||
// Update the cluster centers and costs | ||
converged = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we just leave converged false, and only change it to true inside the foreach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm overlooking some obviously nicer expression, I think the loop is going to work the same either way: you have to assume you terminate unless a distance proves otherwise, per iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic is the same, yes, but it seems really strange to set something to false, then each iteration set it to true and then set it back false if some condition. Why not leave it false and change to true if convergence criteria is met? This is basically a trivial detail, so only change it if you want. I'm fine either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it can be done the way you're suggesting; it's not just preference. You could just set it with a nice simple call .forall
as you're suggesting, usually, but here we also need the side effect of visiting each element. To do both I think we have to 'unroll' the equivalent logic and it amounts to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're correct. Thanks!
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) | ||
}.toArray) | ||
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { | ||
data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess not, but the centers become immediately dense in the first iteration of runAlgorithm
.
// On each step, sample 2 * k points on average for each run with probability proportional | ||
// to their squared distance from that run's centers. Note that only distances between points | ||
val centers = ArrayBuffer[VectorWithNorm]() | ||
var newCenters = Seq(sample.head.toDense) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is performance critical, but centers ++= newCenters
will be faster if newCenters
is an Array instead of List.
@@ -558,6 +475,7 @@ object KMeans { | |||
* Trains a k-means model using specified parameters and the default values for unspecified. | |||
*/ | |||
@Since("0.8.0") | |||
@deprecated("Use train method without 'runs'", "2.1.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add them for completeness, and deprecate all overloads using runs
.
@sethah OK I will add a new overload of |
Test build #66673 has finished for PR 15342 at commit
|
* on system time. | ||
*/ | ||
@Since("2.1.0") | ||
def train(data: RDD[Vector], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: style should match other train signatures
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 @sethah
@@ -531,6 +471,7 @@ object KMeans { | |||
* "k-means||". (default: "k-means||") | |||
*/ | |||
@Since("0.8.0") | |||
@deprecated("Use train method without 'runs'", "2.1.0") | |||
def train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature does not have a direct alternative without runs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 @sethah
Test build #66684 has finished for PR 15342 at commit
|
Only the last two minor items, otherwise, this looks ready to me. Thanks! |
Yeah, but now we have yet 2 more overloads. I had intended to point people to 1 new overload, but I guess it's weird to make people specify the seed arg. And optional args, the normal solution, breaks binary compatibility IIRC |
Test build #66729 has finished for PR 15342 at commit
|
LGTM |
Merged to master. I'm going to reopen a PR for just the duplicate centroids issue to re-table that. |
## What changes were proposed in this pull request? This is a revival of #14948 and related to #14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it. This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available. ## How was this patch tested? Existing tests Author: Sean Owen <sowen@cloudera.com> Closes #15342 from srowen/SPARK-11560.
## What changes were proposed in this pull request? This is a revival of apache#14948 and related to apache#14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it. This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available. ## How was this patch tested? Existing tests Author: Sean Owen <sowen@cloudera.com> Closes apache#15342 from srowen/SPARK-11560.
What changes were proposed in this pull request?
This is a revival of #14948 and related to #14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it.
This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available.
How was this patch tested?
Existing tests