diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d6f8b29a43dfd..b0e14cb8296a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} -import breeze.numerics.{abs, exp} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} +import breeze.numerics.{trigamma, abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi @@ -239,22 +239,26 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** alias for docConcentration */ private var alpha: Vector = Vectors.dense(0) - /** (private[clustering] for debugging) Get docConcentration */ + /** (for debugging) Get docConcentration */ private[clustering] def getAlpha: Vector = alpha /** alias for topicConcentration */ private var eta: Double = 0 - /** (private[clustering] for debugging) Get topicConcentration */ + /** (for debugging) Get topicConcentration */ private[clustering] def getEta: Double = eta private var randomGenerator: java.util.Random = null + /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */ + private var sampleWithReplacement: Boolean = true + // Online LDA specific parameters // Learning rate is: (tau0 + t)^{-kappa} private var tau0: Double = 1024 private var kappa: Double = 0.51 private var miniBatchFraction: Double = 0.05 + private var optimizeAlpha: Boolean = false // internal data structure private var docs: RDD[(Long, Vector)] = null @@ -262,7 +266,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** Dirichlet parameter for the posterior over topics */ private var lambda: BDM[Double] = null - /** (private[clustering] for debugging) Get parameter for topics */ + /** (for debugging) Get parameter for topics */ private[clustering] def getLambda: BDM[Double] = lambda /** Current iteration (count of invocations of [[next()]]) */ @@ -325,7 +329,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * (private[clustering]) + * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) + * will be optimized during training. + */ + def getOptimzeAlpha: Boolean = this.optimizeAlpha + + /** + * Sets whether to optimize alpha parameter during training. + * + * Default: false + */ + def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { + this.optimizeAlpha = optimizeAlpha + this + } + + /** * Set the Dirichlet parameter for the posterior over topics. * This is only used for testing now. In the future, it can help support training stop/resume. */ @@ -335,7 +354,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * (private[clustering]) * Used for random initialization of the variational parameters. * Larger value produces values closer to 1.0. * This is only used for testing currently. @@ -345,6 +363,15 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this } + /** + * Sets whether to sample mini-batches with or without replacement. (default = true) + * This is only used for testing currently. + */ + private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = { + this.sampleWithReplacement = replace + this + } + override private[clustering] def initialize( docs: RDD[(Long, Vector)], lda: LDA): OnlineLDAOptimizer = { @@ -376,7 +403,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } override private[clustering] def next(): OnlineLDAOptimizer = { - val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong()) + val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction, + randomGenerator.nextLong()) if (batch.isEmpty()) return this submitMiniBatch(batch) } @@ -418,6 +446,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) + if (optimizeAlpha) updateAlpha(gammat) this } @@ -433,13 +462,39 @@ final class OnlineLDAOptimizer extends LDAOptimizer { weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) } - /** Calculates learning rate rho, which decays as a function of [[iteration]] */ + /** + * Update alpha based on `gammat`, the inferred topic distributions for documents in the + * current mini-batch. Uses Newton-Rhapson method. + * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters + * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + */ + private def updateAlpha(gammat: BDM[Double]): Unit = { + val weight = rho() + val N = gammat.rows.toDouble + val alpha = this.alpha.toBreeze.toDenseVector + val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N + val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector) + + val c = N * trigamma(sum(alpha)) + val q = -N * trigamma(alpha) + val b = sum(gradf / q) / (1D / c + sum(1D / q)) + + val dalpha = -(gradf - b) / q + + if (all((weight * dalpha + alpha) :> 0D)) { + alpha :+= weight * dalpha + this.alpha = Vectors.dense(alpha.toArray) + } + } + + + /** Calculate learning rate rho for the current [[iteration]]. */ private def rho(): Double = { math.pow(getTau0 + this.iteration, -getKappa) } /** - * Get a random matrix to initialize lambda + * Get a random matrix to initialize lambda. */ private def getGammaMatrix(row: Int, col: Int): BDM[Double] = { val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index f2b94707fd0ff..fdc2554ab853e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -400,6 +400,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("OnlineLDAOptimizer alpha hyperparameter optimization") { + val k = 2 + val docs = sc.parallelize(toyData) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false) + val lda = new LDA().setK(k) + .setDocConcentration(1D / k) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel] + + /* Verify the results with gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.alpha) + > [ 0.42582646 0.43511073] + */ + + assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05) + } + test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics,