Skip to content

Commit

Permalink
[SPARK-3424][MLLIB] cache point distances during k-means|| init
Browse files Browse the repository at this point in the history
This PR ports the following feature implemented in #2634 by derrickburns:

* During k-means|| initialization, we should cache costs (squared distances) previously computed.

It also contains the following optimization:

* aggregate sumCosts directly
* ran multiple (#runs) k-means++ in parallel

I compared the performance locally on mnist-digit. Before this patch:

![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png)

with this patch:

![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png)

It is clear that each k-means|| iteration takes about the same amount of time with this patch.

Authors:
  Derrick Burns <derrickburns@gmail.com>
  Xiangrui Meng <meng@databricks.com>

Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits:

0a875ec [Xiangrui Meng] address comments
4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means||
  • Loading branch information
mengxr committed Jan 22, 2015
1 parent 27bccc5 commit ca7910d
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,45 +279,80 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Initialize each run's center to a random point
// Initialize empty centers and point costs.
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()

// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))

/** Merges new centers to centers. */
def mergeNewCenters(): Unit = {
var r = 0
while (r < runs) {
centers(r) ++= newCenters(r)
newCenters(r).clear()
r += 1
}
}

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's current centers
// to their squared distance from that run's centers. Note that only distances between points
// and new centers are computed in each iteration.
var step = 0
while (step < initializationSteps) {
val bcCenters = data.context.broadcast(centers)
val sumCosts = data.flatMap { point =>
(0 until runs).map { r =>
(r, KMeans.pointCost(bcCenters.value(r), point))
}
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val bcNewCenters = data.context.broadcast(newCenters)
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
Vectors.dense(
Array.tabulate(runs) { r =>
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
})
}.cache()
val sumCosts = costs
.aggregate(Vectors.zeros(runs))(
seqOp = (s, v) => {
// s += v
axpy(1.0, v, s)
s
},
combOp = (s0, s1) => {
// s0 += s1
axpy(1.0, s1, s0)
s0
}
)
preCosts.unpersist(blocking = false)
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
points.flatMap { p =>
pointsWithCosts.flatMap { case (p, c) =>
(0 until runs).filter { r =>
rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
}.map((_, p))
}
}.collect()
mergeNewCenters()
chosen.foreach { case (r, p) =>
centers(r) += p.toDense
newCenters(r) += p.toDense
}
step += 1
}

mergeNewCenters()
costs.unpersist(blocking = false)

// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
(0 until runs).map { r =>
Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
Expand Down

0 comments on commit ca7910d

Please sign in to comment.