Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
srowen committed Oct 20, 2016
1 parent 85c9857 commit ebebcb9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ class KMeans private (
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
// Select without replacement; may still produce duplicates if the data has < k distinct
// points, so deduplicate the centroids to match the behavior of k-means|| in the same situation
data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt()).
map(_.vector).distinct.map(new VectorWithNorm(_))
data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
.map(_.vector).distinct.map(new VectorWithNorm(_))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.clusterCenters.head ~== center absTol 1E-5)
}

test("no distinct points") {
test("fewer distinct points than clusters") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0)),
2)
val center = Vectors.dense(1.0, 2.0, 3.0)

// Make sure code runs.
var model = KMeans.train(data, k = 2, maxIterations = 1)
assert(model.clusterCenters.size === 1)
var model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "random")
assert(model.clusterCenters.length === 1)

model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "k-means||")
assert(model.clusterCenters.length === 1)
}


test("fewer clusters than points") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 3.0, 4.0)),
2)

var model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = "random")
assert(model.clusterCenters.length === 1)

model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = "k-means||")
assert(model.clusterCenters.length === 1)
}

test("more clusters than points") {
Expand All @@ -85,9 +101,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
Vectors.dense(1.0, 3.0, 4.0)),
2)

// Make sure code runs.
var model = KMeans.train(data, k = 3, maxIterations = 1)
assert(model.clusterCenters.size === 2)
var model = KMeans.train(data, k = 3, maxIterations = 1, initializationMode = "random")
assert(model.clusterCenters.length === 2)

model = KMeans.train(data, k = 3, maxIterations = 1, initializationMode = "k-means||")
assert(model.clusterCenters.length === 2)
}

test("deterministic initialization") {
Expand Down

0 comments on commit ebebcb9

Please sign in to comment.