Skip to content

Commit

Permalink
Change decay parameterization
Browse files Browse the repository at this point in the history
- Use a single halfLife parameter that now determines the decay factor
directly
- Allow specification of timeUnit for the halfLife as “batches” or
“points”
- Documentation adjusted accordingly
  • Loading branch information
freeman-lab committed Oct 31, 2014
1 parent 9f7aea9 commit 0411bf5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 130 deletions.
8 changes: 7 additions & 1 deletion docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ to the cluster thus far, `$x_t$` is the new cluster center from the current batc
is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
with `$\alpha$=0` only the most recent data will be used. This is analogous to an
exponentially-weighted moving average.
exponentially-weighted moving average.

The decay can be specified using a `halfLife` parameter, which determines the
correct decay factor `a` such that, for data acquired
at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
The unit of time can be specified either as `batches` or `points` and the update rule
will be adjusted accordingly.

### Examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,28 @@ import org.apache.spark.util.Utils
*
* The update algorithm uses the "mini-batch" KMeans rule,
* generalized to incorporate forgetfullness (i.e. decay).
* The basic update rule (for each cluster) is:
* The update rule (for each cluster) is:
*
* c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t]
* n_t+t = n_t + m_t
* c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
* n_t+t = n_t * a + m_t
*
* Where c_t is the previously estimated centroid for that cluster,
* n_t is the number of points assigned to it thus far, x_t is the centroid
* estimated on the current batch, and m_t is the number of points assigned
* to that centroid in the current batch.
*
* This update rule is modified with a decay factor 'a' that scales
* the contribution of the clusters as estimated thus far.
* If a=1, all batches are weighted equally. If a=0, new centroids
* The decay factor 'a' scales the contribution of the clusters as estimated thus far,
* by applying a as a discount weighting on the current point when evaluating
* new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
* are determined entirely by recent data. Lower values correspond to
* more forgetting.
*
* Decay can optionally be specified as a decay fraction 'q',
* which corresponds to the fraction of batches (or points)
* after which the past will be reduced to a contribution of 0.5.
* This decay fraction can be specified in units of 'points' or 'batches'.
* if 'batches', behavior will be independent of the number of points per batch;
* if 'points', the expected number of points per batch must be specified.
* Decay can optionally be specified by a half life and associated
* time unit. The time unit can either be a batch of data or a single
* data point. Considering data arrived at time t, the half life h is defined
* such that at time t + h the discount applied to the data from t is 0.5.
* The definition remains the same whether the time unit is given
* as batches or points.
*
*/
@DeveloperApi
Expand All @@ -69,7 +69,7 @@ class StreamingKMeansModel(
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging {

/** Perform a k-means update on a batch of data. */
def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = {
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {

val centers = clusterCenters
val counts = clusterCounts
Expand All @@ -94,12 +94,12 @@ class StreamingKMeansModel(
val newCount = count
val newCentroid = mean / newCount.toDouble
// compute the normalized scale factor that controls forgetting
val decayFactor = units match {
case "batches" => newCount / (a * oldCount + newCount)
case "points" => newCount / (math.pow(a, newCount) * oldCount + newCount)
val lambda = timeUnit match {
case "batches" => newCount / (decayFactor * oldCount + newCount)
case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
}
// perform the update
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
// store the new counts and centers
counts(label) = oldCount + newCount
centers(label) = Vectors.fromBreeze(updatedCentroid)
Expand Down Expand Up @@ -134,8 +134,8 @@ class StreamingKMeansModel(
@DeveloperApi
class StreamingKMeans(
var k: Int,
var a: Double,
var units: String) extends Logging {
var decayFactor: Double,
var timeUnit: String) extends Logging {

protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)

Expand All @@ -149,30 +149,18 @@ class StreamingKMeans(

/** Set the decay factor directly (for forgetful algorithms). */
def setDecayFactor(a: Double): this.type = {
this.a = a
this.decayFactor = decayFactor
this
}

/** Set the decay units for forgetful algorithms ("batches" or "points"). */
def setUnits(units: String): this.type = {
if (units != "batches" && units != "points") {
throw new IllegalArgumentException("Invalid units for decay: " + units)
/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
if (timeUnit != "batches" && timeUnit != "points") {
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
}
this.units = units
this
}

/** Set decay fraction in units of batches. */
def setDecayFractionBatches(q: Double): this.type = {
this.a = math.log(1 - q) / math.log(0.5)
this.units = "batches"
this
}

/** Set decay fraction in units of points. Must specify expected number of points per batch. */
def setDecayFractionPoints(q: Double, m: Double): this.type = {
this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m)
this.units = "points"
this.decayFactor = math.exp(math.log(0.5) / halfLife)
logInfo("Setting decay factor to: %g ".format (this.decayFactor))
this.timeUnit = timeUnit
this
}

Expand Down Expand Up @@ -216,7 +204,7 @@ class StreamingKMeans(
def trainOn(data: DStream[Vector]) {
this.assertInitialized()
data.foreachRDD { (rdd, time) =>
model = model.update(rdd, this.a, this.units)
model = model.update(rdd, this.decayFactor, this.timeUnit)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.clustering

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.scalatest.FunSuite
Expand Down Expand Up @@ -98,94 +97,6 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {

}

test("drifting with fractional decay in units of batches") {

val numBatches1 = 50
val numBatches2 = 50
val numPoints = 1
val q = 0.25
val k = 1
val d = 1
val r = 2.0

// create model with two clusters
val model = new StreamingKMeans()
.setK(1)
.setDecayFractionBatches(q)
.setInitialCenters(Array(Vectors.dense(0.0)))

// create two batches of data with different, pre-specified centers
// to simulate a transition from one cluster to another
val (input1, centers1) = StreamingKMeansDataGenerator(
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
val (input2, centers2) = StreamingKMeansDataGenerator(
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))

// store the history
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)

// setup and run the model training
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
// extract the center (in this case one-dimensional)
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
inputDStream.count()
})
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)

// check that the fraction of batches required to reach 50
// equals the setting of q, by finding the index of the first batch
// below 50 and comparing to total number of batches received
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
assert(fraction ~== q absTol 1E-1)

}

test("drifting with fractional decay in units of points") {

val numBatches1 = 50
val numBatches2 = 50
val numPoints = 10
val q = 0.25
val k = 1
val d = 1
val r = 2.0

// create model with two clusters
val model = new StreamingKMeans()
.setK(1)
.setDecayFractionPoints(q, numPoints)
.setInitialCenters(Array(Vectors.dense(0.0)))

// create two batches of data with different, pre-specified centers
// to simulate a transition from one cluster to another
val (input1, centers1) = StreamingKMeansDataGenerator(
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
val (input2, centers2) = StreamingKMeansDataGenerator(
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))

// store the history
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)

// setup and run the model training
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
// extract the center (in this case one-dimensional)
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
inputDStream.count()
})
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)

// check that the fraction of batches required to reach 50
// equals the setting of q, by finding the index of the first batch
// below 50 and comparing to total number of batches received
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
assert(fraction ~== q absTol 1E-1)

}

def StreamingKMeansDataGenerator(
numPoints: Int,
numBatches: Int,
Expand Down

0 comments on commit 0411bf5

Please sign in to comment.