diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 9354a53884b33..1601d6c84e217 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -17,8 +17,9 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{Vector => BV, norm => breezeNorm} +import breeze.linalg.{Vector => BV, SparseVector => BSV, norm => breezeNorm} +import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -138,7 +139,7 @@ class BisectingKMeans private ( // can be clustered if the number of divided clusterStats is equal to 0 // TODO Remove non-leaf cluster stats from `leafClusterStats` - val dividedData = divideClusters(data, dividableLeafClusters, maxIterations).cache() + val dividedData = divideClusters(data, dividableLeafClusters, maxIterations, seed).cache() leafClusterStats = summarizeClusters(dividedData) dividableLeafClusters = leafClusterStats.filter(_._2.isDividable) clusterStats = clusterStats ++ leafClusterStats @@ -248,47 +249,24 @@ private[clustering] object BisectingKMeans { /** * Gets the initial centers for bisecting k-means * - * @param data pairs of point and its cluster index * @param stats pairs of cluster index and cluster statistics + * @param seed random seed */ def initNextCenters( - data: RDD[(Long, BV[Double])], - stats: collection.Map[Long, BisectingClusterStat]): collection.Map[Long, BV[Double]] = { - - // Since the combination sampleByKey and groupByKey is more expensive, - // this as follows would be better. - val bcIndeces = data.sparkContext.broadcast(stats.keySet) - val samples = data.mapPartitions { iter => - val map = collection.mutable.Map.empty[Long, collection.mutable.ArrayBuffer[BV[Double]]] - - bcIndeces.value.foreach {i => map(i) = collection.mutable.ArrayBuffer.empty[BV[Double]]} - val LOCAL_SAMPLE_SIZE = 100 - iter.foreach { case (i, point) => - map(i).append(point) - // to avoid to increase the memory usage on each map thread, - // the number of elements is cut off at the right time. - if (map(i).size > LOCAL_SAMPLE_SIZE) { - val elements = map(i).sortWith((a, b) => breezeNorm(a, 2.0) < breezeNorm(b, 2.0)) - map(i) = collection.mutable.ArrayBuffer(elements.head, elements.last) - } - } + stats: collection.Map[Long, BisectingClusterStat], + seed: Long + ): collection.Map[Long, BV[Double]] = { - // in order to reduce the shuffle size, take only two elements - map.filterNot(_._2.isEmpty).map { case (i, points) => - val elements = map(i).toSeq.sortWith((a, b) => breezeNorm(a, 2.0) < breezeNorm(b, 2.0)) - i -> collection.mutable.ArrayBuffer(elements.head, elements.last) - }.toIterator - }.reduceByKey { case (points1, points2) => - points1.union(points2) - }.collect() - - val nextCenters = samples.flatMap { case (i, points) => - val elements = points.toSeq.sortWith((a, b) => breezeNorm(a, 2.0) < breezeNorm(b, 2.0)) - Array((2 * i, elements.head), (2 * i + 1, elements.last)) + val random = new XORShiftRandom() + random.setSeed(seed) + val nextCenters = stats.flatMap { case (idx, clusterStats) => + val center = clusterStats.mean + val stdev = math.sqrt(clusterStats.sumOfSquares) / clusterStats.rows + val activeKeys = clusterStats.mean.activeKeysIterator.toArray + val activeValues = activeKeys.map(i => random.nextDouble() * stdev) + val perturbation = new BSV[Double](activeKeys, activeValues, clusterStats.mean.size) + Array((2 * idx, center - perturbation), (2 * idx + 1, center + perturbation)) }.toMap - if (!stats.keySet.flatMap(idx => Array(2 * idx, 2 * idx + 1)).forall(nextCenters.contains(_))) { - throw new SparkException("Failed to initialize centers for next step") - } nextCenters } @@ -298,11 +276,15 @@ private[clustering] object BisectingKMeans { * @param data pairs of point and its cluster index * @param clusterStats target clusters to divide * @param maxIterations the maximum iterations to calculate clusters statistics + * @param seed random seed */ def divideClusters( data: RDD[(Long, BV[Double])], clusterStats: collection.Map[Long, BisectingClusterStat], - maxIterations: Int): RDD[(Long, BV[Double])] = { + maxIterations: Int, + seed: Long + ): RDD[(Long, BV[Double])] = { + val sc = data.sparkContext val appName = sc.appName @@ -315,7 +297,7 @@ private[clustering] object BisectingKMeans { // extract dividable input data val dividableData = data.filter { case (idx, point) => dividableClusterStats.contains(idx)} // get next initial centers - var newCenters = initNextCenters(dividableData, dividableClusterStats) + var newCenters = initNextCenters(dividableClusterStats, seed) var nextData = data var subIter = 0 var totalSumOfSquares = Double.MaxValue @@ -596,6 +578,7 @@ private[clustering] case class BisectingClusterStat ( rows: Long, mean: BV[Double], sumOfSquares: Double) extends Serializable { + require(sumOfSquares >= 0.0) def isDividable: Boolean = sumOfSquares > 0 && rows >= 2 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index f0947e336cf7c..74e12d00c2022 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -116,7 +116,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { 2L -> new BisectingClusterStat(2, BV[Double](1.0, 1.0) * 2.0, 0.0), 3L -> new BisectingClusterStat(2, BV[Double](2.0, 2.0) * 2.0, 0.0) ) - val initNextCenters = BisectingKMeans.initNextCenters(data, stats) + val initNextCenters = BisectingKMeans.initNextCenters(stats, 1) assert(initNextCenters.size === 4) assert(initNextCenters.keySet === Set(4, 5, 6, 7)) } @@ -140,7 +140,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val data = sc.parallelize(seed, 1) val leafClusterStats = BisectingKMeans.summarizeClusters(data) val dividableLeafClusters = leafClusterStats.filter(_._2.isDividable) - val result = BisectingKMeans.divideClusters(data, dividableLeafClusters, 20).collect() + val result = BisectingKMeans.divideClusters(data, dividableLeafClusters, 20, 1).collect() val expected = Seq( (4, Vectors.dense(0.0, 0.0)), (4, Vectors.dense(1.0, 1.0)), (4, Vectors.dense(2.0, 2.0)), @@ -185,7 +185,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { ) val data = sc.parallelize(local, 1) val stats = BisectingKMeans.summarizeClusters(data) - val dividedData = BisectingKMeans.divideClusters(data, stats, 20).collect() + val dividedData = BisectingKMeans.divideClusters(data, stats, 20, 1).collect() assert(dividedData(0) == (4L, BV[Double](0.9, 0.9))) assert(dividedData(1) == (4L, BV[Double](1.1, 1.1)))