Skip to content

Commit

Permalink
Add kmeans initial seed to pyspark API
Browse files Browse the repository at this point in the history
  • Loading branch information
str-janus committed Dec 4, 2014
1 parent d005429 commit 35c1884
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable {
k: Int,
maxIterations: Int,
runs: Int,
initializationMode: String): KMeansModel = {
initializationMode: String,
seed: java.lang.Long): KMeansModel = {
val kMeansAlg = new KMeans()
.setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)

if (seed != null) kMeansAlg.setSeed(seed)

try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@ class KMeans private (
private var runs: Int,
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double) extends Serializable with Logging {
private var epsilon: Double,
private var seed: Long = System.nanoTime()) extends Serializable with Logging {

/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
*/
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)

def setSeed(seed: Long): this.type = {
this.seed = seed
this
}

/** Set the number of clusters to create (k). Default: 2. */
def setK(k: Int): this.type = {
this.k = k
Expand Down Expand Up @@ -255,7 +261,7 @@ class KMeans private (
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
Expand All @@ -273,7 +279,7 @@ class KMeans private (
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Initialize each run's center to a random point
val seed = new XORShiftRandom().nextInt()
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def predict(self, x):
class KMeans(object):

@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
"""Train a k-means clustering model."""
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
runs, initializationMode)
runs, initializationMode, seed)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])

Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ def test_clustering(self):
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))

def test_clustering_deterministic(self):
from pyspark.mllib.clustering import KMeans
X = range(0, 100, 10)
Y = range(0, 100, 10)
data = [[x, y] for x, y in zip(X, Y)]
clusters1 = KMeans.train(self.sc.parallelize(data),
3, initializationMode="k-means||", seed=42)
clusters2 = KMeans.train(self.sc.parallelize(data),
3, initializationMode="k-means||", seed=42)
clusters3 = KMeans.train(self.sc.parallelize(data),
3, initializationMode="k-means||", seed=42)
centers1 = array(clusters1.centers).flatten().tolist()
centers2 = array(clusters2.centers).flatten().tolist()
centers3 = array(clusters3.centers).flatten().tolist()
self.assertListEqual(centers1, centers2)
self.assertListEqual(centers1, centers3)

def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree
Expand Down

0 comments on commit 35c1884

Please sign in to comment.