diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 69291fb26ed7a..d1953a12dbe90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.mllib.clustering import breeze.linalg.{Vector => BV} @@ -13,12 +30,45 @@ import org.apache.spark.SparkContext._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.StreamingContext._ +/** + * :: DeveloperApi :: + * + * StreamingKMeansModel extends MLlib's KMeansModel for streaming + * algorithms, so it can keep track of the number of points assigned + * to each cluster, and also update the model by doing a single iteration + * of the standard KMeans algorithm. + * + * The update algorithm uses the "mini-batch" KMeans rule, + * generalized to incorporate forgetfullness (i.e. decay). + * The basic update rule (for each cluster) is: + * + * c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t + m_t + * + * Where c_t is the previously estimated centroid for that cluster, + * n_t is the number of points assigned to it thus far, x_t is the centroid + * estimated on the current batch, and m_t is the number of points assigned + * to that centroid in the current batch. + * + * This update rule is modified with a decay factor 'a' that scales + * the contribution of the clusters as estimated thus far. + * If a=1, all batches are weighted equally. If a=0, new centroids + * are determined entirely by recent data. Lower values correspond to + * more forgetting. + * + * Decay can optionally be specified as a decay fraction 'q', + * which corresponds to the fraction of batches (or points) + * after which the past will be reduced to a contribution of 0.5. + * This decay fraction can be specified in units of 'points' or 'batches'. + * if 'batches', behavior will be independent of the number of points per batch; + * if 'points', the expected number of points per batch must be specified. + */ @DeveloperApi class StreamingKMeansModel( override val clusterCenters: Array[Vector], val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) { - /** do a sequential KMeans update on a batch of data **/ + // do a sequential KMeans update on a batch of data def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = { val centers = clusterCenters @@ -70,39 +120,49 @@ class StreamingKMeans( def this() = this(2, 1.0, "batches") + /** Set the number of clusters. */ def setK(k: Int): this.type = { this.k = k this } + /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { this.a = a this } + /** Set the decay units for forgetful algorithms ("batches" or "points"). */ def setUnits(units: String): this.type = { + if (units != "batches" && units != "points") { + throw new IllegalArgumentException("Invalid units for decay: " + units) + } this.units = units this } + /** Set decay fraction in units of batches. */ def setDecayFractionBatches(q: Double): this.type = { this.a = math.log(1 - q) / math.log(0.5) this.units = "batches" this } + /** Set decay fraction in units of points. Must specify expected number of points per batch. */ def setDecayFractionPoints(q: Double, m: Double): this.type = { this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m) this.units = "points" this } + /** Specify initial explicitly directly. */ def setInitialCenters(initialCenters: Array[Vector]): this.type = { val clusterCounts = Array.fill(this.k)(0).map(_.toLong) this.model = new StreamingKMeansModel(initialCenters, clusterCounts) this } + /** Initialize random centers, requiring only the number of dimensions. */ def setRandomCenters(d: Int): this.type = { val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray val clusterCounts = Array.fill(0)(d).map(_.toLong) @@ -110,10 +170,19 @@ class StreamingKMeans( this } + /** Return the latest model. */ def latestModel(): StreamingKMeansModel = { model } + /** + * Update the clustering model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * checks whether the cluster centers have been initialized, + * and updates the model using each batch of data from the stream. + * + * @param data DStream containing vector data + */ def trainOn(data: DStream[Vector]) { this.isInitialized data.foreachRDD { (rdd, time) => @@ -121,16 +190,34 @@ class StreamingKMeans( } } + /** + * Use the clustering model to make predictions on batches of data from a DStream. + * + * @param data DStream containing vector data + * @return DStream containing predictions + */ def predictOn(data: DStream[Vector]): DStream[Int] = { this.isInitialized data.map(model.predict) } + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * + * @param data DStream containing (key, feature vector) pairs + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { this.isInitialized data.mapValues(model.predict) } + /** + * Check whether cluster centers have been initialized. + * + * @return Boolean, True if cluster centrs have been initialized + */ def isInitialized: Boolean = { if (Option(model.clusterCenters) == None) { logError("Initial cluster centers must be set before starting predictions") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index a930bba0c5e19..5c23b04961df2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer @@ -43,6 +60,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) // estimated center from streaming should exactly match the arithmetic mean of all data points + // because the decay factor is set to 1.0 val grandMean = input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) @@ -74,7 +92,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { runStreams(ssc, numBatches, numBatches) // check that estimated centers are close to true centers - // NOTE this depends on the initialization! allow for binary flip + // NOTE exact assignment depends on the initialization! assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) assert(centers(1) ~== model.latestModel().clusterCenters(1) absTol 1E-1)