Skip to content

Commit

Permalink
Improve initNextCenters
Browse files Browse the repository at this point in the history
  • Loading branch information
yu-iskw committed Oct 29, 2015
1 parent 5da05d3 commit a50689a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.mllib.clustering

import breeze.linalg.{Vector => BV, norm => breezeNorm}
import breeze.linalg.{Vector => BV, SparseVector => BSV, norm => breezeNorm}

import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, Vectors}
Expand Down Expand Up @@ -138,7 +139,7 @@ class BisectingKMeans private (

// can be clustered if the number of divided clusterStats is equal to 0
// TODO Remove non-leaf cluster stats from `leafClusterStats`
val dividedData = divideClusters(data, dividableLeafClusters, maxIterations).cache()
val dividedData = divideClusters(data, dividableLeafClusters, maxIterations, seed).cache()
leafClusterStats = summarizeClusters(dividedData)
dividableLeafClusters = leafClusterStats.filter(_._2.isDividable)
clusterStats = clusterStats ++ leafClusterStats
Expand Down Expand Up @@ -248,47 +249,24 @@ private[clustering] object BisectingKMeans {
/**
* Gets the initial centers for bisecting k-means
*
* @param data pairs of point and its cluster index
* @param stats pairs of cluster index and cluster statistics
* @param seed random seed
*/
def initNextCenters(
data: RDD[(Long, BV[Double])],
stats: collection.Map[Long, BisectingClusterStat]): collection.Map[Long, BV[Double]] = {

// Since the combination sampleByKey and groupByKey is more expensive,
// this as follows would be better.
val bcIndeces = data.sparkContext.broadcast(stats.keySet)
val samples = data.mapPartitions { iter =>
val map = collection.mutable.Map.empty[Long, collection.mutable.ArrayBuffer[BV[Double]]]

bcIndeces.value.foreach {i => map(i) = collection.mutable.ArrayBuffer.empty[BV[Double]]}
val LOCAL_SAMPLE_SIZE = 100
iter.foreach { case (i, point) =>
map(i).append(point)
// to avoid to increase the memory usage on each map thread,
// the number of elements is cut off at the right time.
if (map(i).size > LOCAL_SAMPLE_SIZE) {
val elements = map(i).sortWith((a, b) => breezeNorm(a, 2.0) < breezeNorm(b, 2.0))
map(i) = collection.mutable.ArrayBuffer(elements.head, elements.last)
}
}
stats: collection.Map[Long, BisectingClusterStat],
seed: Long
): collection.Map[Long, BV[Double]] = {

// in order to reduce the shuffle size, take only two elements
map.filterNot(_._2.isEmpty).map { case (i, points) =>
val elements = map(i).toSeq.sortWith((a, b) => breezeNorm(a, 2.0) < breezeNorm(b, 2.0))
i -> collection.mutable.ArrayBuffer(elements.head, elements.last)
}.toIterator
}.reduceByKey { case (points1, points2) =>
points1.union(points2)
}.collect()

val nextCenters = samples.flatMap { case (i, points) =>
val elements = points.toSeq.sortWith((a, b) => breezeNorm(a, 2.0) < breezeNorm(b, 2.0))
Array((2 * i, elements.head), (2 * i + 1, elements.last))
val random = new XORShiftRandom()
random.setSeed(seed)
val nextCenters = stats.flatMap { case (idx, clusterStats) =>
val center = clusterStats.mean
val stdev = math.sqrt(clusterStats.sumOfSquares) / clusterStats.rows
val activeKeys = clusterStats.mean.activeKeysIterator.toArray
val activeValues = activeKeys.map(i => random.nextDouble() * stdev)
val perturbation = new BSV[Double](activeKeys, activeValues, clusterStats.mean.size)
Array((2 * idx, center - perturbation), (2 * idx + 1, center + perturbation))
}.toMap
if (!stats.keySet.flatMap(idx => Array(2 * idx, 2 * idx + 1)).forall(nextCenters.contains(_))) {
throw new SparkException("Failed to initialize centers for next step")
}
nextCenters
}

Expand All @@ -298,11 +276,15 @@ private[clustering] object BisectingKMeans {
* @param data pairs of point and its cluster index
* @param clusterStats target clusters to divide
* @param maxIterations the maximum iterations to calculate clusters statistics
* @param seed random seed
*/
def divideClusters(
data: RDD[(Long, BV[Double])],
clusterStats: collection.Map[Long, BisectingClusterStat],
maxIterations: Int): RDD[(Long, BV[Double])] = {
maxIterations: Int,
seed: Long
): RDD[(Long, BV[Double])] = {

val sc = data.sparkContext
val appName = sc.appName

Expand All @@ -315,7 +297,7 @@ private[clustering] object BisectingKMeans {
// extract dividable input data
val dividableData = data.filter { case (idx, point) => dividableClusterStats.contains(idx)}
// get next initial centers
var newCenters = initNextCenters(dividableData, dividableClusterStats)
var newCenters = initNextCenters(dividableClusterStats, seed)
var nextData = data
var subIter = 0
var totalSumOfSquares = Double.MaxValue
Expand Down Expand Up @@ -596,6 +578,7 @@ private[clustering] case class BisectingClusterStat (
rows: Long,
mean: BV[Double],
sumOfSquares: Double) extends Serializable {
require(sumOfSquares >= 0.0)

def isDividable: Boolean = sumOfSquares > 0 && rows >= 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
2L -> new BisectingClusterStat(2, BV[Double](1.0, 1.0) * 2.0, 0.0),
3L -> new BisectingClusterStat(2, BV[Double](2.0, 2.0) * 2.0, 0.0)
)
val initNextCenters = BisectingKMeans.initNextCenters(data, stats)
val initNextCenters = BisectingKMeans.initNextCenters(stats, 1)
assert(initNextCenters.size === 4)
assert(initNextCenters.keySet === Set(4, 5, 6, 7))
}
Expand All @@ -140,7 +140,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
val data = sc.parallelize(seed, 1)
val leafClusterStats = BisectingKMeans.summarizeClusters(data)
val dividableLeafClusters = leafClusterStats.filter(_._2.isDividable)
val result = BisectingKMeans.divideClusters(data, dividableLeafClusters, 20).collect()
val result = BisectingKMeans.divideClusters(data, dividableLeafClusters, 20, 1).collect()

val expected = Seq(
(4, Vectors.dense(0.0, 0.0)), (4, Vectors.dense(1.0, 1.0)), (4, Vectors.dense(2.0, 2.0)),
Expand Down Expand Up @@ -185,7 +185,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val data = sc.parallelize(local, 1)
val stats = BisectingKMeans.summarizeClusters(data)
val dividedData = BisectingKMeans.divideClusters(data, stats, 20).collect()
val dividedData = BisectingKMeans.divideClusters(data, stats, 20, 1).collect()

assert(dividedData(0) == (4L, BV[Double](0.9, 0.9)))
assert(dividedData(1) == (4L, BV[Double](1.1, 1.1)))
Expand Down

0 comments on commit a50689a

Please sign in to comment.