Skip to content

Commit

Permalink
Python API for Gaussian Mixture Model
Browse files Browse the repository at this point in the history
  • Loading branch information
FlytxtRnD committed Jan 15, 2015
1 parent 4b325c7 commit fda60f3
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 2 deletions.
50 changes: 50 additions & 0 deletions examples/src/main/python/mllib/gaussian_mixture_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A Gaussian Mixture Model clustering program using MLlib.
This example requires NumPy (http://www.numpy.org/).
"""

import sys

import numpy as np
from pyspark import SparkContext
from pyspark.mllib.clustering import GaussianMixtureEM


# TODO change , to ' '
def parseVector(line):
return np.array([float(x) for x in line.split(',')])


if __name__ == "__main__":
if len(sys.argv) != 4:
print >> sys.stderr, "Usage: gaussian_mixture_model <input_file> <k> <convergenceTol> "
exit(-1)
sc = SparkContext(appName="GMM")
lines = sc.textFile(sys.argv[1])
data = lines.map(parseVector)
k = int(sys.argv[2])
convergenceTol = float(sys.argv[3])
model = GaussianMixtureEM.train(data, k, convergenceTol)
for i in range(k):
print ("weight = ", model.weight[i], "mu = ", model.mu[i],
"sigma = ", model.sigma[i].toArray())
print model.predictLabels(data).collect()
sc.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,46 @@ class PythonMLLibAPI extends Serializable {
}
}

/**
* Java stub for Python mllib GaussianMixtureEM.train()
*/
def trainGaussianMixtureEM(
data: JavaRDD[Vector],
k: Int,
convergenceTol: Double,
maxIterations: Int): JList[Object] = {
val gmmAlg = new GaussianMixtureEM()
.setK(k)
.setConvergenceTol(convergenceTol)
.setMaxIterations(maxIterations)
try {
val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
List(model.weight, model.mu, model.sigma).
map(_.asInstanceOf[Object]).asJava
} finally {
data.rdd.unpersist(blocking = false)
}
}

/**
* Java stub for Python mllib GaussianMixtureModel.predictSoft()
*/
def findPredict(
data: JavaRDD[Vector],
wt: Object,
mu: Array[Object],
si: Array[Object]): RDD[Array[Double]] = {
try {
val weight = wt.asInstanceOf[Array[Double]]
val mean = mu.map(_.asInstanceOf[Vector])
val sigma = si.map(_.asInstanceOf[Matrix])
val model = new GaussianMixtureModel(weight, mean, sigma)
model.predictSoft(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
} finally {
data.rdd.unpersist(blocking = false)
}
}

/**
* A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
*/
Expand Down
56 changes: 54 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, callJavaFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector

__all__ = ['KMeansModel', 'KMeans']
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixtureEM']


class KMeansModel(object):
Expand Down Expand Up @@ -86,6 +86,58 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"
return KMeansModel([c.toArray() for c in centers])


class GaussianMixtureModel(object):

"""A clustering model derived from the Gaussian Mixture Model method.
>>> from numpy import array
>>> clusterdata_1 = sc.parallelize(array([6.0, 9.0,5.0, 10.0,4.0, 11.0]).reshape(3,2))
>>> model = GaussianMixtureEM.train(clusterdata_1, 1, 0.0001, maxIterations=10)
>>> labels = model.predictLabels(clusterdata_1).collect()
>>> labels[0]==labels[1]==labels[2]
True
>>> clusterdata_2 = sc.parallelize(array([-1,-5,-3,-9,-4,-6,9,5,4,3,11,4]).reshape(6,2))
>>> model = GaussianMixtureEM.train(clusterdata_2, 2, 0.0001, maxIterations=10)
>>> labels = model.predictLabels(clusterdata_2).collect()
>>> labels[0]==labels[1]==labels[2]
True
>>> labels[3]==labels[4]==labels[5]
True
"""

def __init__(self, weight, mu, sigma):
self.weight = weight
self.mu = mu
self.sigma = sigma

def predictLabels(self, X):
"""
Find the cluster to which the points in X has maximum membership
in this model.
"""
cluster_labels = self.predictSoft(X).map(lambda x: x.index(max(x)))
return cluster_labels

def predictSoft(self, X):
"""
Find the membership of each point in X to all clusters in this model.
"""
membership_matrix = callMLlibFunc("findPredict", X.map(_convert_to_vector),
self.weight, self.mu, self.sigma)
return membership_matrix


class GaussianMixtureEM(object):

@classmethod
def train(cls, rdd, k, convergenceTol, maxIterations=100):
"""Train a Gaussian Mixture clustering model."""
weight, mu, sigma = callMLlibFunc("trainGaussianMixtureEM",
rdd.map(_convert_to_vector), k,
convergenceTol, maxIterations)
return GaussianMixtureModel(weight, mu, sigma)


def _test():
import doctest
globs = globals().copy()
Expand Down

0 comments on commit fda60f3

Please sign in to comment.