Skip to content

Commit

Permalink
[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None
Browse files Browse the repository at this point in the history
If initial model passed to GMM is not empty it causes net.razorvine.pickle.PickleException. It can be fixed by converting initialModel.weights to list.

Author: zero323 <matthew.szymkiewicz@gmail.com>

Closes #10644 from zero323/SPARK-12006.
  • Loading branch information
zero323 authored and jkbradley committed Jan 7, 2016
1 parent 8113dbd commit 592f649
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
if initialModel.k != k:
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
% (initialModel.k, k))
initialModelWeights = initialModel.weights
initialModelWeights = list(initialModel.weights)
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,18 @@ def test_gmm_deterministic(self):
for c1, c2 in zip(clusters1.weights, clusters2.weights):
self.assertEqual(round(c1, 7), round(c2, 7))

def test_gmm_with_initial_model(self):
from pyspark.mllib.clustering import GaussianMixture
data = self.sc.parallelize([
(-10, -5), (-9, -4), (10, 5), (9, 4)
])

gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
maxIterations=10, seed=63)
gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
maxIterations=10, seed=63, initialModel=gmm1)
self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)

def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
Expand Down

0 comments on commit 592f649

Please sign in to comment.