Skip to content

Commit

Permalink
Remove unnecessary imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Jun 21, 2015
1 parent 7ec04db commit 25d3c9d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
43 changes: 43 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler
from pyspark.mllib.feature import ElementwiseProduct
from pyspark.mllib.util import MLUtils
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext
Expand Down Expand Up @@ -1010,6 +1011,48 @@ def collect(rdd):
self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])


class MLUtilsTests(MLlibTestCase):
def test_append_bias(self):
data = [2.0, 2.0, 2.0]
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)
self.assertEqual(type(ret), DenseVector)

def test_append_bias_with_vector(self):
data = Vectors.dense([2.0, 2.0, 2.0])
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)
self.assertEqual(type(ret), DenseVector)

def test_append_bias_with_sp_vector(self):
data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
# Returned value must be SparseVector
ret = MLUtils.appendBias(data)
self.assertEqual(ret, expected)
self.assertEqual(type(ret), SparseVector)

def test_load_vectors(self):
import shutil
data = [
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0]
]
temp_dir = tempfile.mkdtemp()
load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
try:
self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
ret = ret_rdd.collect()
self.assertEqual(len(ret), 2)
self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
except:
self.fail()
finally:
shutil.rmtree(load_vectors_path)


if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
Expand Down
12 changes: 4 additions & 8 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
xrange = range

from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vector, Vectors, DenseVector, SparseVector, _convert_to_vector
from pyspark.mllib.linalg import Vector, Vectors, SparseVector, _convert_to_vector


class MLUtils(object):
Expand Down Expand Up @@ -183,15 +183,11 @@ def appendBias(data):
"""
vec = _convert_to_vector(data)
if isinstance(vec, SparseVector):
if _have_scipy:
l = scipy.sparse.csc_matrix(np.append(vec.toArray(), 1.0))
return _convert_to_vector(l.T)
else:
raise TypeError("Cannot append bias %s into sparce "
"vector because of lack of scipy" % type(vec))
l = scipy.sparse.csc_matrix(np.append(vec.toArray(), 1.0))
return _convert_to_vector(l.T)
elif isinstance(vec, Vector):
vec = vec.toArray()
return np.append(vec, 1.0).tolist()
return _convert_to_vector(np.append(vec, 1.0).tolist())

@staticmethod
def loadVectors(sc, path):
Expand Down

0 comments on commit 25d3c9d

Please sign in to comment.