Skip to content

Commit

Permalink
Corrected a variety of style and naming issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaloppo committed Dec 12, 2014
1 parent 8aaa17d commit 9770261
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,34 @@
package org.apache.spark.examples.mllib

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.clustering.GMMExpectationMaximization
import org.apache.spark.mllib.clustering.GaussianMixtureModelEM
import org.apache.spark.mllib.linalg.Vectors

object DenseGmmEM {
def main(args: Array[String]): Unit = {
if( args.length != 3 ) {
println("usage: DenseGmmEM <input file> <k> <delta>")
if (args.length != 3) {
println("usage: DenseGmmEM <input file> <k> <convergenceTol>")
} else {
run(args(0), args(1).toInt, args(2).toDouble)
}
}

def run(inputFile: String, k: Int, tol: Double) {
def run(inputFile: String, k: Int, convergenceTol: Double) {
val conf = new SparkConf().setAppName("Spark EM Sample")
val ctx = new SparkContext(conf)

val data = ctx.textFile(inputFile).map(line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))).cache()
val data = ctx.textFile(inputFile).map{ line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))
}.cache()

val clusters = GMMExpectationMaximization.train(data, k, tol)
val clusters = new GaussianMixtureModelEM()
.setK(k)
.setConvergenceTol(convergenceTol)
.run(data)

for(i <- 0 until clusters.k) {
println("w=%f mu=%s sigma=\n%s\n" format (clusters.w(i), clusters.mu(i), clusters.sigma(i)))
for (i <- 0 until clusters.k) {
println("weight=%f mu=%s sigma=\n%s\n" format
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ import org.apache.spark.mllib.linalg.Vector
* Multivariate Gaussian mixture model consisting of k Gaussians, where points are drawn
* from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are the respective
* mean and covariance for each Gaussian distribution i=1..k.
*
* @param weight Weights for each Gaussian distribution in the mixture, where mu(i) is
* the weight for Gaussian i, and weight.sum == 1
* @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i
* @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the
* covariance matrix for Gaussian i
*/
class GaussianMixtureModel(
val w: Array[Double],
val weight: Array[Double],
val mu: Array[Vector],
val sigma: Array[Matrix]) extends Serializable {

/** Number of gaussians in mixture */
def k: Int = w.length;
def k: Int = weight.length;
}
Loading

0 comments on commit 9770261

Please sign in to comment.