Skip to content

Commit

Permalink
[SPARK-7105] [PySpark] [MLlib] Support model save/load in GMM
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 23, 2015
1 parent 2f5cbd8 commit 9c305aa
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.api.python

import java.util.{List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
import org.apache.spark.mllib.clustering.GaussianMixtureModel

/**
* Wrapper around GaussianMixtureModel to provide helper methods in Python
*/
private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
val weights: Vector = Vectors.dense(model.weights)
val k: Int = weights.size

/**
* Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
*/
val gaussians: JList[Object] = {
val modelGaussians = model.gaussians
var i = 0
var mu = ArrayBuffer.empty[Vector]
var sigma = ArrayBuffer.empty[Matrix]
while (i < k) {
mu += modelGaussians(i).mu
sigma += modelGaussians(i).sigma
i += 1
}
List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
seed: java.lang.Long,
initialModelWeights: java.util.ArrayList[Double],
initialModelMu: java.util.ArrayList[Vector],
initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
val gmmAlg = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
Expand All @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) gmmAlg.setSeed(seed)

try {
val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
var wt = ArrayBuffer.empty[Double]
var mu = ArrayBuffer.empty[Vector]
var sigma = ArrayBuffer.empty[Matrix]
for (i <- 0 until model.k) {
wt += model.weights(i)
mu += model.gaussians(i).mu
sigma += model.gaussians(i).sigma
}
List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
} finally {
data.rdd.unpersist(blocking = false)
}
Expand Down
75 changes: 53 additions & 22 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"
return KMeansModel([c.toArray() for c in centers])


class GaussianMixtureModel(object):
@inherit_doc
class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):

"""
.. note:: Experimental
"""A clustering model derived from the Gaussian Mixture Model method.
A clustering model derived from the Gaussian Mixture Model method.
>>> from pyspark.mllib.linalg import Vectors, DenseMatrix
>>> from numpy.testing import assert_equal
>>> from shutil import rmtree
>>> import os, tempfile
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
... 0.9,0.8,0.75,0.935,
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
Expand All @@ -169,6 +177,25 @@ class GaussianMixtureModel(object):
True
>>> labels[4]==labels[5]
True
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = GaussianMixtureModel.load(sc, path)
>>> assert_equal(model.weights, sameModel.weights)
>>> mus, sigmas = list(
... zip(*[(g.mu, g.sigma) for g in model.gaussians]))
>>> sameMus, sameSigmas = list(
... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
>>> mus == sameMus
True
>>> sigmas == sameSigmas
True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
>>> data = array([-5.1971, -2.5359, -3.8220,
... -5.2211, -5.0602, 4.7118,
... 6.8989, 3.4592, 4.6322,
Expand All @@ -182,38 +209,30 @@ class GaussianMixtureModel(object):
True
>>> labels[3]==labels[4]
True
>>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
>>> im = GaussianMixtureModel([0.5, 0.5],
... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
>>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
"""

def __init__(self, weights, gaussians):
self._weights = weights
self._gaussians = gaussians
self._k = len(self._weights)

@property
def weights(self):
"""
Weights for each Gaussian distribution in the mixture, where weights[i] is
the weight for Gaussian i, and weights.sum == 1.
"""
return self._weights
return array(self.call("weights"))

@property
def gaussians(self):
"""
Array of MultivariateGaussian where gaussians[i] represents
the Multivariate Gaussian (Normal) Distribution for Gaussian i.
"""
return self._gaussians
return [
MultivariateGaussian(gaussian[0], gaussian[1])
for gaussian in zip(*self.call("gaussians"))]

@property
def k(self):
"""Number of gaussians in mixture."""
return self._k
return len(self.weights)

def predict(self, x):
"""
Expand All @@ -238,17 +257,30 @@ def predictSoft(self, x):
:return: membership_matrix. RDD of array of double values.
"""
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self._weights), means, sigmas)
_convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))

@classmethod
def load(cls, sc, path):
"""Load the GaussianMixtureModel from disk.
:param sc: SparkContext
:param path: str, path to where the model is stored.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
return cls(wrapper)


class GaussianMixture(object):
"""
.. note:: Experimental
Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
:param data: RDD of data points
Expand All @@ -271,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
initialModelWeights = initialModel.weights
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
k, convergenceTol, maxIterations, seed,
initialModelWeights, initialModelMu, initialModelSigma)
mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
return GaussianMixtureModel(weight, mvg_obj)
java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
k, convergenceTol, maxIterations, seed,
initialModelWeights, initialModelMu, initialModelSigma)
return GaussianMixtureModel(java_model)


class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

if sys.version > '3':
xrange = range
basestring = str

from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector

Expand Down Expand Up @@ -223,6 +225,10 @@ class JavaSaveable(Saveable):
"""

def save(self, sc, path):
if not isinstance(sc, SparkContext):
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
self._java_model.save(sc._jsc.sc(), path)


Expand Down

0 comments on commit 9c305aa

Please sign in to comment.