Skip to content

Commit

Permalink
Minor fixes and tweaks.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaloppo committed Dec 19, 2014
1 parent 1de73f3 commit b97fe00
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ object DenseGmmEM {
}
}

def run(inputFile: String, k: Int, convergenceTol: Double) {
private def run(inputFile: String, k: Int, convergenceTol: Double) {
val conf = new SparkConf().setAppName("Spark EM Sample")
val ctx = new SparkContext(conf)

val data = ctx.textFile(inputFile).map{ line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))
}.cache
}.cache()

val clusters = new GaussianMixtureModelEM()
.setK(k)
Expand All @@ -55,11 +55,11 @@ object DenseGmmEM {
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
}

println("Cluster labels:")
println("Cluster labels (first <= 100):")
val (responsibilityMatrix, clusterLabels) = clusters.predict(data)
for (x <- clusterLabels.collect) {
clusterLabels.take(100).foreach{ x =>
print(" " + x)
}
println
println()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ class GaussianMixtureModelEM private (
// (U, U) => U for aggregation
private def addExpectationSums(m1: ExpectationSum, m2: ExpectationSum): ExpectationSum = {
m1._1(0) += m2._1(0)
for (i <- 0 until m1._2.length) {
var i = 0
while (i < m1._2.length) {
m1._2(i) += m2._2(i)
m1._3(i) += m2._3(i)
m1._4(i) += m2._4(i)
i = i + 1
}
m1
}
Expand All @@ -90,11 +92,13 @@ class GaussianMixtureModelEM private (
val pSum = p.sum
sums._1(0) += math.log(pSum)
val xxt = x * new Transpose(x)
for (i <- 0 until k) {
var i = 0
while (i < k) {
p(i) /= pSum
sums._2(i) += p(i)
sums._3(i) += x * p(i)
sums._4(i) += xxt * p(i)
i = i + 1
}
sums
}
Expand Down Expand Up @@ -123,7 +127,7 @@ class GaussianMixtureModelEM private (
}

/** Return the user supplied initial GMM, if supplied */
def getInitialiGmm: Option[GaussianMixtureModel] = initialGmm
def getInitialGmm: Option[GaussianMixtureModel] = initialGmm

/** Set the number of Gaussians in the mixture model. Default: 2 */
def setK(k: Int): this.type = {
Expand Down Expand Up @@ -182,7 +186,7 @@ class GaussianMixtureModelEM private (

case None => {
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
((0 until k).map(_ => 1.0 / k).toArray, (0 until k).map{ i =>
(Array.fill[Double](k)(1.0 / k), (0 until k).map{ i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
}.toArray)
Expand Down

0 comments on commit b97fe00

Please sign in to comment.