Skip to content

Commit

Permalink
Adding KMeans train with seed and Scala unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
str-janus committed Dec 4, 2014
1 parent 616d111 commit 5d087b4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,31 @@ object KMeans {
.run(data)
}

/**
* Trains a k-means model using the given set of parameters.
*
* @param data training points stored as `RDD[Array[Double]]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param runs number of parallel runs, defaults to 1. The best model is returned.
* @param initializationMode initialization model, either "random" or "k-means||" (default).
* @param seed seed value for cluster initialization
*/
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
runs: Int,
initializationMode: String,
seed: Long): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
.setSeed(seed)
.run(data)
}

/**
* Trains a k-means model using specified parameters and the default values for unspecified.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(model.clusterCenters.size === 3)
}

test("deterministic initilization") {
// Create a large-ish set of point to cluster
val points = List.tabulate(1000)(n => Vectors.dense(n,n))
val rdd = sc.parallelize(points, 3)

for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
// Create three deterministic models and compare cluster means
val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42)
val centers1 = model1.clusterCenters

val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42)
val centers2 = model2.clusterCenters

val model3 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42)
val centers3 = model3.clusterCenters

assert(centers1.deep == centers2.deep)
assert(centers1.deep == centers3.deep)
}
}

test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),
Expand Down

0 comments on commit 5d087b4

Please sign in to comment.