From 1daee613e08e727f272d380b64466bbdc8faba02 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 30 Nov 2016 18:18:35 +0800 Subject: [PATCH 1/5] recreate pr --- .../apache/spark/ml/clustering/KMeans.scala | 26 ++++++++++++-- .../spark/mllib/clustering/KMeans.scala | 34 +++++++++++++++++-- .../spark/ml/clustering/KMeansSuite.scala | 8 +++++ 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e02b532ca8a93..d1841aa71f060 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -84,6 +84,20 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** + * The fraction of the data to update centers per iteration. Must be > 0 and ≤ 1. + * Default: 1.0. + * @group param + */ + @Since("2.2.0") + final val miniBatchFraction = new DoubleParam(this, "k", "The fraction of the data to update" + + " clustering centers per iteration. Must be in (0, 1].", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + @Since("2.2.0") + def getMiniBatchFraction: Double = $(miniBatchFraction) + /** * Validates and transforms the input schema. * @param schema input schema @@ -260,7 +274,8 @@ class KMeans @Since("1.5.0") ( maxIter -> 20, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 2, - tol -> 1e-4) + tol -> 1e-4, + miniBatchFraction -> 1.0) @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -300,6 +315,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.2.0") + def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { transformSchema(dataset.schema, logging = true) @@ -314,7 +333,9 @@ class KMeans @Since("1.5.0") ( } val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) + instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol, + miniBatchFraction) + val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -322,6 +343,7 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + .setMiniBatchFraction($(miniBatchFraction)) val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fa72b72e2d921..2f5055cfc095b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -45,15 +45,17 @@ class KMeans private ( private var maxIterations: Int, private var initializationMode: String, private var initializationSteps: Int, + private var miniBatchFraction: Double, private var epsilon: Double, private var seed: Long) extends Serializable with Logging { /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, - * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. + * initializationMode: "k-means||", initializationSteps: 2, miniBatchFraction: 1.0, + * epsilon: 1e-4, seed: random}. */ @Since("0.8.0") - def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1.0, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). @@ -151,6 +153,26 @@ class KMeans private ( this } + + /** + * The fraction of data to be used for each EM iteration. + */ + @Since("2.2.0") + private[spark] def getMiniBatchFraction: Double = miniBatchFraction + + /** + * :: Experimental :: + * Set fraction of data to be used for each EM iteration. + * Default 1.0 + */ + @Since("2.2.0") + private[spark] def setMiniBatchFraction(fraction: Double): this.type = { + require(fraction > 0 && fraction <= 1.0, + s"Fraction for mini-batch EM must be in range (0, 1] but got ${fraction}") + this.miniBatchFraction = fraction + this + } + /** * The distance threshold within which we've consider centers to have converged. */ @@ -272,8 +294,14 @@ class KMeans private ( val costAccum = sc.doubleAccumulator val bcCenters = sc.broadcast(centers) + val sampled = if (miniBatchFraction < 1) { + data.sample(false, miniBatchFraction, 42 + iteration) + } else { + data + } + // Find the sum and count of points mapping to each center - val totalContribs = data.mapPartitions { points => + val totalContribs = sampled.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 119fe1dead9a9..c6d325798ea1f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -144,6 +144,14 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("SPARK-14174: enable mini-batch EM") { + val model1 = new KMeans().setK(k).setSeed(1).fit(dataset) + val cost1 = model1.computeCost(dataset) + val model2 = new KMeans().setK(k).setMiniBatchFraction(0.9).setSeed(1).fit(dataset) + val cost2 = model2.computeCost(dataset) + require(cost1 === cost2) + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) From 4e8dca69b3669af408c4454450585817a7708c57 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 30 Nov 2016 19:18:57 +0800 Subject: [PATCH 2/5] fix a bug --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d1841aa71f060..35fa0799552a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -90,8 +90,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @group param */ @Since("2.2.0") - final val miniBatchFraction = new DoubleParam(this, "k", "The fraction of the data to update" + - " clustering centers per iteration. Must be in (0, 1].", + final val miniBatchFraction = new DoubleParam(this, "miniBatchFraction", "The fraction of the" + + " data to update clustering centers per iteration. Must be in (0, 1].", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) /** @group getParam */ From 93f6f7f913cba69defda4d0eeaddcba234abd5a6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 23 May 2017 16:29:26 +0800 Subject: [PATCH 3/5] update version --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 +++--- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 35fa0799552a1..fca185ca31e99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -89,13 +89,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * Default: 1.0. * @group param */ - @Since("2.2.0") + @Since("2.3.0") final val miniBatchFraction = new DoubleParam(this, "miniBatchFraction", "The fraction of the" + " data to update clustering centers per iteration. Must be in (0, 1].", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) /** @group getParam */ - @Since("2.2.0") + @Since("2.3.0") def getMiniBatchFraction: Double = $(miniBatchFraction) /** @@ -316,7 +316,7 @@ class KMeans @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ - @Since("2.2.0") + @Since("2.3.0") def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value) @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 2f5055cfc095b..59ba886ad37d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -157,7 +157,6 @@ class KMeans private ( /** * The fraction of data to be used for each EM iteration. */ - @Since("2.2.0") private[spark] def getMiniBatchFraction: Double = miniBatchFraction /** @@ -165,7 +164,6 @@ class KMeans private ( * Set fraction of data to be used for each EM iteration. * Default 1.0 */ - @Since("2.2.0") private[spark] def setMiniBatchFraction(fraction: Double): this.type = { require(fraction > 0 && fraction <= 1.0, s"Fraction for mini-batch EM must be in range (0, 1] but got ${fraction}") From 46d9c7bbbfdd663a19212c2ddc7431ccd6293022 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 24 May 2017 11:33:52 +0800 Subject: [PATCH 4/5] update test --- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c6d325798ea1f..74288e1103bbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -50,6 +50,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + assert(kmeans.getMiniBatchFraction === 1.0) val model = kmeans.setMaxIter(1).fit(dataset) MLTestingUtils.checkCopyAndUids(kmeans, model) @@ -68,6 +69,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .setInitSteps(3) .setSeed(123) .setTol(1e-3) + .setMiniBatchFraction(0.5) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") @@ -77,6 +79,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) assert(kmeans.getTol === 1e-3) + assert(kmeans.getMiniBatchFraction === 0.5) } test("parameters validation") { @@ -89,6 +92,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } + intercept[IllegalArgumentException] { + new KMeans().setMiniBatchFraction(0) + } } test("fit, transform and summary") { From a5ce8caad5a81c2e5c6789402748a0b4f15b2fbc Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 26 May 2017 10:15:15 +0800 Subject: [PATCH 5/5] fix some nits --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 +++--- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index fca185ca31e99..76179c21d2d16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -85,13 +85,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe def getInitSteps: Int = $(initSteps) /** - * The fraction of the data to update centers per iteration. Must be > 0 and ≤ 1. + * The fraction of data used to update centers per iteration. Must be > 0 and ≤ 1. * Default: 1.0. * @group param */ @Since("2.3.0") - final val miniBatchFraction = new DoubleParam(this, "miniBatchFraction", "The fraction of the" + - " data to update clustering centers per iteration. Must be in (0, 1].", + final val miniBatchFraction = new DoubleParam(this, "miniBatchFraction", "The fraction of" + + " data used to update cluster centers per iteration. Must be in (0, 1].", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) /** @group getParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 74288e1103bbb..c350e464fd065 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -95,6 +95,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR intercept[IllegalArgumentException] { new KMeans().setMiniBatchFraction(0) } + intercept[IllegalArgumentException] { + new KMeans().setMiniBatchFraction(-0.01) + } + intercept[IllegalArgumentException] { + new KMeans().setMiniBatchFraction(1.01) + } } test("fit, transform and summary") {