Skip to content

Commit

Permalink
Remove type checks for list, pyarray etc
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 2, 2015
1 parent 0ee5dd4 commit fcad0a3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 27 deletions.
31 changes: 6 additions & 25 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,24 +577,16 @@ def dot(self, other):
...
AssertionError: dimension mismatch
"""
if type(other) == np.ndarray:
assert len(self) == _vector_size(other), "dimension mismatch"
if isinstance(other, np.ndarray):
if other.ndim == 2:
results = [self.dot(other[:, i]) for i in xrange(other.shape[1])]
return np.array(results)
elif other.ndim > 2:
elif other.ndim == 1:
return np.dot(other[self.indices], self.values)
else:
raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim)

assert len(self) == _vector_size(other), "dimension mismatch"

if isinstance(other, array.array):
result = 0.0
for i, ind in enumerate(self.indices):
result += self.values[i] * other[ind]
return result

elif isinstance(other, np.ndarray):
return np.dot(other[self.indices], self.values)

elif isinstance(other, DenseVector):
return np.dot(other.array[self.indices], self.values)

Expand Down Expand Up @@ -641,19 +633,8 @@ def squared_distance(self, other):
AssertionError: dimension mismatch
"""
assert len(self) == _vector_size(other), "dimension mismatch"
if isinstance(other, list) or isinstance(other, array.array):
result = 0.0
j = 0 # index into our own array
for i, val in enumerate(other):
if j < len(self.indices) and self.indices[j] == i:
diff = self.values[j] - val
result += diff * diff
j += 1
else:
result += val * val
return result

elif isinstance(other, np.ndarray) or isinstance(other, DenseVector):
if isinstance(other, np.ndarray) or isinstance(other, DenseVector):
if isinstance(other, np.ndarray) and other.ndim != 1:
raise Exception("Cannot call squared_distance with %d-dimensional array" %
other.ndim)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_dot(self):
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
arr = pyarray('d', [0, 1, 2, 3])
arr = pyarray.array('d', [0, 1, 2, 3])
self.assertEquals(10.0, sv.dot(dv))
self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
self.assertEquals(30.0, dv.dot(dv))
Expand All @@ -142,7 +142,7 @@ def test_squared_distance(self):
dv = DenseVector(array([1., 2., 3., 4.]))
lst = DenseVector([4, 3, 2, 1])
lst1 = [4, 3, 2, 1]
arr = pyarray('d', [0, 2, 1, 3])
arr = pyarray.array('d', [0, 2, 1, 3])
narr = array([0, 2, 1, 3])
self.assertEquals(15.0, _squared_distance(sv, dv))
self.assertEquals(25.0, _squared_distance(sv, lst))
Expand Down

0 comments on commit fcad0a3

Please sign in to comment.