From 2a497a3e42574f5f17e4ec2133b8b7f7ccdd37bb Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Tue, 28 Apr 2015 11:01:07 +0200 Subject: [PATCH] [FLINK-1937] [ml] Fixes sparse vector/matrix creation fromCOO with a single element --- .../org/apache/flink/ml/math/SparseMatrix.scala | 14 ++++++++++++++ .../org/apache/flink/ml/math/SparseVector.scala | 12 ++++++++++++ .../apache/flink/ml/math/SparseMatrixSuite.scala | 13 +++++++++++++ .../apache/flink/ml/math/SparseVectorSuite.scala | 10 ++++++++++ 4 files changed, 49 insertions(+) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala index 061c464778403..fe58ddbf70a38 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala @@ -250,4 +250,18 @@ object SparseMatrix{ new SparseMatrix(numRows, numCols, prunedRowIndices, colPtrs, prunedData) } + + /** Convenience method to convert a single tuple with an integer value into a SparseMatrix. + * The problem is that providing a single tuple to the fromCOO method, the Scala type inference + * cannot infer that the tuple has to be of type (Int, Int, Double) because of the overloading + * with the Iterable type. + * + * @param numRows + * @param numCols + * @param entry + * @return + */ + def fromCOO(numRows: Int, numCols: Int, entry: (Int, Int, Int)): SparseMatrix = { + fromCOO(numRows, numCols, (entry._1, entry._2, entry._3.toDouble)) + } } diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala index 6689aedba5426..6cf4c6318bd8f 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala @@ -178,4 +178,16 @@ object SparseVector { new SparseVector(size, indices, data) } + + /** Convenience method to be able to instantiate a SparseVector with a single element. The Scala + * type inference mechanism cannot infer that the second tuple value has to be of type Double + * if only a single tuple is provided. + * + * @param size + * @param entry + * @return + */ + def fromCOO(size: Int, entry: (Int, Int)): SparseVector = { + fromCOO(size, (entry._1, entry._2.toDouble)) + } } diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala index 132e7fe57ba35..74f2ccfff3c13 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala @@ -24,6 +24,19 @@ class SparseMatrixSuite extends FlatSpec with Matchers { behavior of "Flink's SparseMatrix" + it should "contain a single element provided as a coordinate list (COO)" in { + val sparseMatrix = SparseMatrix.fromCOO(4, 4, (0, 0, 1)) + + sparseMatrix(0, 0) should equal(1) + + for(i <- 1 until sparseMatrix.size) { + val row = i / sparseMatrix.numCols + val col = i % sparseMatrix.numCols + + sparseMatrix(row, col) should equal(0) + } + } + it should "be initialized from a coordinate list representation (COO)" in { val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17), (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1)) diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala index 97ef1cb458e36..cde141c0862ec 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala @@ -24,6 +24,16 @@ class SparseVectorSuite extends FlatSpec with Matchers { behavior of "Flink's SparseVector" + it should "contain a single element provided as coordinate list (COO)" in { + val sparseVector = SparseVector.fromCOO(3, (0, 1)) + + sparseVector(0) should equal(1) + + for(index <- 1 until 3) { + sparseVector(index) should equal(0) + } + } + it should "contain the initialization data provided as coordinate list (COO)" in { val data = List[(Int, Double)]((0, 1), (2, 0), (4, 42), (0, 3)) val size = 5