Skip to content

Commit

Permalink
equals only internal used, so rename to _equals
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Sep 14, 2015
1 parent d3f8c14 commit 3b8ac7a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
25 changes: 7 additions & 18 deletions python/pyspark/mllib/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def __eq__(self, other):
return np.array_equal(self.array, other.array)
elif isinstance(other, SparseVector):
if len(self) != other.size:
return false
return Vectors.equals(list(xrange(len(self))), self.array, other.indices, other.values)
return NotImplemented
return False
return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values)
return False

def __ne__(self, other):
return not self == other
Expand Down Expand Up @@ -739,9 +739,9 @@ def __eq__(self, other):
and np.array_equal(other.values, self.values)
elif isinstance(other, DenseVector):
if self.size != len(other):
return false
return Vectors.equals(self.indices, self.values, list(xrange(len(other))), other.array)
return NotImplemented
return False
return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array)
return False

def __getitem__(self, index):
inds = self.indices
Expand Down Expand Up @@ -879,21 +879,10 @@ def zeros(size):
return DenseVector(np.zeros(size))

@staticmethod
def equals(v1_indices, v1_values, v2_indices, v2_values):
def _equals(v1_indices, v1_values, v2_indices, v2_values):
"""
Check equality between sparse/dense vectors,
v1_indices and v2_indices assume to be strictly increasing.
>>> indices = [1, 2, 4]
>>> values = [1., 3., 2.]
>>> Vectors.equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])
True
>>> Vectors.equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])
False
>>> Vectors.equals(indices, values, list(range(5)), [0., 3., 0., 2.])
False
>>> Vectors.equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])
False
"""
v1_size = len(v1_values)
v2_size = len(v2_values)
Expand Down
12 changes: 10 additions & 2 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,23 @@ def test_eq(self):
v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
v4 = SparseVector(4, [(1, 1.0), (3, 5.5)])
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
self.assertTrue(v1 == v2)
self.assertTrue(v1 == v3)
self.assertTrue(v2 == v4)
self.assertFalse(v2 == v4)
self.assertFalse(v1 == v5)
self.assertFalse(v1 == v6)

def test_equals(self):
indices = [1, 2, 4]
values = [1., 3., 2.]
self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))

def test_conversion(self):
# numpy arrays should be automatically upcast to float64
# tests for fix of [SPARK-5089]
Expand Down

0 comments on commit 3b8ac7a

Please sign in to comment.