From e6b4cc5d977163a354e1906a3c7269af263c05d2 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 17 Mar 2015 21:57:03 +0530 Subject: [PATCH 1/2] [SPARK-6364] Implement equals and hashcode for Matrix --- .../apache/spark/mllib/linalg/Matrices.scala | 54 +++++++++++++++++++ .../spark/mllib/linalg/MatricesSuite.scala | 38 +++++++++++++ 2 files changed, 92 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 85e63b1382b5e..2aee7e926503a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -114,6 +114,50 @@ sealed trait Matrix extends Serializable { * corresponding value in the matrix with type `Double`. */ private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + + override def hashCode(): Int = { + var result: Int = numRows * numCols + 31 + this.foreachActive { case (rowInd, colInd, value) => + // ignore explict 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + rowInd + (numRows * colInd) + // refer to {@link java.util.Arrays.equals} for hash algorithm + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } + result + } + + override def equals(other: Any): Boolean = { + other match { + case mat: Matrix => + if (mat.numRows != this.numRows || mat.numCols != this.numCols) return false + (this, mat) match { + case (dm1: DenseMatrix, dm2: DenseMatrix) => + Arrays.equals(dm1.toArray, dm2.toArray) + case (sm1: SparseMatrix, sm2: SparseMatrix) => + // For the case in which one matrix is CSC and the other is CSR + // the values, colPtrs and rowIndices need not be the same. + // When both matrices are of the same type, it is sufficient to check that + // the values, colPtrs and rowIndices are the same. + if (sm1.isTransposed != sm2.isTransposed) { + if (sm1.values.length != sm2.values.length) return false + sm1.foreachActive { + case (i, j, value) => if (value != sm2(i, j)) return false + } + } else { + if (sm1.values != sm2.values) return false + if (sm1.colPtrs != sm2.colPtrs) return false + if (sm1.rowIndices != sm2.rowIndices) return false + } + true + case (dm1: DenseMatrix, sm1: SparseMatrix) => Matrices.equals(dm1, sm1) + case (sm1: SparseMatrix, dm1: DenseMatrix) => Matrices.equals(dm1, sm1) + } + case _ => false + } + } } @DeveloperApi @@ -814,6 +858,16 @@ object Matrices { } } + /** + * Check equality between sparse/dense matrices + */ + private[mllib] def equals(denseMat: DenseMatrix, sparseMat: SparseMatrix): Boolean = { + sparseMat.foreachActive { (row, col, value) => + if (value != denseMat(row, col)) return false + } + return true + } + /** * Generate a `Matrix` consisting of zeros. * @param numRows number of rows of the matrix diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 8dbb70f5d1c4c..e80b8cafd5305 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -454,5 +454,43 @@ class MatricesSuite extends SparkFunSuite { lines = mat.toString(5, 100).lines.toArray assert(lines.size == 5 && lines.forall(_.size <= 100)) + + test("equals") { + // A == (A.T).T + val dm = new DenseMatrix(2, 2, Array(0.0, 1.4, 0.3, 3.5)) + val dm2 = new DenseMatrix(2, 3, Array(0.0, 1.4, 2.9, 0.0, 0.0, 0.0)) + val dmt = new DenseMatrix(2, 2, Array(0.0, 0.3, 1.4, 3.5)).transpose + val dmt2 = new DenseMatrix(3, 2, Array(0.0, 2.9, 0.0, 1.4, 0.0, 0.0)).transpose + assert(dm == dmt) + assert(dm2 == dmt2) + assert(dm != dmt2) + assert(dm2 != dmt) + + // Check that dense matrix == corresponding sparse matrix. + val sm = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(1.4, 0.3, 3.5)) + val sm2 = new SparseMatrix(2, 3, Array(0, 1, 2, 2), Array(1, 0), Array(1.4, 2.9)) + val sm3 = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(1.4, 0.9, 3.5)) + val sm4 = new SparseMatrix(2, 3, Array(0, 1, 2, 2), Array(1, 0), Array(1.4, 2.2)) + assert(dm == sm) + assert(dm2 == sm2) + assert(dm != sm3) + assert(dm2 != sm4) + + // Check that dense matrix == corresponding CSC matrix. + val csr = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(0.3, 1.4, 3.5)) + val csr2 = new SparseMatrix(3, 2, Array(0, 1, 2), Array(1, 0), Array(2.9, 1.4)) + assert(dm == csr.transpose) + assert(dm2 == csr2.transpose) + assert(dm != csr2.transpose) + assert(dm2 != csr.transpose) + + // Check equality between csr and csc matrices + assert(sm == csr.transpose) + assert(sm2 == csr2.transpose) + assert(sm != csr2.transpose) + assert(sm2 != csr.transpose) + + Seq(dm, dm2, sm, sm2).foreach{mat => assert(mat == mat.asInstanceOf[Matrix])} + } } From 7f46d17eb7511c0f3395fa8f99e7713b7ad5f7b3 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 19 Jun 2015 23:30:59 +0530 Subject: [PATCH 2/2] improved hashing --- .../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 6 ++++-- .../scala/org/apache/spark/mllib/linalg/MatricesSuite.scala | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 2aee7e926503a..6d8db7e656561 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -116,11 +116,13 @@ sealed trait Matrix extends Serializable { private[spark] def foreachActive(f: (Int, Int, Double) => Unit) override def hashCode(): Int = { - var result: Int = numRows * numCols + 31 + var result: Int = 31 + numRows + result = 31 * result + numCols this.foreachActive { case (rowInd, colInd, value) => // ignore explict 0 for comparison between sparse and dense if (value != 0) { - result = 31 * result + rowInd + (numRows * colInd) + result = 31 * result + rowInd + result = 31 * result + colInd // refer to {@link java.util.Arrays.equals} for hash algorithm val bits = java.lang.Double.doubleToLongBits(value) result = 31 * result + (bits ^ (bits >>> 32)).toInt diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index e80b8cafd5305..b92f3f2962054 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -454,6 +454,7 @@ class MatricesSuite extends SparkFunSuite { lines = mat.toString(5, 100).lines.toArray assert(lines.size == 5 && lines.forall(_.size <= 100)) + } test("equals") { // A == (A.T).T