Skip to content

Commit

Permalink
change to optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Mar 23, 2015
1 parent 8cb16a6 commit f367cc9
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,32 +249,24 @@ class LDA private (


/**
* TODO: add API to take documents paths once tokenizer is ready.
* Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
*
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
* The term count vectors are "bags of words" with a fixed-size vocabulary
* (where the vocabulary size is the length of the vector).
* Document IDs must be unique and >= 0.
* @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
* -1 for automatic batchNumber.
* @param batchNumber Number of batches to split input corpus. For each batch, recommendation
* size is [4, 16384]. -1 for automatic batchNumber.
* @return Inferred LDA model
*/
def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = {
val D = documents.count().toInt
val batchSize =
if (batchNumber == -1) { // auto mode
if (D / 100 > 16384) 16384
else if (D / 100 < 4) 4
else D / 100
}
else {
require(batchNumber > 0, "batchNumber should be positive or -1")
D / batchNumber
}
require(batchNumber > 0 || batchNumber == -1,
s"batchNumber must be greater or -1, but was set to $batchNumber")

val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchSize)
(0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchNumber)
val model = onlineLDA.optimize()
new LocalLDAModel(Matrices.fromBreeze(model).transpose)
}

/** Java-friendly version of [[run()]] */
Expand Down Expand Up @@ -437,39 +429,54 @@ private[clustering] object LDA {
private[clustering] class OnlineLDAOptimizer(
private val documents: RDD[(Long, Vector)],
private val k: Int,
private val batchSize: Int) extends Serializable{
private val batchNumber: Int) extends Serializable{

private val vocabSize = documents.first._2.size
private val D = documents.count().toInt
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
private val batchSize =
if (batchNumber == -1) { // auto mode
if (D / 100 > 16384) 16384
else if (D / 100 < 4) 4
else D / 100
}
else {
D / batchNumber
}

// Initialize the variational distribution q(beta|lambda)
var lambda = getGammaMatrix(k, vocabSize) // K * V
private var lambda = getGammaMatrix(k, vocabSize) // K * V
private var Elogbeta = dirichlet_expectation(lambda) // K * V
private var expElogbeta = exp(Elogbeta) // K * V

private var batchId = 0
def next(): Unit = {
require(batchId < actualBatchNumber)
// weight of the mini-batch. 1024 down weights early iterations
val weight = math.pow(1024 + batchId, -0.5)
val batch = documents.sample(true, batchSize.toDouble / D)
batch.cache()
// Given a mini-batch of documents, estimates the parameters gamma controlling the
// variational distribution over the topic weights for each document in the mini-batch.
var stat = BDM.zeros[Double](k, vocabSize)
stat = batch.aggregate(stat)(seqOp, _ += _)
stat = stat :* expElogbeta
def optimize(): BDM[Double] = {
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
for(i <- 1 to actualBatchNumber){
val batch = documents.sample(true, batchSize.toDouble / D)

// Given a mini-batch of documents, estimates the parameters gamma controlling the
// variational distribution over the topic weights for each document in the mini-batch.
var stat = BDM.zeros[Double](k, vocabSize)
stat = batch.treeAggregate(stat)(gradient, _ += _)
update(stat, i)
}
lambda
}

private def update(raw: BDM[Double], iter:Int): Unit ={
// weight of the mini-batch. 1024 helps down weights early iterations
val weight = math.pow(1024 + iter, -0.5)

// This step finishes computing the sufficient statistics for the M step
val stat = raw :* expElogbeta

// Update lambda based on documents.
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
Elogbeta = dirichlet_expectation(lambda)
expElogbeta = exp(Elogbeta)
batchId += 1
}

// for each document d update that document's gamma and phi
private def seqOp(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
private def gradient(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
val termCounts = doc._2
val (ids, cts) = termCounts match {
case v: DenseVector => (((0 until v.size).toList), v.values)
Expand All @@ -488,7 +495,7 @@ private[clustering] object LDA {
val ctsVector = new BDV[Double](cts).t // 1 * ids

// Iterate between gamma and phi until convergence
while (meanchange > 1e-6) {
while (meanchange > 1e-5) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k
Expand Down

0 comments on commit f367cc9

Please sign in to comment.