/
BisectingKMeans.scala
567 lines (516 loc) · 19.7 KB
/
BisectingKMeans.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.clustering
import java.util.Random
import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
* A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques"
* by Steinbach, Karypis, and Kumar, with modification to fit Spark.
* The algorithm starts from a single cluster that contains all points.
* Iteratively it finds divisible clusters on the bottom level and bisects each of them using
* k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible.
* The bisecting steps of clusters on the same level are grouped together to increase parallelism.
* If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters,
* larger clusters get higher priority.
*
* @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if
* there are no divisible leaf clusters.
* @param maxIterations the max number of k-means iterations to split clusters (default: 20)
* @param minDivisibleClusterSize the minimum number of points (if greater than or equal 1.0) or
* the minimum proportion of points (if less than 1.0) of a divisible
* cluster (default: 1)
* @param seed a random seed (default: hash value of the class name)
*
* @see <a href="http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf">
* Steinbach, Karypis, and Kumar, A comparison of document clustering techniques,
* KDD Workshop on Text Mining, 2000.</a>
*/
@Since("1.6.0")
class BisectingKMeans private (
private var k: Int,
private var maxIterations: Int,
private var minDivisibleClusterSize: Double,
private var seed: Long,
private var distanceMeasure: String) extends Logging {
import BisectingKMeans._
/**
* Constructs with the default configuration
*/
@Since("1.6.0")
def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN)
/**
* Sets the desired number of leaf clusters (default: 4).
* The actual number could be smaller if there are no divisible leaf clusters.
*/
@Since("1.6.0")
def setK(k: Int): this.type = {
require(k > 0, s"k must be positive but got $k.")
this.k = k
this
}
/**
* Gets the desired number of leaf clusters.
*/
@Since("1.6.0")
def getK: Int = this.k
/**
* Sets the max number of k-means iterations to split clusters (default: 20).
*/
@Since("1.6.0")
def setMaxIterations(maxIterations: Int): this.type = {
require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.")
this.maxIterations = maxIterations
this
}
/**
* Gets the max number of k-means iterations to split clusters.
*/
@Since("1.6.0")
def getMaxIterations: Int = this.maxIterations
/**
* Sets the minimum number of points (if greater than or equal to `1.0`) or the minimum proportion
* of points (if less than `1.0`) of a divisible cluster (default: 1).
*/
@Since("1.6.0")
def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = {
require(minDivisibleClusterSize > 0.0,
s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.")
this.minDivisibleClusterSize = minDivisibleClusterSize
this
}
/**
* Gets the minimum number of points (if greater than or equal to `1.0`) or the minimum proportion
* of points (if less than `1.0`) of a divisible cluster.
*/
@Since("1.6.0")
def getMinDivisibleClusterSize: Double = minDivisibleClusterSize
/**
* Sets the random seed (default: hash value of the class name).
*/
@Since("1.6.0")
def setSeed(seed: Long): this.type = {
this.seed = seed
this
}
/**
* Gets the random seed.
*/
@Since("1.6.0")
def getSeed: Long = this.seed
/**
* The distance suite used by the algorithm.
*/
@Since("2.4.0")
def getDistanceMeasure: String = distanceMeasure
/**
* Set the distance suite used by the algorithm.
*/
@Since("2.4.0")
def setDistanceMeasure(distanceMeasure: String): this.type = {
DistanceMeasure.validateDistanceMeasure(distanceMeasure)
this.distanceMeasure = distanceMeasure
this
}
private[spark] def run(
input: RDD[Vector],
instr: Option[Instrumentation]): BisectingKMeansModel = {
val instances: RDD[(Vector, Double)] = input.map {
case (point) => (point, 1.0)
}
runWithWeight(instances, None)
}
private[spark] def runWithWeight(
input: RDD[(Vector, Double)],
instr: Option[Instrumentation]): BisectingKMeansModel = {
val d = input.map(_._1.size).first
logInfo(s"Feature dimension: $d.")
val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
// Compute and cache vector norms for fast distance computation.
val norms = input.map(d => Vectors.norm(d._1, 2.0))
val vectors = input.zip(norms).map {
case ((x, weight), norm) => new VectorWithNorm(x, norm, weight)
}
if (input.getStorageLevel == StorageLevel.NONE) {
vectors.persist(StorageLevel.MEMORY_AND_DISK)
}
var assignments = vectors.map(v => (ROOT_INDEX, v))
var activeClusters = summarize(d, assignments, dMeasure)
instr.foreach(_.logNumExamples(activeClusters.values.map(_.size).sum))
instr.foreach(_.logSumOfWeights(activeClusters.values.map(_.weightSum).sum))
val rootSummary = activeClusters(ROOT_INDEX)
val n = rootSummary.size
logInfo(s"Number of points: $n.")
logInfo(s"Initial cost: ${rootSummary.cost}.")
val minSize = if (minDivisibleClusterSize >= 1.0) {
math.ceil(minDivisibleClusterSize).toLong
} else {
math.ceil(minDivisibleClusterSize * n).toLong
}
logInfo(s"The minimum number of points of a divisible cluster is $minSize.")
var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)]
val random = new Random(seed)
var numLeafClustersNeeded = k - 1
var level = 1
var preIndices: RDD[Long] = null
var indices: RDD[Long] = null
while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) {
// Divisible clusters are sufficiently large and have non-trivial cost.
var divisibleClusters = activeClusters.filter { case (_, summary) =>
(summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size)
}
// If we don't need all divisible clusters, take the larger ones.
if (divisibleClusters.size > numLeafClustersNeeded) {
divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) =>
-summary.size
}.take(numLeafClustersNeeded)
.toMap
}
if (divisibleClusters.nonEmpty) {
val divisibleIndices = divisibleClusters.keys.toSet
logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.")
var newClusterCenters = divisibleClusters.flatMap { case (index, summary) =>
val (left, right) = splitCenter(summary.center, random, dMeasure)
Iterator((leftChildIndex(index), left), (rightChildIndex(index), right))
}.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map
var newClusters: Map[Long, ClusterSummary] = null
var newAssignments: RDD[(Long, VectorWithNorm)] = null
for (iter <- 0 until maxIterations) {
newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters,
dMeasure)
.filter { case (index, _) =>
divisibleIndices.contains(parentIndex(index))
}
newClusters = summarize(d, newAssignments, dMeasure)
newClusterCenters = newClusters.mapValues(_.center).map(identity)
}
if (preIndices != null) {
preIndices.unpersist()
}
preIndices = indices
indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys
.persist(StorageLevel.MEMORY_AND_DISK)
assignments = indices.zip(vectors)
inactiveClusters ++= activeClusters
activeClusters = newClusters
numLeafClustersNeeded -= divisibleClusters.size
} else {
logInfo(s"None active and divisible clusters left on level $level. Stop iterations.")
inactiveClusters ++= activeClusters
activeClusters = Map.empty
}
level += 1
}
if (preIndices != null) {
preIndices.unpersist()
}
if (indices != null) {
indices.unpersist()
}
vectors.unpersist()
val clusters = activeClusters ++ inactiveClusters
val root = buildTree(clusters, dMeasure)
val totalCost = root.leafNodes.map(_.cost).sum
new BisectingKMeansModel(root, this.distanceMeasure, totalCost)
}
/**
* Runs the bisecting k-means algorithm.
* @param input RDD of vectors
* @return model for the bisecting kmeans
*/
@Since("1.6.0")
def run(input: RDD[Vector]): BisectingKMeansModel = {
run(input, None)
}
/**
* Java-friendly version of `run()`.
*/
def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd)
}
private object BisectingKMeans extends Serializable {
/** The index of the root node of a tree. */
private val ROOT_INDEX: Long = 1
private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2
private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2)
/** Returns the left child index of the given node index. */
private def leftChildIndex(index: Long): Long = {
require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.")
2 * index
}
/** Returns the right child index of the given node index. */
private def rightChildIndex(index: Long): Long = {
require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.")
2 * index + 1
}
/** Returns the parent index of the given node index, or 0 if the input is 1 (root). */
private def parentIndex(index: Long): Long = {
index / 2
}
/**
* Summarizes data by each cluster as Map.
* @param d feature dimension
* @param assignments pairs of point and its cluster index
* @return a map from cluster indices to corresponding cluster summaries
*/
private def summarize(
d: Int,
assignments: RDD[(Long, VectorWithNorm)],
distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = {
assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))(
seqOp = (agg, v) => agg.add(v),
combOp = (agg1, agg2) => agg1.merge(agg2)
).mapValues(_.summary)
.collect().toMap
}
/**
* Cluster summary aggregator.
* @param d feature dimension
*/
private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure)
extends Serializable {
private var n: Long = 0L
private var weightSum: Double = 0.0
private val sum: Vector = Vectors.zeros(d)
private var sumSq: Double = 0.0
/** Adds a point. */
def add(v: VectorWithNorm): this.type = {
n += 1L
weightSum += v.weight
// TODO: use a numerically stable approach to estimate cost
sumSq += v.norm * v.norm * v.weight
distanceMeasure.updateClusterSum(v, sum)
this
}
/** Merges another aggregator. */
def merge(other: ClusterSummaryAggregator): this.type = {
n += other.n
weightSum += other.weightSum
sumSq += other.sumSq
axpy(1.0, other.sum, sum)
this
}
/** Returns the summary. */
def summary: ClusterSummary = {
val center = distanceMeasure.centroid(sum.copy, weightSum)
val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), weightSum,
sumSq)
ClusterSummary(n, weightSum, center, cost)
}
}
/**
* Bisects a cluster center.
*
* @param center current cluster center
* @param random a random number generator
* @return initial centers
*/
private def splitCenter(
center: VectorWithNorm,
random: Random,
distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = {
val d = center.vector.size
val norm = center.norm
val level = 1e-4 * norm
val noise = Vectors.dense(Array.fill(d)(random.nextDouble()))
distanceMeasure.symmetricCentroids(level, noise, center.vector)
}
/**
* Updates assignments.
* @param assignments current assignments
* @param divisibleIndices divisible cluster indices
* @param newClusterCenters new cluster centers
* @return new assignments
*/
private def updateAssignments(
assignments: RDD[(Long, VectorWithNorm)],
divisibleIndices: Set[Long],
newClusterCenters: Map[Long, VectorWithNorm],
distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
val newClusterChildren = children.filter(newClusterCenters.contains)
val newClusterChildrenCenterToId =
newClusterChildren.map(id => newClusterCenters(id) -> id).toMap
val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray
if (newClusterChildren.nonEmpty) {
val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1
val center = newClusterChildrenCenters(selected)
val id = newClusterChildrenCenterToId(center)
(id, v)
} else {
(index, v)
}
} else {
(index, v)
}
}
}
/**
* Builds a clustering tree by re-indexing internal and leaf clusters.
* @param clusters a map from cluster indices to corresponding cluster summaries
* @return the root node of the clustering tree
*/
private def buildTree(
clusters: Map[Long, ClusterSummary],
distanceMeasure: DistanceMeasure): ClusteringTreeNode = {
var leafIndex = 0
var internalIndex = -1
/**
* Builds a subtree from this given node index.
*/
def buildSubTree(rawIndex: Long): ClusteringTreeNode = {
val cluster = clusters(rawIndex)
val size = cluster.size
val center = cluster.center
val cost = cluster.cost
val isInternal = clusters.contains(leftChildIndex(rawIndex))
if (isInternal) {
val index = internalIndex
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains)
val height = indexes.map { childIndex =>
distanceMeasure.distance(center, clusters(childIndex).center)
}.max
val children = indexes.map(buildSubTree).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
val height = 0.0
new ClusteringTreeNode(index, size, center, cost, height, Array.empty)
}
}
buildSubTree(ROOT_INDEX)
}
/**
* Summary of a cluster.
*
* @param size the number of points within this cluster
* @param weightSum the weightSum within this cluster
* @param center the center of the points within this cluster
* @param cost the sum of squared distances to the center
*/
private case class ClusterSummary(
size: Long,
weightSum: Double,
center: VectorWithNorm,
cost: Double)
}
/**
* Represents a node in a clustering tree.
*
* @param index node index, negative for internal nodes and non-negative for leaf nodes
* @param size size of the cluster
* @param centerWithNorm cluster center with norm
* @param cost cost of the cluster, i.e., the sum of squared distances to the center
* @param height height of the node in the dendrogram. Currently this is defined as the max distance
* from the center to the centers of the children's, but subject to change.
* @param children children nodes
*/
@Since("1.6.0")
private[clustering] class ClusteringTreeNode private[clustering] (
val index: Int,
val size: Long,
private[clustering] val centerWithNorm: VectorWithNorm,
val cost: Double,
val height: Double,
val children: Array[ClusteringTreeNode]) extends Serializable {
/** Whether this is a leaf node. */
val isLeaf: Boolean = children.isEmpty
require((isLeaf && index >= 0) || (!isLeaf && index < 0))
/** Cluster center. */
def center: Vector = centerWithNorm.vector
/** Predicts the leaf cluster node index that the input point belongs to. */
def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = {
val (index, _) = predict(new VectorWithNorm(point), distanceMeasure)
index
}
/** Returns the full prediction path from root to leaf. */
def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = {
predictPath(new VectorWithNorm(point), distanceMeasure).toArray
}
/** Returns the full prediction path from root to leaf. */
private def predictPath(
pointWithNorm: VectorWithNorm,
distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = {
if (isLeaf) {
this :: Nil
} else {
val selected = children.minBy { child =>
distanceMeasure.distance(child.centerWithNorm, pointWithNorm)
}
selected :: selected.predictPath(pointWithNorm, distanceMeasure)
}
}
/**
* Computes the cost of the input point.
*/
def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = {
val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure)
cost
}
/**
* Predicts the cluster index and the cost of the input point.
*/
private def predict(
pointWithNorm: VectorWithNorm,
distanceMeasure: DistanceMeasure): (Int, Double) = {
predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure)
}
/**
* Predicts the cluster index and the cost of the input point.
* @param pointWithNorm input point
* @param cost the cost to the current center
* @return (predicted leaf cluster index, cost)
*/
@tailrec
private def predict(
pointWithNorm: VectorWithNorm,
cost: Double,
distanceMeasure: DistanceMeasure): (Int, Double) = {
if (isLeaf) {
(index, cost)
} else {
val (selectedChild, minCost) = children.map { child =>
(child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm))
}.minBy(_._2)
selectedChild.predict(pointWithNorm, minCost, distanceMeasure)
}
}
/**
* Returns all leaf nodes from this node.
*/
def leafNodes: Array[ClusteringTreeNode] = {
if (isLeaf) {
Array(this)
} else {
children.flatMap(_.leafNodes)
}
}
}