-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-8519][SPARK-11560] [ML] [MLlib] Optimize KMeans implementation. #14937
Conversation
Test build #64854 has finished for PR 14937 at commit
|
Test build #64890 has finished for PR 14937 at commit
|
@yanboliang here are a few other changes I made in my PR that accidentally duplicated some of this work. Refer to #14948 for details. For your consideration: I think getRuns/setRuns should be formally deprecated and the runs param to the constructor removed (it's private). There are some mentions of 'runs' in the docs that should be removed too at this point. mergeContribs and the "type WeightedPoint" don't really serve a purpose at this point and can be 'inlined' IMHO. Minor: the "contribs.iterator" can really be an iterator only over triples with non-zero counts, which eliminates the filtering by 0 counts The "run finished" log message is obsolete now. Minor, but in k-means|| the sample of 1 element is very slightly better if it's without replacement. Won't matter much but otherwise you might sample a couple elements. pointsWithCosts.flatMap might be a little faster as filter + map instead because virtually every element is filtered out. mergeNewCenters() is pretty superfluous, because it's simpler to compute newCenters, then add it to centers, in the same loop. No clear() or multiple calls to update this. weightMap can be computed with countByValue directly |
@srowen Thanks for your suggestion, I will update it soon. |
Test build #64927 has finished for PR 14937 at commit
|
norms.persist() | ||
val zippedData = data.zip(norms).map { case (v, norm) => | ||
new VectorWithNorm(v, norm) | ||
val zippedData = data.map { x => new VectorWithNorm(x) } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below you wrote map(new VectorWithNorm(_))
which is indeed a little more compact. Not a big deal but if you make another pass you might standardize this.
@yanboliang are you still working on this? it seems like an important change, I'd love to help get it in. |
@srowen Yes, I'm working on this. You can see the performance test result in the PR description. We can found that the optimization k-means can get performance improvements about 2 ~ 4 times by using native BLAS level 3 matrix-matrix multiplications for dense input. However, we saw performance degradation for sparse input. For example, the new implementation spent almost twice time as much as the old one when training k-means model on the famous mnist data set. Deep into this problem, I found there is no native BLAS gemm implementation for sparse matrix multiplying with dense one, so sparse input can not benefit from native BLAS to speed up. I searched and found there is sparse BLAS library, but it looks like netlib does not support exporting it. |
@yanboliang would it be useful if I worked on a PR to just remove |
@srowen Please feel free to send that PR. This PR involves some significant change and should be carefully discussed, so it may not be merged too fast. Thanks! |
## What changes were proposed in this pull request? This is a revival of #14948 and related to #14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it. This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available. ## How was this patch tested? Existing tests Author: Sean Owen <sowen@cloudera.com> Closes #15342 from srowen/SPARK-11560.
@yanboliang I began to run some performance tests on this patch today. With this patch the way it is, I am seeing a huge performance degradation. The most critical reason is the slicing (copying) of the centers array inside the inner, inner while loop. The reason I ask is because I don't see how the results posted in this PR could even occur against the current patch. Were those from an older version? I know this PR has gone through several iterations and so I'm just trying to get a sense for where those results came from. It would be great if we could resolve the merge conflicts and start moving review along. I'm happy to help :) |
@sethah I think the test result can be reproduced against the current patch, however, there are two issues should be considered:
val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")).load(path) Spark loads dataset of libsvm format into SparseVector/SparseMatrix by default, and it will fall into the branch of processing sparse data which will cause huge performance degradation. Could you share some of your test detail? If you already considered the above two tips correctly, please let me know as well. I'm on a business travel and will resolve the merge conflicts in a few days. I'm very appreciate to hear your thoughts about this issue. Thanks. |
@yanboliang I ran some tests on a 3 node bare-metal cluster, 144 cores, 384 gb ram on some dense synthetic data. I installed OpenBLAS customized for the hardware on the nodes (I can confirm it's successfully using NativeBLAS, not positive it's optimized though). With this patch at first, I was seeing something like 10 minute iteration times compared to master branch of ~30 seconds. After refactoring the code to avoid some copying, I am still seeing about a 3-5x slowdown using this approach. I am still working through some of the timings and I haven't done a lot of experimentation with the block size. I will give more details at some point. For now, I can point out that copying the center in here seems to have a huge impact. |
A small update: I have run a few tests on a refactored version of this patch which avoids some data copying. I have found at least one case where the current patch is faster, but many where it is not. I'll try to post formal results at some point. (All test cases using dense data btw) In the meantime, I think it would be helpful to have more detail about the tests above. They are rather small datasets. How many centers were used? How were the timings observed? Thanks! |
@sethah You can try the following piece of code even in a single node: import org.apache.spark.ml.clustering.KMeans
val dataset = spark.read.format("libsvm").options(Map("vectorType" -> "dense")).load("/Users/yliang/Downloads/libsvm/combined")
val kmeans = new KMeans().setK(3).setSeed(1L).setTol(1E-16).setMaxIter(100).setInitMode("random")
val model = kmeans.fit(dataset) You can find the dataset at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html .
After this PR:
I think the value of |
@yanboliang I ran the test. The master branch runs in 10 seconds and the current patch runs in 6 seconds. Still, the results are meaningless in my opinion on such a small dataset. I also ran both branches at larger scale and I saw that master branch takes ~20 seconds per iteration in one case while this patch takes 10 minutes. I traced it down to the way the data is being copied. Could you also run tests at scale to verify this? Again, with some refactoring I ran some very preliminary tests (data size approximately 100gb with 100 - 1k clusters) and saw that this branch can improve performance for some cases and degrades it in others. We need to test this at scale to really understand the implications I think. I will try to summarize my results sometime in the next week. I think we will see performance gains when the number of features/clusters is large. |
@sethah Yeah, I agree it's better to run more test against large-scale data. If the number of feature or cluster is large, the center array slice cost and some other place can be optimized which I did not pay more attention. And we definitely should really understand the performance test result. So feel free to share your result. |
## What changes were proposed in this pull request? This is a revival of apache#14948 and related to apache#14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it. This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available. ## How was this patch tested? Existing tests Author: Sean Owen <sowen@cloudera.com> Closes apache#15342 from srowen/SPARK-11560.
Test build #85510 has finished for PR 14937 at commit
|
val model = runAlgorithm(zippedData, instr) | ||
norms.unpersist() | ||
blockData.persist() | ||
blockData.count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just wonder, is this count()
only for executing the persist()
? If so, I think it might slow down the algorithm in case the size of the data is severly large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is to make sure this 'child' RDD is materialized before its 'parent' is unpersisted, or else we lose the value of caching the parent.
} | ||
val model = runAlgorithm(zippedData, instr) | ||
norms.unpersist() | ||
blockData.persist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be good if the storage level is specified as MEMORY_AND_DISK
What changes were proposed in this pull request?
Use BLAS Level 3 matrix-matrix multiplications to compute pairwise distance in k-means.
This is the updated version of #10806.
Performance
Below are some performance tests I have run so far. I am happy to add more cases or trials if it is necessary.
Note: Since sparse matrix multiplications can not be benefit from native BLAS to improve performance, we can see performance degradation for sparse input. We should figure out a way to accelerate sparse matrix multiplications.
How was this patch tested?
Existing unit tests.