From 29805690500087551c6f474fda035d74724e8016 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 26 Apr 2015 22:32:15 +0900 Subject: [PATCH] [SPARK-6263] Python MLlib API missing items: Utils --- .../mllib/api/python/PythonMLLibAPI.scala | 6 +++++ python/pyspark/mllib/tests.py | 24 +++++++++++++++++++ python/pyspark/mllib/util.py | 8 +++++++ 3 files changed, 38 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6237b64c8f984..b181b3b13c38e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -71,6 +71,12 @@ private[python] class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + def appendBias(data: org.apache.spark.mllib.linalg.Vector) + = MLUtils.appendBias(data) + + def loadVectors(jsc: JavaSparkContext, path: String) + = MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1b008b93bc137..73b4706992ec5 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -46,6 +46,7 @@ from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext @@ -789,6 +790,29 @@ def test_model_transform(self): self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [1.0, 2.0, 3.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + try: + self.sc.parallelize(data).saveAsTextFile("test_load_vectors") + ret_rdd = MLUtils.loadVectors(self.sc, "test_load_vectors") + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + finally: + shutil.rmtree("test_load_vectors") + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 16a90db146ef0..4a1e069c6dbff 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -169,6 +169,14 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + return callMLlibFunc("appendBias", _convert_to_vector(data)) + + @staticmethod + def loadVectors(sc, path): + return callMLlibFunc("loadVectors", sc, path) + class Saveable(object): """