Skip to content

Commit

Permalink
[SPARK-6226][MLLIB] add save/load in PySpark's KMeansModel
Browse files Browse the repository at this point in the history
Use `_py2java` and `_java2py` to convert Python model to/from Java model. yinxusen

Author: Xiangrui Meng <meng@databricks.com>

Closes #5049 from mengxr/SPARK-6226-mengxr and squashes the following commits:

570ba81 [Xiangrui Meng] fix python style
b10b911 [Xiangrui Meng] add save/load in PySpark's KMeansModel
  • Loading branch information
mengxr committed Mar 17, 2015
1 parent d9f3e01 commit c94d062
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.clustering

import scala.collection.JavaConverters._

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
Expand All @@ -34,6 +36,9 @@ import org.apache.spark.sql.Row
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {

/** A Java-friendly constructor that takes an Iterable of Vectors. */
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)

/** Total number of clusters. */
def k: Int = clusterCenters.length

Expand Down
28 changes: 25 additions & 3 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@

from pyspark import RDD
from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, callJavaFunc
from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.util import Saveable, Loader, inherit_doc

__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']


class KMeansModel(object):
@inherit_doc
class KMeansModel(Saveable, Loader):

"""A clustering model derived from the k-means method.
Expand Down Expand Up @@ -55,6 +57,16 @@ class KMeansModel(object):
True
>>> type(model.clusterCenters)
<type 'list'>
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = KMeansModel.load(sc, path)
>>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
True
>>> try:
... os.removedirs(path)
... except OSError:
... pass
"""

def __init__(self, centers):
Expand All @@ -77,6 +89,16 @@ def predict(self, x):
best_distance = distance
return best

def save(self, sc, path):
java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
java_model.save(sc._jsc.sc(), path)

@classmethod
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path)
return KMeansModel(_java2py(sc, java_model.clusterCenters()))


class KMeans(object):

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _py2java(sc, obj):
obj = _to_java_object_rdd(obj)
elif isinstance(obj, SparkContext):
obj = obj._jsc
elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)):
obj = ListConverter().convert(obj, sc._gateway._gateway_client)
elif isinstance(obj, list):
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
elif isinstance(obj, JavaObject):
pass
elif isinstance(obj, (int, long, float, bool, basestring)):
Expand Down

0 comments on commit c94d062

Please sign in to comment.