From 3f57fd2d2292be74e660dcb781c5c3a1d9a60ea2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Jan 2015 18:32:07 +0800 Subject: [PATCH] Add BLAS.dsyr and use it in GaussianMixtureEM. --- .../mllib/clustering/GaussianMixtureEM.scala | 10 ++++---- .../org/apache/spark/mllib/linalg/BLAS.scala | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index bdf984aee4dae..daeae9321fb5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} import org.apache.spark.mllib.stat.impl.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -151,9 +151,10 @@ class GaussianMixtureEM private ( var i = 0 while (i < k) { val mu = sums.means(i) / sums.weights(i) - val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr + BLAS.dsyr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector], + Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) weights(i) = sums.weights(i) / sumWeights - gaussians(i) = new MultivariateGaussian(mu, sigma) + gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } @@ -211,7 +212,8 @@ private object ExpectationSum { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += x * p(i) - sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr + BLAS.dsyr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector], + Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) i = i + 1 } sums diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9fed513becddc..9f621d5ec2e88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -228,6 +228,29 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } + + /** + * A := alpha * x * x^T + A + */ + def dsyr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val mA = A.numRows + val nA = A.numCols + require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA") + require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + + nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) + + // Fill lower triangular part of A + var i = 0 + while (i < mA) { + var j = i + 1 + while(j < nA) { + A(j, i) = A(i, j) + j += 1 + } + i += 1 + } + } /** * C := alpha * A * B + beta * C