-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2308][MLLIB] Add Mini-Batch KMeans Clustering method #1248
Conversation
Can one of the admins verify this patch? |
Broad question -- this seems to duplicate a lot of KMeans.scala. Can it not be a variant rather than a separate implementation? or at least refactor the substantial commonality? |
The main function (runBreeze) of the KMeans is not compatible since KMeans optimizes multiple runs by striping iterations across the runs. With MiniBatch, each run's iteration will use a different randomly-sampled subset of points so the runs have to be done independently. I can pull out other shared functions, though. |
…ctions common to KMeans and KMeansMiniBatch objects.
…vate KMeans classes. Moved KMeansMiniBatch.{initRandom, initKMeansMiniBatchParallel} there
…s are only for one run.
…n stochastic nature, use epsilons instead of direct comparison of floats.
Sean, I updated the code to factor out common bits into a KMeansCommons file, using traits for both the objects and classes. I updated the KMeansMiniBatch tests so they are customized for the KMeansMiniBatch, don't duplicate testing of common code, and account for the stochastic nature by using an epsilon for the errors instead of directly comparing the floats. I also realized that I failed to implement a key part of the MiniBatch algorithm so that is now included. Please review again. |
// Execute iterations of Lloyd's algorithm until all runs have converged | ||
while (iteration < maxIterations) { | ||
|
||
val sampledPoints = data.sample(false, batchSize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample actually takes a Double as the second argument for the sampling rate (between 0.0, and 1.0). I think you want takeSample here in order to get the exact batch size (but be warned that takeSample actually needs to collect the sample to the driver). Or you could compute the sampling rate and use sample instead, for an approximate sample size (you can only do this for free if the size of the RDD is already known, o/w you need to do a count, which requires a pass over the entire RDD).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, dorx! I'm surprised this didn't result in a compilation or run-time error. I'll update the code.
…e tests accordingly.
MiniBatch KMeans needs a sampling method that runs in O(k), where k is the number of data points to sample time. Current Spark sampling methods run in O(n) time, where n is the number of data points in the RDD. I'm closing this PR until a better sampling method is found. |
Mini-batch is a version of KMeans that uses a randomly-sampled subset of the data points in each iteration instead of the full set of data points, improving performance (and in some cases, accuracy). The mini-batch version is compatible with the KMeans|| initialization algorithm currently implemented in MLlib.
This PR adds the KMeansMiniBatch clustering algorithm, tests, and updates docs.
Discussed in SPARK-2308