Skip to content

Commit

Permalink
Fix minors
Browse files Browse the repository at this point in the history
  • Loading branch information
yu-iskw committed Oct 29, 2015
1 parent 57b06ba commit 5da05d3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ class BisectingKMeans private (
val sc = input.sparkContext
val startTime = System.currentTimeMillis()
var data = initData(input).cache()
// this is used for managing calculated cached RDDs
var updatedDataHistory = Array.empty[RDD[(Long, BV[Double])]]

// `clusterStats` is described as binary tree structure
// `clusterStats` is described as binary tree structure as Map
// `clusterStats(1)` means the root of a binary tree
// `clusterStats(2n)` and `clusterStats(2n+1)` are the children of `clusterStats(n)`
var leafClusterStats = summarizeClusters(data)
var dividableLeafClusters = leafClusterStats.filter(_._2.isDividable)
var clusterStats = leafClusterStats
Expand All @@ -143,16 +143,15 @@ class BisectingKMeans private (
dividableLeafClusters = leafClusterStats.filter(_._2.isDividable)
clusterStats = clusterStats ++ leafClusterStats

// update each index
// keep recent 2 cached RDDs in order to run more quickly
updatedDataHistory = updatedDataHistory ++ Array(dividedData)
data = dividedData
// keep recent 2 cached RDDs in order to run more quickly
step += 1
if (updatedDataHistory.length > 1) {
val head = updatedDataHistory.head
updatedDataHistory = updatedDataHistory.tail
head.unpersist()
}
step += 1
}
// create a map of cluster node with their costs
val nodes = createClusterNodes(data, clusterStats)
Expand Down Expand Up @@ -312,24 +311,25 @@ private[clustering] object BisectingKMeans {
if (dividableClusterStats.isEmpty) {
return data
}

// extract dividable input data
val dividableData = data.filter { case (idx, point) => dividableClusterStats.contains(idx)}

// get next initial centers
var newCenters = initNextCenters(dividableData, dividableClusterStats)
// TODO Supports distance metrics other Euclidean distance metric
val metric = (bv1: BV[Double], bv2: BV[Double]) => breezeNorm(bv1 - bv2, 2.0)
val bcMetric = sc.broadcast(metric)
// pairs of cluster index and (sums, #points, sumOfSquares)
var stats = Map.empty[Long, (BV[Double], Long, Double)]

var nextData = data
var subIter = 0
var totalSumOfSquares = Double.MaxValue
var oldTotalSumOfSquares = Double.MaxValue
var relativeError = Double.MaxValue
val dimension = dividableData.first()._2.size
// TODO add a set method for the threshold, instead of 1e-4

// TODO Supports distance metrics other Euclidean distance metric
val metric = (bv1: BV[Double], bv2: BV[Double]) => breezeNorm(bv1 - bv2, 2.0)
val bcMetric = sc.broadcast(metric)

while (subIter < maxIterations && relativeError > 1e-4) {
// TODO add a set method for the threshold, instead of 1e-4

// convert each index into the closest child index
val bcNewCenters = sc.broadcast(newCenters)
nextData = dividableData.map { case (idx, point) =>
Expand Down Expand Up @@ -360,7 +360,7 @@ private[clustering] object BisectingKMeans {
// calculate the center of each cluster
newCenters = tempStats.map {case (idx, (sums, n, sumOfNorm)) => (idx, sums :/ n.toDouble)}

totalSumOfSquares = stats.map{case (idx, (sums, n, sumOfNorm)) => sumOfNorm}.sum
totalSumOfSquares = tempStats.map{case (idx, (sums, n, sumOfNorm)) => sumOfNorm}.sum
relativeError = math.abs(totalSumOfSquares - oldTotalSumOfSquares) / totalSumOfSquares
oldTotalSumOfSquares = totalSumOfSquares
subIter += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.mllib.util.TestingUtils._
class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {

test("run") {
val algo = new BisectingKMeans().setK(123).setSeed(1)
val k = 123
val algo = new BisectingKMeans().setK(k).setSeed(1)
val localSeed: Seq[Vector] = (0 to 999).map(i => Vectors.dense(i.toDouble, i.toDouble)).toSeq
val data = sc.parallelize(localSeed, 2)
val model = algo.run(data)
Expand All @@ -40,6 +41,9 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.node.getChildren.head.getParent.get === model.node)
assert(model.node.getChildren.apply(1).getParent.get === model.node)
assert(model.getClusters.forall(_.getParent.isDefined))

val predicted = model.predict(data)
assert(predicted.distinct.count() === k)
}

test("run with too many cluster size than the records") {
Expand Down

0 comments on commit 5da05d3

Please sign in to comment.