From 6b596dcd6a8fe6a6574e4fdda80fdf424c7d8777 Mon Sep 17 00:00:00 2001 From: Zakaria_Hili Date: Tue, 22 Nov 2016 10:33:11 +0100 Subject: [PATCH 1/5] [SPARK-18356] [ML] Improve MLKmeans Performance --- .../apache/spark/ml/clustering/KMeans.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 26505b4cc1501..dcd38ddaa21e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion /** @@ -306,12 +307,19 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + fit(dataset, handlePersistence) + } + @Since("2.0.0") + protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } - - val instr = Instrumentation.create(this, rdd) + if (handlePersistence) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + val instr = Instrumentation.create(this, instances) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -321,12 +329,15 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) - val parentModel = algo.run(rdd, Option(instr)) + val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(summary) instr.logSuccess(model) + if (handlePersistence) { + instances.unpersist() + } model } From d49da761d3d2ec5280b2104dd71f05c917329539 Mon Sep 17 00:00:00 2001 From: Zakaria_Hili Date: Tue, 22 Nov 2016 11:02:24 +0100 Subject: [PATCH 2/5] [SPARK-18356] [ML] Improve MLKmeans Performance --- .../src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dcd38ddaa21e1..0d4e8891ebd78 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -310,7 +310,6 @@ class KMeans @Since("1.5.0") ( val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE fit(dataset, handlePersistence) } - @Since("2.0.0") protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { @@ -336,7 +335,7 @@ class KMeans @Since("1.5.0") ( model.setSummary(summary) instr.logSuccess(model) if (handlePersistence) { - instances.unpersist() + instances.unpersist() } model } From ce596e859aee517346ed3871ac4434c6e96465e9 Mon Sep 17 00:00:00 2001 From: Zakaria_Hili Date: Tue, 22 Nov 2016 14:58:40 +0100 Subject: [PATCH 3/5] [SPARK-18356] [ML] Improve MLKmeans Performance --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 0d4e8891ebd78..f45cbb9fe6554 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -305,11 +305,12 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("2.0.0") + @Since("2.2.0") override def fit(dataset: Dataset[_]): KMeansModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE fit(dataset, handlePersistence) - } +} + @Since("2.2.0") protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { From fd4543d918b26b451000fe92fbe1cdb259a6687b Mon Sep 17 00:00:00 2001 From: Zakaria_Hili Date: Tue, 22 Nov 2016 16:07:37 +0100 Subject: [PATCH 4/5] [SPARK-18356] [ML] Improve MLKmeans Performance --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f45cbb9fe6554..24a5e58c312fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -305,11 +305,12 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("2.2.0") + @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE fit(dataset, handlePersistence) -} + } + @Since("2.2.0") protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) From f17a54c34e9aa007f36426c2b8459ab0bf784863 Mon Sep 17 00:00:00 2001 From: HILI Zakaria Date: Thu, 24 Nov 2016 12:59:32 +0100 Subject: [PATCH 5/5] Update KMeans.scala --- .../src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 24a5e58c312fc..e17e0f630120e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -310,7 +310,7 @@ class KMeans @Since("1.5.0") ( val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE fit(dataset, handlePersistence) } - + @Since("2.2.0") protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true)