From e7f12fa3e1d3273f558f90455c6c5be8e6a9c8f6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 3 Sep 2016 15:11:37 +0100 Subject: [PATCH] Don't create duplicate cluster centers. This means < k clusters may be returned when there are < k inputs --- .../org/apache/spark/mllib/clustering/KMeans.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 711d904ca2881..888c4ab61a833 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,7 +56,8 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). + * Number of clusters to create (k). Note that if the input has fewer than k elements, + * then it's possible that fewer than k clusters are created. */ @Since("1.4.0") def getK: Int = k @@ -135,7 +136,7 @@ class KMeans private ( /** * Set the number of steps for the k-means|| initialization mode. This is an advanced - * setting -- the default of 5 is almost always enough. Default: 5. + * setting -- the default of 2 is almost always enough. Default: 2. */ @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { @@ -269,11 +270,10 @@ class KMeans private ( // Find the sum and count of points mapping to each center val totalContribs = data.mapPartitions { points => val thisCenters = bcCenters.value - val k = thisCenters.length val dims = thisCenters.head.vector.size - val sums = Array.fill(k)(Vectors.zeros(dims)) - val counts = Array.fill(k)(0L) + val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) + val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) @@ -324,7 +324,7 @@ class KMeans private ( * Initialize set of cluster centers at random. */ private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { - val sample = data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()) + val sample = data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt()) sample.map(v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)) }