diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f976d2f97b043..6237b64c8f984 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -985,8 +985,10 @@ private[spark] object SerDe extends Serializable { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] val bytes = new Array[Byte](8 * m.values.size) val order = ByteOrder.nativeOrder() + val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + out.write(Opcodes.MARK) out.write(Opcodes.BININT) out.write(PickleUtils.integer_to_bytes(m.numRows)) out.write(Opcodes.BININT) @@ -994,19 +996,22 @@ private[spark] object SerDe extends Serializable { out.write(Opcodes.BINSTRING) out.write(PickleUtils.integer_to_bytes(bytes.length)) out.write(bytes) - out.write(Opcodes.TUPLE3) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) } def construct(args: Array[Object]): Object = { - if (args.length != 3) { - throw new PickleException("should be 3") + if (args.length != 4) { + throw new PickleException("should be 4") } val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) + val isTransposed = args(3).asInstanceOf[Int] == 1 + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) } } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 38b3aa3ad460e..c3729de752895 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -638,9 +638,10 @@ class Matrix(object): Represents a local matrix. """ - def __init__(self, numRows, numCols): + def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols + self.isTransposed = isTransposed def toArray(self): """ @@ -662,14 +663,16 @@ class DenseMatrix(Matrix): """ Column-major dense matrix. """ - def __init__(self, numRows, numCols, values): - Matrix.__init__(self, numRows, numCols) + def __init__(self, numRows, numCols, values, isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols self.values = values def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) + return DenseMatrix, ( + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) def toArray(self): """ @@ -680,15 +683,23 @@ def toArray(self): array([[ 0., 2.], [ 1., 3.]]) """ - return self.values.reshape((self.numRows, self.numCols), order='F') + if self.isTransposed: + return np.asfortranarray( + self.values.reshape((self.numRows, self.numCols))) + else: + return self.values.reshape((self.numRows, self.numCols), order='F') def toSparse(self): """Convert to SparseMatrix""" - indices = np.nonzero(self.values)[0] + if self.isTransposed: + values = np.ravel(self.toArray(), order='F') + else: + values = self.values + indices = np.nonzero(values)[0] colCounts = np.bincount(indices // self.numRows) colPtrs = np.cumsum(np.hstack( (0, colCounts, np.zeros(self.numCols - colCounts.size)))) - values = self.values[indices] + values = values[indices] rowIndices = indices % self.numRows return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) @@ -701,21 +712,28 @@ def __getitem__(self, indices): if j >= self.numCols or j < 0: raise ValueError("Column index %d is out of range [0, %d)" % (j, self.numCols)) - return self.values[i + j * self.numRows] + + if self.isTransposed: + return self.values[i * self.numCols + j] + else: + return self.values[i + j * self.numRows] def __eq__(self, other): - return (isinstance(other, DenseMatrix) and - self.numRows == other.numRows and - self.numCols == other.numCols and - all(self.values == other.values)) + if (not isinstance(other, DenseMatrix) or + self.numRows != other.numRows or + self.numCols != other.numCols): + return False + + self_values = np.ravel(self.toArray(), order='F') + other_values = np.ravel(other.toArray(), order='F') + return all(self_values == other_values) class SparseMatrix(Matrix): """Sparse Matrix stored in CSC format.""" def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False): - Matrix.__init__(self, numRows, numCols) - self.isTransposed = isTransposed + Matrix.__init__(self, numRows, numCols, isTransposed) self.colPtrs = self._convert_to_array(colPtrs, np.int32) self.rowIndices = self._convert_to_array(rowIndices, np.int32) self.values = self._convert_to_array(values, np.float64) @@ -777,8 +795,7 @@ def toArray(self): return A def toDense(self): - densevals = np.reshape( - self.toArray(), (self.numRows * self.numCols), order='F') + densevals = np.ravel(self.toArray(), order='F') return DenseMatrix(self.numRows, self.numCols, densevals) # TODO: More efficient implementation: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c6ed5acd1770e..3671ce8ccc475 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -193,6 +193,22 @@ def test_sparse_matrix(self): self.assertEquals(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEquals(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEquals(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + class ListTests(PySparkTestCase):