Skip to content

Commit

Permalink
Deprecate KMeans 'runs' param which was a no-op since 2.0; optimize/s…
Browse files Browse the repository at this point in the history
…implify code now that runs = 1 always; most significantly, choose initializationSteps = 2 as default for default k-means|| init because the paper implies that's as optimal and is much faster.
  • Loading branch information
srowen committed Sep 3, 2016
1 parent f2d6e2e commit dbed7d7
Showing 1 changed file with 74 additions and 157 deletions.
231 changes: 74 additions & 157 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,17 @@ import org.apache.spark.util.random.XORShiftRandom
class KMeans private (
private var k: Int,
private var maxIterations: Int,
private var runs: Int,
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double,
private var seed: Long) extends Serializable with Logging {

/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20,
* initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}.
*/
@Since("0.8.0")
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())

/**
* Number of clusters to create (k).
Expand Down Expand Up @@ -112,15 +111,17 @@ class KMeans private (
* This function has no effect since Spark 2.0.0.
*/
@Since("1.4.0")
@deprecated("This has no effect and always returns 1", "2.1.0")
def getRuns: Int = {
logWarning("Getting number of runs has no effect since Spark 2.0.0.")
runs
1
}

/**
* This function has no effect since Spark 2.0.0.
*/
@Since("0.8.0")
@deprecated("This has no effect", "2.1.0")
def setRuns(runs: Int): this.type = {
logWarning("Setting number of runs has no effect since Spark 2.0.0.")
this
Expand Down Expand Up @@ -239,17 +240,9 @@ class KMeans private (

val initStartTime = System.nanoTime()

// Only one run is allowed when initialModel is given
val numRuns = if (initialModel.nonEmpty) {
if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
1
} else {
runs
}

val centers = initialModel match {
case Some(kMeansCenters) =>
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
Expand All @@ -258,149 +251,108 @@ class KMeans private (
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

val active = Array.fill(numRuns)(true)
val costs = Array.fill(numRuns)(0.0)

var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var active = true
var cost = 0.0
var iteration = 0

val iterationStartTime = System.nanoTime()

instr.foreach(_.logNumFeatures(centers(0)(0).vector.size))

// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (Vector, Long)
def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
axpy(1.0, x._1, y._1)
(y._1, x._2 + y._2)
}

val activeCenters = activeRuns.map(r => centers(r)).toArray
val costAccums = activeRuns.map(_ => sc.doubleAccumulator)
instr.foreach(_.logNumFeatures(centers.head.vector.size))

val bcActiveCenters = sc.broadcast(activeCenters)
// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && active) {
val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast(centers)

// Find the sum and count of points mapping to each center
val totalContribs = data.mapPartitions { points =>
val thisActiveCenters = bcActiveCenters.value
val runs = thisActiveCenters.length
val k = thisActiveCenters(0).length
val dims = thisActiveCenters(0)(0).vector.size
val thisCenters = bcCenters.value
val k = thisCenters.length
val dims = thisCenters.head.vector.size

val sums = Array.fill(runs, k)(Vectors.zeros(dims))
val counts = Array.fill(runs, k)(0L)
val sums = Array.fill(k)(Vectors.zeros(dims))
val counts = Array.fill(k)(0L)

points.foreach { point =>
(0 until runs).foreach { i =>
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i).add(cost)
val sum = sums(i)(bestCenter)
axpy(1.0, point.vector, sum)
counts(i)(bestCenter) += 1
}
val (bestCenter, cost) = KMeans.findClosest(thisCenters, point)
costAccum.add(cost)
val sum = sums(bestCenter)
axpy(1.0, point.vector, sum)
counts(bestCenter) += 1
}

val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
((i, j), (sums(i)(j), counts(i)(j)))
}
contribs.iterator
}.reduceByKey(mergeContribs).collectAsMap()

bcActiveCenters.destroy(blocking = false)

// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
var changed = false
var j = 0
while (j < k) {
val (sum, count) = totalContribs((i, j))
if (count != 0) {
scal(1.0 / count, sum)
val newCenter = new VectorWithNorm(sum)
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
changed = true
}
centers(run)(j) = newCenter
}
j += 1
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.collectAsMap()

bcCenters.destroy(blocking = false)

// Update the cluster centers and costs
active = false
totalContribs.foreach { case (j, (sum, count)) =>
scal(1.0 / count, sum)
val newCenter = new VectorWithNorm(sum)
if (!active && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) {
active = true
}
if (!changed) {
active(run) = false
logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations")
}
costs(run) = costAccums(i).value
centers(j) = newCenter
}

activeRuns = activeRuns.filter(active(_))
cost = costAccum.value
iteration += 1
}

val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.")
logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

if (iteration == maxIterations) {
logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"KMeans converged in $iteration iterations.")
}

val (minCost, bestRun) = costs.zipWithIndex.min

logInfo(s"The cost for the best run is $minCost.")
logInfo(s"The cost is $cost.")

new KMeansModel(centers(bestRun).map(_.vector))
new KMeansModel(centers.map(_.vector))
}

/**
* Initialize `runs` sets of cluster centers at random.
* Initialize set of cluster centers at random.
*/
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
val sample = data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt())
sample.map(v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm))
}

/**
* Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al.
* Initialize set of cluster centers using the k-means|| algorithm by Bahmani et al.
* (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries
* to find with dissimilar cluster centers by starting with a random center and then doing
* passes where more centers are chosen with probability proportional to their squared distance
* to the current cluster set. It results in a provable approximation to an optimal clustering.
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
// Initialize empty centers and point costs.
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
var costs = data.map(_ => Double.PositiveInfinity)

// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val sample = data.takeSample(false, 1, seed)
// Could be empty if data is empty; fail with a better message early:
require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data")
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))

/** Merges new centers to centers. */
def mergeNewCenters(): Unit = {
var r = 0
while (r < runs) {
centers(r) ++= newCenters(r)
newCenters(r).clear()
r += 1
}
}
require(sample.nonEmpty, s"No samples available from $data")

val centers = ArrayBuffer[VectorWithNorm]()
var newCenters = Seq(sample.head.toDense)
centers ++= newCenters

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's centers. Note that only distances between points
// On each step, sample 2 * k points on average with probability proportional
// to their squared distance from the centers. Note that only distances between points
// and new centers are computed in each iteration.
var step = 0
var bcNewCentersList = ArrayBuffer[Broadcast[_]]()
Expand All @@ -409,74 +361,39 @@ class KMeans private (
bcNewCentersList += bcNewCenters
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
Array.tabulate(runs) { r =>
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
}
math.min(KMeans.pointCost(bcNewCenters.value, point), cost)
}.persist(StorageLevel.MEMORY_AND_DISK)
val sumCosts = costs
.aggregate(new Array[Double](runs))(
seqOp = (s, v) => {
// s += v
var r = 0
while (r < runs) {
s(r) += v(r)
r += 1
}
s
},
combOp = (s0, s1) => {
// s0 += s1
var r = 0
while (r < runs) {
s0(r) += s1(r)
r += 1
}
s0
}
)
val sumCosts = costs.sum()

bcNewCenters.unpersist(blocking = false)
preCosts.unpersist(blocking = false)

val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
pointsWithCosts.flatMap { case (p, c) =>
val rs = (0 until runs).filter { r =>
rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
}
if (rs.nonEmpty) Some((p, rs)) else None
}
pointCosts.filter { case (_, c) => rand.nextDouble() < 2.0 * c * k / sumCosts }.map(_._1)
}.collect()
mergeNewCenters()
chosen.foreach { case (p, rs) =>
rs.foreach(newCenters(_) += p.toDense)
}
newCenters = chosen.map(_.toDense)
centers ++= newCenters
step += 1
}

mergeNewCenters()
costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))

// Finally, we might have a set of more than k candidate centers for each run; weigh each
if (centers.size <= k) {
return centers.toArray
}

// Finally, we might have a set of more than k candidate centers; weight each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val countMap = data.map(p => KMeans.findClosest(bcCenters.value, p)._1).countByValue()

bcCenters.destroy(blocking = false)

val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
}

finalCenters.toArray
val myWeights = centers.indices.map(countMap(_).toDouble).toArray
LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30)
}
}

Expand Down

0 comments on commit dbed7d7

Please sign in to comment.