From 487dc247ac8b0629edf5201e36512773f56226b0 Mon Sep 17 00:00:00 2001 From: Vishnu Prasad Date: Wed, 20 Apr 2016 05:50:25 +0530 Subject: [PATCH] [SPARK-14739][PySPARK] Fix for Both Sparse and Dense Vector Parsing Errors --- python/pyspark/mllib/linalg/__init__.py | 6 +++--- python/pyspark/mllib/tests.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index abf00a4737948..66a41500233bb 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -293,7 +293,7 @@ def parse(s): s = s[start + 1: end] try: - values = [float(val) for val in s.split(',')] + values = [float(val) for val in s.split(',') if val] except ValueError: raise ValueError("Unable to parse values from %s" % s) return DenseVector(values) @@ -586,7 +586,7 @@ def parse(s): new_s = s[ind_start + 1: ind_end] ind_list = new_s.split(',') try: - indices = [int(ind) for ind in ind_list] + indices = [int(ind) for ind in ind_list if val] except ValueError: raise ValueError("Unable to parse indices from %s." % new_s) s = s[ind_end + 1:].strip() @@ -599,7 +599,7 @@ def parse(s): raise ValueError("Values array should end with ']'.") val_list = s[val_start + 1: val_end].split(',') try: - values = [float(val) for val in val_list] + values = [float(val) for val in val_list if val] except ValueError: raise ValueError("Unable to parse values from %s." % s) return SparseVector(size, indices, values) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f272da56d1aee..7a957e01e9da3 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -401,6 +401,10 @@ def test_parse_vector(self): self.assertTrue(Vectors.parse(str(a)), a) a = SparseVector(10, [0, 1], [4, 5]) self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) + a = DenseVector([]) + self.assertEqual(Vectors.parse('[]'), a) + a = SparseVector(8, [], []) + self.assertEqual(Vectors.parse('(8, [], [])'), a) def test_norms(self): a = DenseVector([0, 2, 3, -1])