Skip to content

Commit

Permalink
[SPARK-7401] Vectorize dot product and sq_dist
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 2, 2015
1 parent 5fa0863 commit e5f1de0
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,13 +586,19 @@ def dot(self, other):

assert len(self) == _vector_size(other), "dimension mismatch"

if type(other) in (np.ndarray, array.array, DenseVector):
if type(other) == array.array:
result = 0.0
for i in xrange(len(self.indices)):
result += self.values[i] * other[self.indices[i]]
for i, ind in enumerate(self.indices):
result += self.values[i] * other[ind]
return result

elif type(other) is SparseVector:
elif isinstance(other, np.ndarray):
return np.dot(other[self.indices], self.values)

elif isinstance(other, DenseVector):
return np.dot(other.array[self.indices], self.values)

elif isinstance(other, SparseVector):
result = 0.0
i, j = 0, 0
while i < len(self.indices) and j < len(other.indices):
Expand Down Expand Up @@ -635,7 +641,7 @@ def squared_distance(self, other):
AssertionError: dimension mismatch
"""
assert len(self) == _vector_size(other), "dimension mismatch"
if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
if type(other) in (list, array.array):
if type(other) is np.array and other.ndim != 1:
raise Exception("Cannot call squared_distance with %d-dimensional array" %
other.ndim)
Expand All @@ -650,6 +656,18 @@ def squared_distance(self, other):
result += other[i] * other[i]
return result

elif type(other) in (np.array, np.ndarray, DenseVector):
if type(other) == DenseVector:
other = other.array
sparse_ind = np.zeros(other.size, dtype=bool)
sparse_ind[self.indices] = True
dist = other[sparse_ind] - self.values
result = np.dot(dist, dist)

other_ind = other[~sparse_ind]
result += np.dot(other_ind, other_ind)
return result

elif type(other) is SparseVector:
result = 0.0
i, j = 0, 0
Expand Down

0 comments on commit e5f1de0

Please sign in to comment.