Skip to content

Commit

Permalink
In GaussianMixtureModel: Changed name of weight, gaussian to weights,…
Browse files Browse the repository at this point in the history
… gaussians. Other sources modified accordingly.
  • Loading branch information
tgaloppo committed Jan 20, 2015
1 parent 091e8da commit 3ef6c7f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object DenseGmmEM {

for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weight(i), clusters.gaussian(i).mu, clusters.gaussian(i).sigma))
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}

println("Cluster labels (first <= 100):")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class GaussianMixtureEM private (
// diagonal covariance matrices using component variances
// derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weight, gmm.gaussian)
case Some(gmm) => (gmm.weights, gmm.gaussians)

case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
* covariance matrix for Gaussian i
*/
class GaussianMixtureModel(
val weight: Array[Double],
val gaussian: Array[MultivariateGaussian]) extends Serializable {
val weights: Array[Double],
val gaussians: Array[MultivariateGaussian]) extends Serializable {

require(weight.length == gaussian.length, "Length of weight and Gaussian arrays must match")
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")

/** Number of gaussians in mixture */
def k: Int = weight.length
def k: Int = weights.length

/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
Expand All @@ -56,10 +56,10 @@ class GaussianMixtureModel(
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val dists = sc.broadcast(gaussian)
val weights = sc.broadcast(weight)
val bcDists = sc.broadcast(gaussians)
val bcWeights = sc.broadcast(weights)
points.map { x =>
computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
seeds.foreach { seed =>
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
assert(gmm.weight(0) ~== Ew absTol 1E-5)
assert(gmm.gaussian(0).mu ~== Emu absTol 1E-5)
assert(gmm.gaussian(0).sigma ~== Esigma absTol 1E-5)
assert(gmm.weights(0) ~== Ew absTol 1E-5)
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
}

Expand Down Expand Up @@ -73,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
.setInitialModel(initialGmm)
.run(data)

assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
assert(gmm.gaussian(0).mu ~== Emu(0) absTol 1E-3)
assert(gmm.gaussian(1).mu ~== Emu(1) absTol 1E-3)
assert(gmm.gaussian(0).sigma ~== Esigma(0) absTol 1E-3)
assert(gmm.gaussian(1).sigma ~== Esigma(1) absTol 1E-3)
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
}

0 comments on commit 3ef6c7f

Please sign in to comment.