-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-9793] [MLlib] [PySpark] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly #8166
Changes from 6 commits
1e9d1bc
7489a44
83f51ed
fca0f5a
d3f8c14
3b8ac7a
b58d1bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
|
||
import sys | ||
import array | ||
import struct | ||
|
||
if sys.version >= '3': | ||
basestring = str | ||
|
@@ -122,6 +123,15 @@ def _format_float_list(l): | |
return [_format_float(x) for x in l] | ||
|
||
|
||
def _double_to_long_bits(value): | ||
if value != value: | ||
# value is NaN, standardize to canonical non-signaling NaN | ||
return 0x7ff8000000000000 | ||
else: | ||
# pack double into 64 bits, then unpack as long int | ||
return struct.unpack('Q', struct.pack('d', value))[0] | ||
|
||
|
||
class VectorUDT(UserDefinedType): | ||
""" | ||
SQL user-defined type (UDT) for Vector. | ||
|
@@ -404,11 +414,31 @@ def __repr__(self): | |
return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) | ||
|
||
def __eq__(self, other): | ||
return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) | ||
if isinstance(other, DenseVector): | ||
return np.array_equal(self.array, other.array) | ||
elif isinstance(other, SparseVector): | ||
if len(self) != other.size: | ||
return False | ||
return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self == other | ||
|
||
def __hash__(self): | ||
size = len(self) | ||
result = 31 + size | ||
nnz = 0 | ||
i = 0 | ||
while i < size and nnz < 128: | ||
if self.array[i] != 0: | ||
result = 31 * result + i | ||
bits = _double_to_long_bits(self.array[i]) | ||
result = 31 * result + (bits ^ (bits >> 32)) | ||
nnz += 1 | ||
i += 1 | ||
return result | ||
|
||
def __getattr__(self, item): | ||
return getattr(self.array, item) | ||
|
||
|
@@ -704,20 +734,14 @@ def __repr__(self): | |
return "SparseVector({0}, {{{1}}})".format(self.size, entries) | ||
|
||
def __eq__(self, other): | ||
""" | ||
Test SparseVectors for equality. | ||
|
||
>>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) | ||
>>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) | ||
>>> v1 == v2 | ||
True | ||
>>> v1 != v2 | ||
False | ||
""" | ||
return (isinstance(other, self.__class__) | ||
and other.size == self.size | ||
and np.array_equal(other.indices, self.indices) | ||
and np.array_equal(other.values, self.values)) | ||
if isinstance(other, SparseVector): | ||
return other.size == self.size and np.array_equal(other.indices, self.indices) \ | ||
and np.array_equal(other.values, self.values) | ||
elif isinstance(other, DenseVector): | ||
if self.size != len(other): | ||
return False | ||
return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) | ||
return False | ||
|
||
def __getitem__(self, index): | ||
inds = self.indices | ||
|
@@ -739,6 +763,19 @@ def __getitem__(self, index): | |
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
def __hash__(self): | ||
result = 31 + self.size | ||
nnz = 0 | ||
i = 0 | ||
while i < len(self.values) and nnz < 128: | ||
if self.values[i] != 0: | ||
result = 31 * result + int(self.indices[i]) | ||
bits = _double_to_long_bits(self.values[i]) | ||
result = 31 * result + (bits ^ (bits >> 32)) | ||
nnz += 1 | ||
i += 1 | ||
return result | ||
|
||
|
||
class Vectors(object): | ||
|
||
|
@@ -841,6 +878,31 @@ def parse(s): | |
def zeros(size): | ||
return DenseVector(np.zeros(size)) | ||
|
||
@staticmethod | ||
def _equals(v1_indices, v1_values, v2_indices, v2_values): | ||
""" | ||
Check equality between sparse/dense vectors, | ||
v1_indices and v2_indices assume to be strictly increasing. | ||
""" | ||
v1_size = len(v1_values) | ||
v2_size = len(v2_values) | ||
k1 = 0 | ||
k2 = 0 | ||
all_equal = True | ||
while all_equal: | ||
while k1 < v1_size and v1_values[k1] == 0: | ||
k1 += 1 | ||
while k2 < v2_size and v2_values[k2] == 0: | ||
k2 += 1 | ||
|
||
if k1 >= v1_size or k2 >= v2_size: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I think checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, that's fine with me |
||
return k1 >= v1_size and k2 >= v2_size | ||
|
||
all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this is OK since https://github.com/apache/spark/blob/master/python/pyspark/mllib/linalg/__init__.py#L489 checks for that. Could you please document this assumption though? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, can you please document that in the method's docstring. #7854 is proposing to remove the explicit call to |
||
k1 += 1 | ||
k2 += 1 | ||
return all_equal | ||
|
||
|
||
class Matrix(object): | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,6 +194,37 @@ def test_squared_distance(self): | |
self.assertEquals(3.0, _squared_distance(sv, arr)) | ||
self.assertEquals(3.0, _squared_distance(sv, narr)) | ||
|
||
def test_hash(self): | ||
v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) | ||
v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) | ||
v3 = DenseVector([1.0, 1.0, 0.0, 5.5]) | ||
v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) | ||
self.assertTrue(hash(v1) == hash(v2)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
self.assertFalse(hash(v1) == hash(v3)) | ||
self.assertFalse(hash(v2) == hash(v3)) | ||
self.assertFalse(hash(v2) == hash(v4)) | ||
|
||
def test_eq(self): | ||
v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) | ||
v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) | ||
v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) | ||
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) | ||
v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) | ||
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) | ||
self.assertTrue(v1 == v2) | ||
self.assertTrue(v1 == v3) | ||
self.assertFalse(v2 == v4) | ||
self.assertFalse(v1 == v5) | ||
self.assertFalse(v1 == v6) | ||
|
||
def test_equals(self): | ||
indices = [1, 2, 4] | ||
values = [1., 3., 2.] | ||
self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) | ||
self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) | ||
self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) | ||
self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) | ||
|
||
def test_conversion(self): | ||
# numpy arrays should be automatically upcast to float64 | ||
# tests for fix of [SPARK-5089] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make the code more readable: