Skip to content

Commit

Permalink
Adding GMM, tests, and Scaladoc.
Browse files Browse the repository at this point in the history
  • Loading branch information
etrain committed May 4, 2015
1 parent cc24026 commit bbe6274
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 5 deletions.
91 changes: 91 additions & 0 deletions src/main/scala/nodes/learning/GaussianMixtureModel.scala
@@ -0,0 +1,91 @@
package nodes.learning

import breeze.linalg._
import nodes.utils.external.ImageFeatures
import org.apache.spark.rdd.RDD
import pipelines._
import utils.MatrixUtils


/**
* A Mixture of Gaussians, usually computed via some clustering process.
*
* @param means Cluster centers.
* @param variances Cluster variances (diagonal)
* @param weights Cluster weights.
*/
class GaussianMixtureModel(
val means: DenseMatrix[Double],
val variances: DenseMatrix[Double],
val weights: DenseVector[Double])
extends Transformer[DenseVector[Double],DenseVector[Double]] with Logging {

val k = means.cols
val dim = means.rows

require(means.rows == variances.rows && means.cols == variances.cols, "GMM means and variances must be the same size.")
require(weights.length == k, "Every GMM center must have a weight.")

/**
* For now this is unimplemented. It should return the soft assignment to each cluster.
* @param in A Vector
* @return The soft assignments of the vector according to the mixture model.
*/
def apply(in: RDD[DenseVector[Double]]): RDD[DenseVector[Double]] = ???

}


/**
* Fit a Gaussian Mixture model to Data.
*
* @param k Number of centers to estimate.
*/
class GaussianMixtureModelEstimator(k: Int) extends Estimator[RDD[DenseVector[Double]], RDD[DenseVector[Double]]] {

/**
* Currently this model works on items that fit in local memory.
* @param samples
* @return A PipelineNode (Transformer) which can be called on new data.
*/
def fit(samples: RDD[DenseVector[Double]]): GaussianMixtureModel = {
fit(samples.collect)
}

/**
* Fit a Gaussian mixture model with `k` centers to a sample array.
*
* @param samples Sample Array - all elements must be the same size.
* @return A Gaussian Mixture Model.
*/
def fit(samples: Array[DenseVector[Double]]): GaussianMixtureModel = {
val extLib = new ImageFeatures
val nDim = samples(0).length

//Flatten this thing out.
val sampleFloats = samples.map(_.toArray.map(_.toFloat))
val res = extLib.computeGMM(k, nDim, sampleFloats.flatten)

val meanSize = k*nDim
val varSize = k*nDim
val coefSize = k*nDim

// Each array region is expected to be centroid-major.
val means = new DenseMatrix(nDim, k, res.slice(0, meanSize).map(_.toDouble))
val vars = new DenseMatrix(nDim, k, res.slice(meanSize, meanSize+varSize).map(_.toDouble))
val coefs = new DenseVector(res.slice(meanSize+varSize, meanSize+varSize+coefSize).map(_.toDouble))

new GaussianMixtureModel(means, vars, coefs)
}
}

object GaussianMixtureModel {
def load(meanFile: String, varsFile: String, weightsFile: String): GaussianMixtureModel = {

val means = MatrixUtils.loadCSVFile(meanFile)
val variances = MatrixUtils.loadCSVFile(varsFile)
val weights = DenseVector(MatrixUtils.loadCSVFile(weightsFile).data)

new GaussianMixtureModel(means, variances, weights)
}
}
4 changes: 2 additions & 2 deletions src/test/scala/utils/TestUtils.scala
Expand Up @@ -22,7 +22,7 @@ object TestUtils {
* @param pathInTestResources Input path.
* @return Resource URI.
*/
def getTestResourceURI(pathInTestResources: String) = {
getClass.getClassLoader.getResource(pathInTestResources).toURI
def getTestResourceFileName(pathInTestResources: String): String = {
getClass.getClassLoader.getResource(pathInTestResources).getFile
}
}
30 changes: 27 additions & 3 deletions src/test/scala/utils/external/ImageFeaturesSuite.scala
Expand Up @@ -2,6 +2,7 @@ package nodes.utils.external

import breeze.linalg._
import breeze.numerics._
import nodes.learning.GaussianMixtureModel
import org.scalatest.FunSuite
import pipelines.Logging
import utils.TestUtils._
Expand All @@ -28,14 +29,37 @@ class ImageFeaturesSuite extends FunSuite with Logging {
val result = new DenseMatrix(descriptorLength, numCols, rawDescDataShort.map(_.toDouble))

log.info(s"SIFT is ${result.toArray.sum}")
assert(Stats.aboutEq(result.toArray.sum, 8.6163289E7), "SUM of SIFTs must match the expected norm.")
assert(Stats.aboutEq(result.toArray.sum, 8.6163289E7), "SUM of SIFTs must match the expected sum.")

}

test("Load SIFT Descriptors and compute Fisher Vector Features") {

val siftDescriptor = MatrixUtils.loadCSVFile(TestUtils.getTestResourceURI("images/feats.csv").toString)
//val gmm =
val siftDescriptor = MatrixUtils.loadCSVFile(TestUtils.getTestResourceFileName("images/feats.csv").toString)

val gmmMeans = TestUtils.getTestResourceFileName("images/voc_codebook/means.csv")
val gmmVars = TestUtils.getTestResourceFileName("images/voc_codebook/variances.csv")
val gmmWeights = TestUtils.getTestResourceFileName("images/voc_codebook/priors")

val gmm = GaussianMixtureModel.load(gmmMeans, gmmVars, gmmWeights)

val nCenters = gmm.means.cols
val nDim = gmm.means.rows

val extLib = new ImageFeatures

val fisherVector = extLib.calcAndGetFVs(
gmm.means.toArray.map(_.toFloat),
nCenters,
nDim,
gmm.variances.toArray.map(_.toFloat),
gmm.weights.toArray.map(_.toFloat),
siftDescriptor.toArray.map(_.toFloat))

log.info(s"Fisher Vector is ${fisherVector.sum}")
assert(Stats.aboutEq(fisherVector.sum, 40.109097, 1e-4), "SUM of Fisher Vectors must match expected sum.")

}


}

0 comments on commit bbe6274

Please sign in to comment.