Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11560] [MLLIB] Optimize KMeans implementation / remove 'runs' #15342

Closed
wants to merge 6 commits into from

Conversation

srowen
Copy link
Member

@srowen srowen commented Oct 4, 2016

What changes were proposed in this pull request?

This is a revival of #14948 and related to #14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it.

This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available.

How was this patch tested?

Existing tests

@srowen
Copy link
Member Author

srowen commented Oct 4, 2016

CC @yanboliang

@SparkQA
Copy link

SparkQA commented Oct 4, 2016

Test build #66310 has finished for PR 15342 at commit cd14b65.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

}.toArray)
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
val sample = data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
sample.map(v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample.map(_.toDense)?

}

/**
* Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al.
* Initialize set of cluster centers using the k-means|| algorithm by Bahmani et al.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "Initialize a set"

}

/**
* Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al.
* Initialize set of cluster centers using the k-means|| algorithm by Bahmani et al.
* (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries
* to find with dissimilar cluster centers by starting with a random center and then doing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"to find with dissimilar" -> "to find dissimilar" (while we're here)

// Initialize empty centers and point costs.
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
var costs = data.map(_ => Double.PositiveInfinity)

// Initialize each run's first center to a random point.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Initialize the first center to a random point."

@@ -558,6 +475,7 @@ object KMeans {
* Trains a k-means model using specified parameters and the default values for unspecified.
*/
@Since("0.8.0")
@deprecated("Use train method without 'runs'", "2.1.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two other train signatures that use runs, but have not been marked as deprecated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though there's no alternative to those with the same arguments. We could add another overload and deprecate the others. I'm OK with that too, just felt a little gross to add yet more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add them for completeness, and deprecate all overloads using runs.

costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))

