diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d3d082135800e..41e4c41d038af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -988,6 +988,7 @@ private[spark] object SerDe extends Serializable { val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + out.write(Opcodes.MARK) out.write(Opcodes.BININT) out.write(PickleUtils.integer_to_bytes(m.numRows)) out.write(Opcodes.BININT) @@ -1008,7 +1009,7 @@ private[spark] object SerDe extends Serializable { val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() - val isTransposed = args(3).asInstanceOf[Boolean] + val isTransposed = if (args(3).asInstanceOf[Int] == 1) true else false ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index a808f1bfc17e9..c3729de752895 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -671,7 +671,8 @@ def __init__(self, numRows, numCols, values, isTransposed=False): def __reduce__(self): return DenseMatrix, ( - self.numRows, self.numCols, self.values.tostring(), self.isTransposed) + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) def toArray(self): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1aae57eeb05a5..3671ce8ccc475 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -85,7 +85,7 @@ def test_serialize(self): self._test_serialize(DenseVector(pyarray.array('d', range(10)))) self._test_serialize(SparseVector(4, {1: 1, 3: 2})) self._test_serialize(SparseVector(3, {})) - # self._test_serialize(DenseMatrix(2, 3, range(6))) + self._test_serialize(DenseMatrix(2, 3, range(6))) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2})