From e9ca16ec943b9553056482d0c085eacb6046821e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 9 Jan 2015 10:27:33 -0800 Subject: [PATCH] [SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM This pr uses BLAS.dsyr to replace few implementations in GaussianMixtureEM. Author: Liang-Chi Hsieh Closes #3949 from viirya/blas_dsyr and squashes the following commits: 4e4d6cf [Liang-Chi Hsieh] Add unit test. Rename function name, modify doc and style. 3f57fd2 [Liang-Chi Hsieh] Add BLAS.dsyr and use it in GaussianMixtureEM. --- .../mllib/clustering/GaussianMixtureEM.scala | 10 +++-- .../org/apache/spark/mllib/linalg/BLAS.scala | 26 ++++++++++++ .../apache/spark/mllib/linalg/BLASSuite.scala | 41 +++++++++++++++++++ 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index bdf984aee4dae..3a6c0e681e3fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} import org.apache.spark.mllib.stat.impl.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -151,9 +151,10 @@ class GaussianMixtureEM private ( var i = 0 while (i < k) { val mu = sums.means(i) / sums.weights(i) - val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr + BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector], + Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) weights(i) = sums.weights(i) / sumWeights - gaussians(i) = new MultivariateGaussian(mu, sigma) + gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } @@ -211,7 +212,8 @@ private object ExpectationSum { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += x * p(i) - sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr + BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector], + Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) i = i + 1 } sums diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9fed513becddc..3414daccd7ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -228,6 +228,32 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } + + /** + * A := alpha * x * x^T^ + A + * @param alpha a real scalar that will be multiplied to x * x^T^. + * @param x the vector x that contains the n elements. + * @param A the symmetric matrix A. Size of n x n. + */ + def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val mA = A.numRows + val nA = A.numCols + require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA") + require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + + nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) + + // Fill lower triangular part of A + var i = 0 + while (i < mA) { + var j = i + 1 + while (j < nA) { + A(j, i) = A(i, j) + j += 1 + } + i += 1 + } + } /** * C := alpha * A * B + beta * C diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 5d70c914f14b0..771878e925ea7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -127,6 +127,47 @@ class BLASSuite extends FunSuite { } } + test("syr") { + val dA = new DenseMatrix(4, 4, + Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) + val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1)) + val alpha = 0.15 + + val expected = new DenseMatrix(4, 4, + Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1, + 5.4505, 4.1025, 1.4615)) + + syr(alpha, x, dA) + + assert(dA ~== expected absTol 1e-15) + + val dB = + new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) + + withClue("Matrix A must be a symmetric Matrix") { + intercept[Exception] { + syr(alpha, x, dB) + } + } + + val dC = + new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) + + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + syr(alpha, x, dC) + } + } + + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) + + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + syr(alpha, y, dA) + } + } + } + test("gemm") { val dA =