Skip to content

Commit

Permalink
[SPARK-6263] Python MLlib API missing items: Utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Apr 26, 2015
1 parent 9a5bbe0 commit 2980569
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ private[python] class PythonMLLibAPI extends Serializable {
minPartitions: Int): JavaRDD[LabeledPoint] =
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)

def appendBias(data: org.apache.spark.mllib.linalg.Vector)
= MLUtils.appendBias(data)

def loadVectors(jsc: JavaSparkContext, path: String)
= MLUtils.loadVectors(jsc.sc, path)

private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler
from pyspark.mllib.util import MLUtils
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext

Expand Down Expand Up @@ -789,6 +790,29 @@ def test_model_transform(self):
self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))


class MLUtilsTests(MLlibTestCase):
def test_append_bias(self):
data = [1.0, 2.0, 3.0]
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)

def test_load_vectors(self):
import shutil
data = [
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0]
]
try:
self.sc.parallelize(data).saveAsTextFile("test_load_vectors")
ret_rdd = MLUtils.loadVectors(self.sc, "test_load_vectors")
ret = ret_rdd.collect()
self.assertEqual(len(ret), 2)
self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
finally:
shutil.rmtree("test_load_vectors")


if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def loadLabeledPoints(sc, path, minPartitions=None):
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)

@staticmethod
def appendBias(data):
return callMLlibFunc("appendBias", _convert_to_vector(data))

@staticmethod
def loadVectors(sc, path):
return callMLlibFunc("loadVectors", sc, path)


class Saveable(object):
"""
Expand Down

0 comments on commit 2980569

Please sign in to comment.