Skip to content

Commit

Permalink
[SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM
Browse files Browse the repository at this point in the history
This pr uses BLAS.dsyr to replace few implementations in GaussianMixtureEM.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #3949 from viirya/blas_dsyr and squashes the following commits:

4e4d6cf [Liang-Chi Hsieh] Add unit test. Rename function name, modify doc and style.
3f57fd2 [Liang-Chi Hsieh] Add BLAS.dsyr and use it in GaussianMixtureEM.
  • Loading branch information
viirya authored and mengxr committed Jan 9, 2015
1 parent b6aa557 commit e9ca16e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq

import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils

Expand Down Expand Up @@ -151,9 +151,10 @@ class GaussianMixtureEM private (
var i = 0
while (i < k) {
val mu = sums.means(i) / sums.weights(i)
val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
weights(i) = sums.weights(i) / sumWeights
gaussians(i) = new MultivariateGaussian(mu, sigma)
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
i = i + 1
}

Expand Down Expand Up @@ -211,7 +212,8 @@ private object ExpectationSum {
p(i) /= pSum
sums.weights(i) += p(i)
sums.means(i) += x * p(i)
sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
i = i + 1
}
sums
Expand Down
26 changes: 26 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,32 @@ private[spark] object BLAS extends Serializable with Logging {
}
_nativeBLAS
}

/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
val mA = A.numRows
val nA = A.numCols
require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")

nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)

// Fill lower triangular part of A
var i = 0
while (i < mA) {
var j = i + 1
while (j < nA) {
A(j, i) = A(i, j)
j += 1
}
i += 1
}
}

/**
* C := alpha * A * B + beta * C
Expand Down
41 changes: 41 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,47 @@ class BLASSuite extends FunSuite {
}
}

test("syr") {
val dA = new DenseMatrix(4, 4,
Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1))
val alpha = 0.15

val expected = new DenseMatrix(4, 4,
Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1,
5.4505, 4.1025, 1.4615))

syr(alpha, x, dA)

assert(dA ~== expected absTol 1e-15)

val dB =
new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))

withClue("Matrix A must be a symmetric Matrix") {
intercept[Exception] {
syr(alpha, x, dB)
}
}

val dC =
new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))

withClue("Size of vector must match the rank of matrix") {
intercept[Exception] {
syr(alpha, x, dC)
}
}

val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))

withClue("Size of vector must match the rank of matrix") {
intercept[Exception] {
syr(alpha, y, dA)
}
}
}

test("gemm") {

val dA =
Expand Down

0 comments on commit e9ca16e

Please sign in to comment.