Skip to content

Commit

Permalink
SPARK-5019 - GaussianMixtureModel exposes instances of MultivariateGa…
Browse files Browse the repository at this point in the history
…ussian rather than mean/covariance matrices
  • Loading branch information
tgaloppo committed Jan 17, 2015
1 parent 5d9fa55 commit 091e8da
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 24 deletions.
Expand Up @@ -54,7 +54,7 @@ object DenseGmmEM {

for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
(clusters.weight(i), clusters.gaussian(i).mu, clusters.gaussian(i).sigma))
}

println("Cluster labels (first <= 100):")
Expand Down
Expand Up @@ -134,9 +134,7 @@ class GaussianMixtureEM private (
// diagonal covariance matrices using component variances
// derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
new MultivariateGaussian(mu, sigma)
})
case Some(gmm) => (gmm.weight, gmm.gaussian)

case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
Expand Down Expand Up @@ -176,10 +174,7 @@ class GaussianMixtureEM private (
iter += 1
}

// Need to convert the breeze matrices to MLlib matrices
val means = Array.tabulate(k) { i => gaussians(i).mu }
val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
new GaussianMixtureModel(weights, means, sigmas)
new GaussianMixtureModel(weights, gaussians)
}

/** Average of dense breeze vectors */
Expand Down
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils

Expand All @@ -37,8 +37,9 @@ import org.apache.spark.mllib.util.MLUtils
*/
class GaussianMixtureModel(
val weight: Array[Double],
val mu: Array[Vector],
val sigma: Array[Matrix]) extends Serializable {
val gaussian: Array[MultivariateGaussian]) extends Serializable {

require(weight.length == gaussian.length, "Length of weight and Gaussian arrays must match")

/** Number of gaussians in mixture */
def k: Int = weight.length
Expand All @@ -55,11 +56,7 @@ class GaussianMixtureModel(
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val dists = sc.broadcast {
(0 until k).map { i =>
new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
}.toArray
}
val dists = sc.broadcast(gaussian)
val weights = sc.broadcast(weight)
points.map { x =>
computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
Expand Down
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

Expand All @@ -40,8 +41,8 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
seeds.foreach { seed =>
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
assert(gmm.weight(0) ~== Ew absTol 1E-5)
assert(gmm.mu(0) ~== Emu absTol 1E-5)
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
assert(gmm.gaussian(0).mu ~== Emu absTol 1E-5)
assert(gmm.gaussian(0).sigma ~== Esigma absTol 1E-5)
}
}

Expand All @@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
Array(
new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
)
)

val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
Expand All @@ -72,9 +75,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex

assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
assert(gmm.gaussian(0).mu ~== Emu(0) absTol 1E-3)
assert(gmm.gaussian(1).mu ~== Emu(1) absTol 1E-3)
assert(gmm.gaussian(0).sigma ~== Esigma(0) absTol 1E-3)
assert(gmm.gaussian(1).sigma ~== Esigma(1) absTol 1E-3)
}
}

0 comments on commit 091e8da

Please sign in to comment.