Skip to content

Commit

Permalink
[SPARK-13035][ML][PYSPARK] PySpark ml.clustering support export/import
Browse files Browse the repository at this point in the history
PySpark ml.clustering support export/import.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #10999 from yanboliang/spark-13035.
  • Loading branch information
yanboliang authored and mengxr committed Feb 11, 2016
1 parent 2426eb3 commit 30e0095
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
#

from pyspark import since
from pyspark.ml.util import keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc

__all__ = ['KMeans', 'KMeansModel']


class KMeansModel(JavaModel):
class KMeansModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by KMeans.
Expand All @@ -46,7 +46,8 @@ def computeCost(self, dataset):


@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed):
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
MLWritable, MLReadable):
"""
K-means clustering with support for multiple parallel runs and a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
Expand All @@ -69,6 +70,25 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> kmeans_path = path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
>>> kmeans2.getK()
2
>>> model_path = path + "/kmeans_model"
>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
array([ True, True], dtype=bool)
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.5.0
"""
Expand Down Expand Up @@ -157,9 +177,10 @@ def getInitSteps(self):

if __name__ == "__main__":
import doctest
import pyspark.ml.clustering
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
globs = pyspark.ml.clustering.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.clustering tests")
Expand Down

0 comments on commit 30e0095

Please sign in to comment.