Skip to content

Commit

Permalink
hide transposeMultiply; add rng to rand and randn; add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 26, 2014
1 parent bf1a6aa commit 6bfd8a4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
20 changes: 9 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

package org.apache.spark.mllib.linalg

import java.util.Arrays
import java.util.{Random, Arrays}

import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM}

import org.apache.spark.util.random.XORShiftRandom

/**
* Trait for a local matrix.
*/
Expand Down Expand Up @@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable {
}

/** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
def transposeMultiply(y: DenseMatrix): DenseMatrix = {
private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = {
val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix]
BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
C
}

/** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
def transposeMultiply(y: DenseVector): DenseVector = {
private[mllib] def transposeMultiply(y: DenseVector): DenseVector = {
val output = new DenseVector(new Array[Double](numCols))
BLAS.gemv(true, 1.0, this, y, 0.0, output)
output
Expand Down Expand Up @@ -291,22 +289,22 @@ object Matrices {
* Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers.
* @param numRows number of rows of the matrix
* @param numCols number of columns of the matrix
* @param rng a random number generator
* @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
*/
def rand(numRows: Int, numCols: Int): Matrix = {
val rand = new XORShiftRandom
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble()))
def rand(numRows: Int, numCols: Int, rng: Random): Matrix = {
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble()))
}

/**
* Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers.
* @param numRows number of rows of the matrix
* @param numCols number of columns of the matrix
* @param rng a random number generator
* @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
*/
def randn(numRows: Int, numCols: Int): Matrix = {
val rand = new XORShiftRandom
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian()))
def randn(numRows: Int, numCols: Int, rng: Random): Matrix = {
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian()))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

package org.apache.spark.mllib.linalg

import java.util.Random

import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar._

class MatricesSuite extends FunSuite {
test("dense matrix construction") {
Expand Down Expand Up @@ -112,4 +116,50 @@ class MatricesSuite extends FunSuite {
assert(sparseMat(0, 1) === 10.0)
assert(sparseMat.values(2) === 10.0)
}

test("zeros") {
val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix]
assert(mat.numRows === 2)
assert(mat.numCols === 3)
assert(mat.values.forall(_ == 0.0))
}

test("ones") {
val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix]
assert(mat.numRows === 2)
assert(mat.numCols === 3)
assert(mat.values.forall(_ == 1.0))
}

test("eye") {
val mat = Matrices.eye(2).asInstanceOf[DenseMatrix]
assert(mat.numCols === 2)
assert(mat.numCols === 2)
assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0))
}

test("rand") {
val rng = mock[Random]
when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0)
val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix]
assert(mat.numRows === 2)
assert(mat.numCols === 2)
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}

test("randn") {
val rng = mock[Random]
when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0)
val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix]
assert(mat.numRows === 2)
assert(mat.numCols === 2)
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}

test("diag") {
val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix]
assert(mat.numRows === 2)
assert(mat.numCols === 2)
assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0))
}
}

0 comments on commit 6bfd8a4

Please sign in to comment.