Skip to content

Commit

Permalink
[SPARK-8018] [MLLIB] KMeans should accept initial cluster centers as …
Browse files Browse the repository at this point in the history
…param

 This allows Kmeans to be initialized using an existing set of cluster centers provided as  a KMeansModel object. This mode of initialization performs a single run.

Author: FlytxtRnD <meethu.mathew@flytxt.com>

Closes #6737 from FlytxtRnD/Kmeans-8018 and squashes the following commits:

94b56df [FlytxtRnD] style correction
ef95ee2 [FlytxtRnD] style correction
c446c58 [FlytxtRnD] documentation and numRuns warning change
06d13ef [FlytxtRnD] numRuns corrected
d12336e [FlytxtRnD] numRuns variable modifications
07f8554 [FlytxtRnD] remove setRuns from setIntialModel
e721dfe [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018
242ead1 [FlytxtRnD] corrected == to === in assert
714acb5 [FlytxtRnD] added numRuns
60c8ce2 [FlytxtRnD] ignore runs parameter and initialModel test suite changed
582e6d9 [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018
3f5fc8e [FlytxtRnD] test case modified and one runs condition added
cd5dc5c [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018
16f1b53 [FlytxtRnD] Merge branch 'Kmeans-8018', remote-tracking branch 'upstream/master' into Kmeans-8018
e9c35d7 [FlytxtRnD] Remove getInitialModel and match cluster count criteria
6959861 [FlytxtRnD] Accept initial cluster centers in KMeans
  • Loading branch information
FlytxtRnD authored and jkbradley committed Jul 15, 2015
1 parent 4692769 commit 3f6296f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on
a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed.

**Examples**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ class KMeans private (
this
}

// Initial cluster centers can be provided as a KMeansModel object rather than using the
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None

/**
* Set the initial starting point, bypassing the random initialization or k-means||
* The condition model.k == this.k must be met, failure results
* in an IllegalArgumentException.
*/
def setInitialModel(model: KMeansModel): this.type = {
require(model.k == k, "mismatched cluster count")
initialModel = Some(model)
this
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
Expand Down Expand Up @@ -193,20 +208,34 @@ class KMeans private (

val initStartTime = System.nanoTime()

val centers = if (initializationMode == KMeans.RANDOM) {
initRandom(data)
// Only one run is allowed when initialModel is given
val numRuns = if (initialModel.nonEmpty) {
if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
1
} else {
initKMeansParallel(data)
runs
}

val centers = initialModel match {
case Some(kMeansCenters) => {
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
}
case None => {
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")

val active = Array.fill(runs)(true)
val costs = Array.fill(runs)(0.0)
val active = Array.fill(numRuns)(true)
val costs = Array.fill(numRuns)(0.0)

var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var iteration = 0

val iterationStartTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}

test("Initialize using given cluster centers") {
val points = Seq(
Vectors.dense(0.0, 0.0),
Vectors.dense(1.0, 0.0),
Vectors.dense(0.0, 1.0),
Vectors.dense(1.0, 1.0)
)
val rdd = sc.parallelize(points, 3)
// creating an initial model
val initialModel = new KMeansModel(Array(points(0), points(2)))

val returnModel = new KMeans()
.setK(2)
.setMaxIterations(0)
.setInitialModel(initialModel)
.run(rdd)
// comparing the returned model and the initial model
assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
}

}

object KMeansSuite extends SparkFunSuite {
Expand Down

0 comments on commit 3f6296f

Please sign in to comment.