Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
120 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package nodes.learning | ||
|
||
import breeze.linalg._ | ||
import nodes.utils.external.ImageFeatures | ||
import org.apache.spark.rdd.RDD | ||
import pipelines._ | ||
import utils.MatrixUtils | ||
|
||
|
||
/** | ||
* A Mixture of Gaussians, usually computed via some clustering process. | ||
* | ||
* @param means Cluster centers. | ||
* @param variances Cluster variances (diagonal) | ||
* @param weights Cluster weights. | ||
*/ | ||
class GaussianMixtureModel( | ||
val means: DenseMatrix[Double], | ||
val variances: DenseMatrix[Double], | ||
val weights: DenseVector[Double]) | ||
extends Transformer[DenseVector[Double],DenseVector[Double]] with Logging { | ||
|
||
val k = means.cols | ||
val dim = means.rows | ||
|
||
require(means.rows == variances.rows && means.cols == variances.cols, "GMM means and variances must be the same size.") | ||
require(weights.length == k, "Every GMM center must have a weight.") | ||
|
||
/** | ||
* For now this is unimplemented. It should return the soft assignment to each cluster. | ||
* @param in A Vector | ||
* @return The soft assignments of the vector according to the mixture model. | ||
*/ | ||
def apply(in: RDD[DenseVector[Double]]): RDD[DenseVector[Double]] = ??? | ||
|
||
} | ||
|
||
|
||
/** | ||
* Fit a Gaussian Mixture model to Data. | ||
* | ||
* @param k Number of centers to estimate. | ||
*/ | ||
class GaussianMixtureModelEstimator(k: Int) extends Estimator[RDD[DenseVector[Double]], RDD[DenseVector[Double]]] { | ||
|
||
/** | ||
* Currently this model works on items that fit in local memory. | ||
* @param samples | ||
* @return A PipelineNode (Transformer) which can be called on new data. | ||
*/ | ||
def fit(samples: RDD[DenseVector[Double]]): GaussianMixtureModel = { | ||
fit(samples.collect) | ||
} | ||
|
||
/** | ||
* Fit a Gaussian mixture model with `k` centers to a sample array. | ||
* | ||
* @param samples Sample Array - all elements must be the same size. | ||
* @return A Gaussian Mixture Model. | ||
*/ | ||
def fit(samples: Array[DenseVector[Double]]): GaussianMixtureModel = { | ||
val extLib = new ImageFeatures | ||
val nDim = samples(0).length | ||
|
||
//Flatten this thing out. | ||
val sampleFloats = samples.map(_.toArray.map(_.toFloat)) | ||
val res = extLib.computeGMM(k, nDim, sampleFloats.flatten) | ||
|
||
val meanSize = k*nDim | ||
val varSize = k*nDim | ||
val coefSize = k*nDim | ||
|
||
// Each array region is expected to be centroid-major. | ||
val means = new DenseMatrix(nDim, k, res.slice(0, meanSize).map(_.toDouble)) | ||
val vars = new DenseMatrix(nDim, k, res.slice(meanSize, meanSize+varSize).map(_.toDouble)) | ||
val coefs = new DenseVector(res.slice(meanSize+varSize, meanSize+varSize+coefSize).map(_.toDouble)) | ||
|
||
new GaussianMixtureModel(means, vars, coefs) | ||
} | ||
} | ||
|
||
object GaussianMixtureModel { | ||
def load(meanFile: String, varsFile: String, weightsFile: String): GaussianMixtureModel = { | ||
|
||
val means = MatrixUtils.loadCSVFile(meanFile) | ||
val variances = MatrixUtils.loadCSVFile(varsFile) | ||
val weights = DenseVector(MatrixUtils.loadCSVFile(weightsFile).data) | ||
|
||
new GaussianMixtureModel(means, variances, weights) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters