Skip to content

Commit

Permalink
Don't create duplicate cluster centers. This means < k clusters may b…
Browse files Browse the repository at this point in the history
…e returned when there are < k inputs
  • Loading branch information
srowen committed Sep 3, 2016
1 parent 0a51e1e commit e7f12fa
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class KMeans private (
def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())

/**
* Number of clusters to create (k).
* Number of clusters to create (k). Note that if the input has fewer than k elements,
* then it's possible that fewer than k clusters are created.
*/
@Since("1.4.0")
def getK: Int = k
Expand Down Expand Up @@ -135,7 +136,7 @@ class KMeans private (

/**
* Set the number of steps for the k-means|| initialization mode. This is an advanced
* setting -- the default of 5 is almost always enough. Default: 5.
* setting -- the default of 2 is almost always enough. Default: 2.
*/
@Since("0.8.0")
def setInitializationSteps(initializationSteps: Int): this.type = {
Expand Down Expand Up @@ -269,11 +270,10 @@ class KMeans private (
// Find the sum and count of points mapping to each center
val totalContribs = data.mapPartitions { points =>
val thisCenters = bcCenters.value
val k = thisCenters.length
val dims = thisCenters.head.vector.size

val sums = Array.fill(k)(Vectors.zeros(dims))
val counts = Array.fill(k)(0L)
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
val counts = Array.fill(thisCenters.length)(0L)

points.foreach { point =>
val (bestCenter, cost) = KMeans.findClosest(thisCenters, point)
Expand Down Expand Up @@ -324,7 +324,7 @@ class KMeans private (
* Initialize set of cluster centers at random.
*/
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
val sample = data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt())
val sample = data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
sample.map(v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm))
}

Expand Down

0 comments on commit e7f12fa

Please sign in to comment.