// Finally, we might have a set of more than k candidate centers for each run; weigh each
if (centers.size <= k) {
return centers.toArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to avoid the return keyword and just put the other code under the else here. But it is a small preference.

val costs = Array.fill(numRuns)(0.0)

var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var active = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find converged to be more intuitive, but not a strong preference.

@sethah
Copy link
Contributor

sethah commented Oct 4, 2016

Looking good. My main concern is that now you can have the following:

scala> model.getK
res2: Int = 3

scala> model.clusterCenters.length
res3: Int = 1

We could set the model k to match the cluster centers length before creating the model, during training. We could leave it, but then what does k mean, if not the number of centers?

@srowen
Copy link
Member Author

srowen commented Oct 5, 2016

That's right. k seems like the requested number of centroids, which may not match the actual number in corner cases. What about just documenting that more?

@srowen
Copy link
Member Author

srowen commented Oct 5, 2016

Otherwise updated to reflect all the other review comments, thanks.

@SparkQA
Copy link

SparkQA commented Oct 5, 2016

Test build #66381 has finished for PR 15342 at commit ebbb852.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@yanboliang
Copy link
Contributor

yanboliang commented Oct 5, 2016

I'm more prefer to maintain the original logic that keep model.clusterCenters.length equal to k. Was there some discussion as to make this change?
I checked popular Python machine learning library scikit-learn, it return the requested number of centroids even if it is greater than the number of distinct data points:
image

And for R kmeans, it throw error if there is more cluster centers than distinct data points:
image

@srowen
Copy link
Member Author

srowen commented Oct 5, 2016

This is what SPARK-3261 is about. It's a corner case to be sure. To me it seems like having duplicate centroids is worse because the model loses some of its meaning. Points may arbitrarily assign to one or the other of two identical centroids. Of the 3 possible behaviors, looks like we have all 3 on the table:

  1. error
  2. return < k centroids
  3. return k centroids

I suppose I prefer the new behavior but I can't say I feel that strongly. I guess matching scikit has some value.

@sethah
Copy link
Contributor

sethah commented Oct 6, 2016

What are the circumstances that lead to duplicate cluster centers? Other than the obvious having less data than requested centers. The comment on the original JIRA said training 1.3M points asking for 10k clusters only returned 1k centers.

@srowen
Copy link
Member Author

srowen commented Oct 6, 2016

Good question. I think he's saying that it returned 1K centers after this change. It's a good point that this would also speed things up considerably, because computing the distance to duplicate centroids is all superfluous work.

@sethah
Copy link
Contributor

sethah commented Oct 7, 2016

@srowen That is not the impression that I got from "I just ran clustering on 1.3M points, asking for 10,000 clusters. This clustering run resulted in 1019 unique cluster centers."

@derrickburns Can you clarify a bit here? Also, could you tell us the nature of the data that was used for your clustering?

@srowen
Copy link
Member Author

srowen commented Oct 8, 2016

I'm wondering, what's the use case for allowing duplicate centroids? it doesn't have a reasonable meaning and does slow down execution. I don't feel so strongly about it and would like to get the change to remove "runs" in regardless, so could back that out, but I'd be a little more convinced if it were more than just matching scikit

@srowen
Copy link
Member Author

srowen commented Oct 8, 2016

I backed out the change for SPARK-3261; that part is actually tiny and separable now anyway. We can discuss that here too but wanted to split it from the main change for expediency.

@srowen srowen changed the title [SPARK-11560] [SPARK-3261] [MLLIB] Optimize KMeans implementation / remove 'runs' / KMeans clusterer can return duplicate cluster centers [SPARK-11560] [MLLIB] Optimize KMeans implementation / remove 'runs' Oct 8, 2016
@SparkQA
Copy link

SparkQA commented Oct 8, 2016

Test build #66578 has finished for PR 15342 at commit 68e3d90.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen
Copy link
Member Author

srowen commented Oct 10, 2016

@sethah are you OK with this part? We can still talk about the k centroids bit, either here or on the JIRA.

@sethah
Copy link
Contributor

sethah commented Oct 10, 2016

@srowen I will take a look shortly.


/**
* Number of clusters to create (k).
* Number of clusters to create (k). Note that if the input has fewer than k elements,
* then it's possible that fewer than k clusters are created.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we back out change to avoid duplicate centroids, does this annotation be invalid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, right

new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to cast vector to dense one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess not, but the centers become immediately dense in the first iteration of runAlgorithm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least, it's what the existing code did.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can optimize this issue at #14937, since it will have different treatment for dense and sparse vector.

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's centers. Note that only distances between points
val centers = ArrayBuffer[VectorWithNorm]()
var newCenters = Seq(sample.head.toDense)
Copy link
Contributor

@yanboliang yanboliang Oct 10, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe irrelevant with this PR, why we need to cast it to dense one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is performance critical, but centers ++= newCenters will be faster if newCenters is an Array instead of List.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it faster if it's an Array instead of Seq? or am I getting the comments cross-wired?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayBuffer.++= is optimized for IndexedSeq, but not for List.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is probably negligible. I just thought we could use Array if there is no specific preference for using a List.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a Seq but yeah no specific preference. Where do you see that optimization BTW? I just see it implemented for TraversableOnce, and that's what my IDE says it calls even when given an Array

chosen.foreach { case (p, rs) =>
rs.foreach(newCenters(_) += p.toDense)
}
newCenters = chosen.map(_.toDense)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

bcCenters.destroy(blocking = false)

// Update the cluster centers and costs
converged = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changed would be more intuitive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I thought the opposite flag, converged was more intuitive. If you don't feel strongly about it, let's leave it, but, if you'd moderately prefer changed then I don't mind that. I think it's the same thing with the flag inverted.


// Update the cluster centers and costs
converged = true
totalContribs.foreach { case (j, (sum, count)) =>
Copy link
Contributor

@yanboliang yanboliang Oct 10, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared with the original code, foreach may slower than while loop if you have a large k.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that? I'm aware that Scala for comprehensions can desugar into something surprisingly expensive, but this seems clearer and about the same as a while

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general while is faster than foreach (creating and calling an anonymous function), but I'd be surprised if it affected performance here because we are only running this once per iteration and the bulk of the cost will be distributed computation.

Copy link
Contributor

@sethah sethah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor things, but otherwise LGTM.

bcCenters.destroy(blocking = false)

// Update the cluster centers and costs
converged = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just leave converged false, and only change it to true inside the foreach?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm overlooking some obviously nicer expression, I think the loop is going to work the same either way: you have to assume you terminate unless a distance proves otherwise, per iteration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is the same, yes, but it seems really strange to set something to false, then each iteration set it to true and then set it back false if some condition. Why not leave it false and change to true if convergence criteria is met? This is basically a trivial detail, so only change it if you want. I'm fine either way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it can be done the way you're suggesting; it's not just preference. You could just set it with a nice simple call .forall as you're suggesting, usually, but here we also need the side effect of visiting each element. To do both I think we have to 'unroll' the equivalent logic and it amounts to this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're correct. Thanks!

new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess not, but the centers become immediately dense in the first iteration of runAlgorithm.

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's centers. Note that only distances between points
val centers = ArrayBuffer[VectorWithNorm]()
var newCenters = Seq(sample.head.toDense)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is performance critical, but centers ++= newCenters will be faster if newCenters is an Array instead of List.

@@ -558,6 +475,7 @@ object KMeans {
* Trains a k-means model using specified parameters and the default values for unspecified.
*/
@Since("0.8.0")
@deprecated("Use train method without 'runs'", "2.1.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add them for completeness, and deprecate all overloads using runs.

@srowen
Copy link
Member Author

srowen commented Oct 10, 2016

@sethah OK I will add a new overload of train and deprecate the others.

@SparkQA
Copy link

SparkQA commented Oct 10, 2016

Test build #66673 has finished for PR 15342 at commit 5cb9e5f.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* on system time.
*/
@Since("2.1.0")
def train(data: RDD[Vector],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: style should match other train signatures

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @sethah

@@ -531,6 +471,7 @@ object KMeans {
* "k-means||". (default: "k-means||")
*/
@Since("0.8.0")
@deprecated("Use train method without 'runs'", "2.1.0")
def train(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature does not have a direct alternative without runs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @sethah

@SparkQA
Copy link

SparkQA commented Oct 10, 2016

Test build #66684 has finished for PR 15342 at commit 84fb22f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@yanboliang
Copy link
Contributor

Only the last two minor items, otherwise, this looks ready to me. Thanks!

@srowen
Copy link
Member Author

srowen commented Oct 11, 2016

Yeah, but now we have yet 2 more overloads. I had intended to point people to 1 new overload, but I guess it's weird to make people specify the seed arg. And optional args, the normal solution, breaks binary compatibility IIRC

@SparkQA
Copy link

SparkQA commented Oct 11, 2016

Test build #66729 has finished for PR 15342 at commit ba52582.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor

sethah commented Oct 11, 2016

LGTM

@srowen
Copy link
Member Author

srowen commented Oct 12, 2016

Merged to master. I'm going to reopen a PR for just the duplicate centroids issue to re-table that.

@srowen srowen closed this Oct 12, 2016
@srowen srowen deleted the SPARK-11560 branch October 12, 2016 09:02
asfgit pushed a commit that referenced this pull request Oct 12, 2016
## What changes were proposed in this pull request?

This is a revival of #14948 and related to #14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it.

This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available.

## How was this patch tested?

Existing tests

Author: Sean Owen <sowen@cloudera.com>

Closes #15342 from srowen/SPARK-11560.
uzadude pushed a commit to uzadude/spark that referenced this pull request Jan 27, 2017
## What changes were proposed in this pull request?

This is a revival of apache#14948 and related to apache#14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it.

This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available.

## How was this patch tested?

Existing tests

Author: Sean Owen <sowen@cloudera.com>

Closes apache#15342 from srowen/SPARK-11560.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants