Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17389] [SPARK-3261] [MLLIB] Significant KMeans speedup with better choice of init steps, optimizing to remove 'runs' #14948

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 76 additions & 159 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,21 @@ import org.apache.spark.util.random.XORShiftRandom
class KMeans private (
private var k: Int,
private var maxIterations: Int,
private var runs: Int,
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double,
private var seed: Long) extends Serializable with Logging {

/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20,
* initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}.
*/
@Since("0.8.0")
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())

/**
* Number of clusters to create (k).
* Number of clusters to create (k). Note that if the input has fewer than k elements,
* then it's possible that fewer than k clusters are created.
*/
@Since("1.4.0")
def getK: Int = k
Expand Down Expand Up @@ -112,15 +112,17 @@ class KMeans private (
* This function has no effect since Spark 2.0.0.
*/
@Since("1.4.0")
@deprecated("This has no effect and always returns 1", "2.1.0")
def getRuns: Int = {
logWarning("Getting number of runs has no effect since Spark 2.0.0.")
runs
1
}

/**
* This function has no effect since Spark 2.0.0.
*/
@Since("0.8.0")
@deprecated("This has no effect", "2.1.0")
def setRuns(runs: Int): this.type = {
logWarning("Setting number of runs has no effect since Spark 2.0.0.")
this
Expand All @@ -134,7 +136,7 @@ class KMeans private (

/**
* Set the number of steps for the k-means|| initialization mode. This is an advanced
* setting -- the default of 5 is almost always enough. Default: 5.
* setting -- the default of 2 is almost always enough. Default: 2.
*/
@Since("0.8.0")
def setInitializationSteps(initializationSteps: Int): this.type = {
Expand Down Expand Up @@ -239,17 +241,9 @@ class KMeans private (

val initStartTime = System.nanoTime()

// Only one run is allowed when initialModel is given
val numRuns = if (initialModel.nonEmpty) {
if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
1
} else {
runs
}

val centers = initialModel match {
case Some(kMeansCenters) =>
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
Expand All @@ -258,149 +252,107 @@ class KMeans private (
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

val active = Array.fill(numRuns)(true)
val costs = Array.fill(numRuns)(0.0)

var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var active = true
var cost = 0.0
var iteration = 0

val iterationStartTime = System.nanoTime()

instr.foreach(_.logNumFeatures(centers(0)(0).vector.size))

// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (Vector, Long)
def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
axpy(1.0, x._1, y._1)
(y._1, x._2 + y._2)
}

val activeCenters = activeRuns.map(r => centers(r)).toArray
val costAccums = activeRuns.map(_ => sc.doubleAccumulator)
instr.foreach(_.logNumFeatures(centers.head.vector.size))

val bcActiveCenters = sc.broadcast(activeCenters)
// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && active) {
val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast(centers)

// Find the sum and count of points mapping to each center
val totalContribs = data.mapPartitions { points =>
val thisActiveCenters = bcActiveCenters.value
val runs = thisActiveCenters.length
val k = thisActiveCenters(0).length
val dims = thisActiveCenters(0)(0).vector.size
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size

val sums = Array.fill(runs, k)(Vectors.zeros(dims))
val counts = Array.fill(runs, k)(0L)
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
val counts = Array.fill(thisCenters.length)(0L)

points.foreach { point =>
(0 until runs).foreach { i =>
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i).add(cost)
val sum = sums(i)(bestCenter)
axpy(1.0, point.vector, sum)
counts(i)(bestCenter) += 1
}
val (bestCenter, cost) = KMeans.findClosest(thisCenters, point)
costAccum.add(cost)
val sum = sums(bestCenter)
axpy(1.0, point.vector, sum)
counts(bestCenter) += 1
}

val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
((i, j), (sums(i)(j), counts(i)(j)))
}
contribs.iterator
}.reduceByKey(mergeContribs).collectAsMap()

bcActiveCenters.destroy(blocking = false)

// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
var changed = false
var j = 0
while (j < k) {
val (sum, count) = totalContribs((i, j))
if (count != 0) {
scal(1.0 / count, sum)
val newCenter = new VectorWithNorm(sum)
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
changed = true
}
centers(run)(j) = newCenter
}
j += 1
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.collectAsMap()

bcCenters.destroy(blocking = false)

// Update the cluster centers and costs
active = false
totalContribs.foreach { case (j, (sum, count)) =>
scal(1.0 / count, sum)
val newCenter = new VectorWithNorm(sum)
if (!active && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) {
active = true
}
if (!changed) {
active(run) = false
logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations")
}
costs(run) = costAccums(i).value
centers(j) = newCenter
}

activeRuns = activeRuns.filter(active(_))
cost = costAccum.value
iteration += 1
}

val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.")
logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

if (iteration == maxIterations) {
logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"KMeans converged in $iteration iterations.")
}

val (minCost, bestRun) = costs.zipWithIndex.min

logInfo(s"The cost for the best run is $minCost.")
logInfo(s"The cost is $cost.")

new KMeansModel(centers(bestRun).map(_.vector))
new KMeansModel(centers.map(_.vector))
}

/**
* Initialize `runs` sets of cluster centers at random.
* Initialize set of cluster centers at random.
*/
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
val sample = data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
sample.map(v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified as data.takeSample(false, k, new XORShiftRandom(this.seed).nextLong()).toSeq.toArray.

}

/**
* Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al.
* Initialize set of cluster centers using the k-means|| algorithm by Bahmani et al.
* (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries
* to find with dissimilar cluster centers by starting with a random center and then doing
* passes where more centers are chosen with probability proportional to their squared distance
* to the current cluster set. It results in a provable approximation to an optimal clustering.
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
// Initialize empty centers and point costs.
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
var costs = data.map(_ => Double.PositiveInfinity)

// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val sample = data.takeSample(false, 1, seed)
// Could be empty if data is empty; fail with a better message early:
require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data")
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))

/** Merges new centers to centers. */
def mergeNewCenters(): Unit = {
var r = 0
while (r < runs) {
centers(r) ++= newCenters(r)
newCenters(r).clear()
r += 1
}
}
require(sample.nonEmpty, s"No samples available from $data")

val centers = ArrayBuffer[VectorWithNorm]()
var newCenters = Seq(sample.head.toDense)
centers ++= newCenters

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's centers. Note that only distances between points
// On each step, sample 2 * k points on average with probability proportional
// to their squared distance from the centers. Note that only distances between points
// and new centers are computed in each iteration.
var step = 0
var bcNewCentersList = ArrayBuffer[Broadcast[_]]()
Expand All @@ -409,74 +361,39 @@ class KMeans private (
bcNewCentersList += bcNewCenters
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
Array.tabulate(runs) { r =>
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
}
math.min(KMeans.pointCost(bcNewCenters.value, point), cost)
}.persist(StorageLevel.MEMORY_AND_DISK)
val sumCosts = costs
.aggregate(new Array[Double](runs))(
seqOp = (s, v) => {
// s += v
var r = 0
while (r < runs) {
s(r) += v(r)
r += 1
}
s
},
combOp = (s0, s1) => {
// s0 += s1
var r = 0
while (r < runs) {
s0(r) += s1(r)
r += 1
}
s0
}
)
val sumCosts = costs.sum()

bcNewCenters.unpersist(blocking = false)
preCosts.unpersist(blocking = false)

val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
pointsWithCosts.flatMap { case (p, c) =>
val rs = (0 until runs).filter { r =>
rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
}
if (rs.nonEmpty) Some((p, rs)) else None
}
pointCosts.filter { case (_, c) => rand.nextDouble() < 2.0 * c * k / sumCosts }.map(_._1)
}.collect()
mergeNewCenters()
chosen.foreach { case (p, rs) =>
rs.foreach(newCenters(_) += p.toDense)
}
newCenters = chosen.map(_.toDense)
centers ++= newCenters
step += 1
}

mergeNewCenters()
costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))

// Finally, we might have a set of more than k candidate centers for each run; weigh each
if (centers.size <= k) {
return centers.toArray
}

// Finally, we might have a set of more than k candidate centers; weight each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val countMap = data.map(p => KMeans.findClosest(bcCenters.value, p)._1).countByValue()

bcCenters.destroy(blocking = false)

val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
}

finalCenters.toArray
val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {

// Make sure code runs.
var model = KMeans.train(data, k = 2, maxIterations = 1)
assert(model.clusterCenters.size === 2)
assert(model.clusterCenters.size === 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimization above that returns early if # centers <= k causes this behavior change. I think the new behavior is more correct, because before you could get duplicate centers.

This actually fixed SPARK-3261 too

}

test("more clusters than points") {
Expand All @@ -87,7 +87,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {

// Make sure code runs.
var model = KMeans.train(data, k = 3, maxIterations = 1)
assert(model.clusterCenters.size === 3)
assert(model.clusterCenters.size === 2)
}

test("deterministic initialization") {
Expand Down