From 7917bc1e1b40e617a0d4c70fac6134a6ce76e510 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Sat, 28 Mar 2015 00:42:35 +0530 Subject: [PATCH 1/3] [SPARK-6576] DenseMatrix in PySpark should support indexing --- python/pyspark/mllib/linalg.py | 11 +++++++++++ python/pyspark/mllib/tests.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f5aad28afda0f..de44826cedbab 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -634,6 +634,17 @@ def toArray(self): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def __getitem__(self, indices): + i = indices[0] + j = indices[1] + if i > self.numRows: + raise ValueError("Row index %d should not be greater than numRows %d." % + (i, self.numRows)) + if j > self.numCols: + raise ValueError("Column index %d should not be greater than numCols %d." % + (i, self.numCols)) + return self.values[i * self.numCols + j] + def __eq__(self, other): return (isinstance(other, DenseMatrix) and self.numRows == other.numRows and diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3bb0f0ca68128..65b763ebc4689 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -135,6 +135,14 @@ def test_sparse_vector_indexing(self): for ind in [4, -5, 7.8]: self.assertRaises(ValueError, sv.__getitem__, ind) + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + numcols = mat.numCols + for ind, val in enumerate(mat.values): + rowind = ind // numcols + colind = ind - rowind * numcols + self.assertEquals(mat[rowind, colind], val) + class ListTests(PySparkTestCase): From a0620254debb8ab0c305a7c9962c1e6030c66289 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 1 Apr 2015 11:47:30 +0530 Subject: [PATCH 2/3] Matrices are stored in column order --- python/pyspark/mllib/linalg.py | 19 ++++++++++--------- python/pyspark/mllib/tests.py | 9 ++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index de44826cedbab..dfac543c6f77a 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -635,15 +635,16 @@ def toArray(self): return self.values.reshape((self.numRows, self.numCols), order='F') def __getitem__(self, indices): - i = indices[0] - j = indices[1] - if i > self.numRows: - raise ValueError("Row index %d should not be greater than numRows %d." % - (i, self.numRows)) - if j > self.numCols: - raise ValueError("Column index %d should not be greater than numCols %d." % - (i, self.numCols)) - return self.values[i * self.numCols + j] + i, j = indices + if i > self.numRows or i < 0: + raise ValueError( + "Row index %d should not be greater than numRows %d " + "or negative." % (i, self.numRows)) + if j > self.numCols or j < 0: + raise ValueError( + "Column index %d should not be greater than numCols %d " + "or negative." % (j, self.numCols)) + return self.values[i + j * self.numRows] def __eq__(self, other): return (isinstance(other, DenseMatrix) and diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 65b763ebc4689..893fc6f491ab3 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -137,11 +137,10 @@ def test_sparse_vector_indexing(self): def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - numcols = mat.numCols - for ind, val in enumerate(mat.values): - rowind = ind // numcols - colind = ind - rowind * numcols - self.assertEquals(mat[rowind, colind], val) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) class ListTests(PySparkTestCase): From a7350788e3ae943ed898cb36ed500880459b7596 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 2 Apr 2015 02:26:46 +0530 Subject: [PATCH 3/3] Change bounds --- python/pyspark/mllib/linalg.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index dfac543c6f77a..39b77a7e575fd 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -636,14 +636,12 @@ def toArray(self): def __getitem__(self, indices): i, j = indices - if i > self.numRows or i < 0: - raise ValueError( - "Row index %d should not be greater than numRows %d " - "or negative." % (i, self.numRows)) - if j > self.numCols or j < 0: - raise ValueError( - "Column index %d should not be greater than numCols %d " - "or negative." % (j, self.numCols)) + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) return self.values[i + j * self.numRows] def __eq__(self, other):