From 198d181dfb2c04102afe40680a4637d951e92c0b Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 28 Jul 2015 15:00:25 -0700 Subject: [PATCH] [SPARK-7105] [PYSPARK] [MLLIB] Support model save/load in GMM This PR introduces save / load for GMM's in python API. Also I refactored `GaussianMixtureModel` and inherited it from `JavaModelWrapper` with model being `GaussianMixtureModelWrapper`, a wrapper which provides convenience methods to `GaussianMixtureModel` (due to serialization and deserialization issues) and I moved the creation of gaussians to the scala backend. Author: MechCoder Closes #7617 from MechCoder/python_gmm_save_load and squashes the following commits: 9c305aa [MechCoder] [SPARK-7105] [PySpark] [MLlib] Support model save/load in GMM --- .../python/GaussianMixtureModelWrapper.scala | 53 +++++++++++++ .../mllib/api/python/PythonMLLibAPI.scala | 13 +--- python/pyspark/mllib/clustering.py | 75 +++++++++++++------ python/pyspark/mllib/util.py | 6 ++ 4 files changed, 114 insertions(+), 33 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala new file mode 100644 index 0000000000000..0ec88ef77d695 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} +import org.apache.spark.mllib.clustering.GaussianMixtureModel + +/** + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { + val weights: Vector = Vectors.dense(model.weights) + val k: Int = weights.size + + /** + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ + val gaussians: JList[Object] = { + val modelGaussians = model.gaussians + var i = 0 + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + while (i < k) { + mu += modelGaussians(i).mu + sigma += modelGaussians(i).sigma + i += 1 + } + List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fda8d5a0b048f..6f080d32bbf4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable { seed: java.lang.Long, initialModelWeights: java.util.ArrayList[Double], initialModelMu: java.util.ArrayList[Vector], - initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = { + initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) gmmAlg.setSeed(seed) try { - val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) - var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - for (i <- 0 until model.k) { - wt += model.weights(i) - mu += model.gaussians(i).mu - sigma += model.gaussians(i).sigma - } - List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))) } finally { data.rdd.unpersist(blocking = false) } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 58ad99d46e23b..900ade248c386 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -152,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" return KMeansModel([c.toArray() for c in centers]) -class GaussianMixtureModel(object): +@inherit_doc +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental - """A clustering model derived from the Gaussian Mixture Model method. + A clustering model derived from the Gaussian Mixture Model method. >>> from pyspark.mllib.linalg import Vectors, DenseMatrix + >>> from numpy.testing import assert_equal + >>> from shutil import rmtree + >>> import os, tempfile + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) @@ -169,6 +177,25 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True + + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = GaussianMixtureModel.load(sc, path) + >>> assert_equal(model.weights, sameModel.weights) + >>> mus, sigmas = list( + ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) + >>> sameMus, sameSigmas = list( + ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) + >>> mus == sameMus + True + >>> sigmas == sameSigmas + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + >>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, @@ -182,25 +209,15 @@ class GaussianMixtureModel(object): True >>> labels[3]==labels[4] True - >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1)) - >>> im = GaussianMixtureModel([0.5, 0.5], - ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])), - ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))]) - >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im) """ - def __init__(self, weights, gaussians): - self._weights = weights - self._gaussians = gaussians - self._k = len(self._weights) - @property def weights(self): """ Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1. """ - return self._weights + return array(self.call("weights")) @property def gaussians(self): @@ -208,12 +225,14 @@ def gaussians(self): Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i. """ - return self._gaussians + return [ + MultivariateGaussian(gaussian[0], gaussian[1]) + for gaussian in zip(*self.call("gaussians"))] @property def k(self): """Number of gaussians in mixture.""" - return self._k + return len(self.weights) def predict(self, x): """ @@ -238,17 +257,30 @@ def predictSoft(self, x): :return: membership_matrix. RDD of array of double values. """ if isinstance(x, RDD): - means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians]) + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - _convert_to_vector(self._weights), means, sigmas) + _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: raise TypeError("x should be represented by an RDD, " "but got %s." % type(x)) + @classmethod + def load(cls, sc, path): + """Load the GaussianMixtureModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.GaussianMixtureModelWrapper(model) + return cls(wrapper) + class GaussianMixture(object): """ + .. note:: Experimental + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. :param data: RDD of data points @@ -271,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), - k, convergenceTol, maxIterations, seed, - initialModelWeights, initialModelMu, initialModelSigma) - mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] - return GaussianMixtureModel(weight, mvg_obj) + java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) + return GaussianMixtureModel(java_model) class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 875d3b2d642c6..916de2d6fcdbd 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -21,7 +21,9 @@ if sys.version > '3': xrange = range + basestring = str +from pyspark import SparkContext from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -223,6 +225,10 @@ class JavaSaveable(Saveable): """ def save(self, sc, path): + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) self._java_model.save(sc._jsc.sc(), path)