From bbb8058a036a858df76cf0dc83933a65a186b484 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Sat, 27 May 2017 14:10:24 +0800 Subject: [PATCH 1/8] create pr --- .../spark/ml/clustering/MiniBatchKMeans.scala | 490 ++++++++++++++++++ .../spark/mllib/clustering/KMeans.scala | 12 +- .../ml/clustering/MiniBatchKMeansSuite.scala | 178 +++++++ 3 files changed, 674 insertions(+), 6 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala new file mode 100644 index 0000000000000..38cf9bdfcee1f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Partitioner +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, VectorWithNorm} +import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Common params for MiniBatchKMeans and MiniBatchKMeansModel + */ +private[clustering] trait MiniBatchKMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { + + /** + * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than + * k clusters to be returned, for example, if there are fewer than k distinct points to cluster. + * Default: 2. + * @group param + */ + @Since("2.3.0") + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) + + /** @group getParam */ + @Since("2.3.0") + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * @group expertParam + */ + @Since("2.3.0") + final val initMode = new Param[String](this, "initMode", "The initialization algorithm. " + + "Supported options: 'random' and 'k-means||'.", + (value: String) => MLlibKMeans.validateInitMode(value)) + + /** @group expertGetParam */ + @Since("2.3.0") + def getInitMode: String = $(initMode) + + /** + * Param for the number of steps for the k-means|| initialization mode. This is an advanced + * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. + * @group expertParam + */ + @Since("2.3.0") + final val initSteps = new IntParam(this, "initSteps", "The number of steps for k-means|| " + + "initialization mode. Must be > 0.", ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("2.3.0") + def getInitSteps: Int = $(initSteps) + + /** + * The fraction of data used to update centers per iteration. Must be > 0 and ≤ 1. + * Default: 1.0. + * @group param + */ + @Since("2.3.0") + final val fraction = new DoubleParam(this, "fraction", "The fraction of data used to " + + "update cluster centers per iteration. Must be in (0, 1].", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + @Since("2.3.0") + def getFraction: Double = $(fraction) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * Model fitted by MiniBatchKMeans. + * + * @param clusterCenters Centers of each cluster. + */ +@Since("2.3.0") +class MiniBatchKMeansModel private[ml] ( + @Since("2.3.0") override val uid: String, + @Since("2.3.0") val clusterCenters: Array[Vector]) + extends Model[MiniBatchKMeansModel] with MiniBatchKMeansParams with MLWritable { + + private lazy val clusterCentersWithNorm = + if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + + @Since("2.3.0") + override def copy(extra: ParamMap): MiniBatchKMeansModel = { + val copied = copyValues(new MiniBatchKMeansModel(uid, clusterCenters), extra) + copied.setSummary(trainingSummary).setParent(this.parent) + } + + /** @group setParam */ + @Since("2.3.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.3.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = + MLlibKMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(features))._1 + + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + */ + // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @Since("2.3.0") + def computeCost(dataset: Dataset[_]): Double = { + SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + val bcCentersWithNorm = dataset.sparkSession.sparkContext.broadcast(clusterCentersWithNorm) + val cost = dataset.select(col($(featuresCol))).rdd.map { + case Row(point: Vector) => + MLlibKMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(point)) + }.sum() + bcCentersWithNorm.destroy(blocking = false) + cost + } + + /** + * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * + * For [[MiniBatchKMeansModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + */ + @Since("2.3.0") + override def write: MLWriter = new MiniBatchKMeansModel.MiniBatchKMeansModelWriter(this) + + private var trainingSummary: Option[MiniBatchKMeansSummary] = None + + private[clustering] def setSummary(summary: Option[MiniBatchKMeansSummary]): this.type = { + this.trainingSummary = summary + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.3.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.3.0") + def summary: MiniBatchKMeansSummary = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } +} + +@Since("2.3.0") +object MiniBatchKMeansModel extends MLReadable[MiniBatchKMeansModel] { + + @Since("2.3.0") + override def read: MLReader[MiniBatchKMeansModel] = new MiniBatchKMeansModelReader + + @Since("2.3.0") + override def load(path: String): MiniBatchKMeansModel = super.load(path) + + /** Helper class for storing model data */ + private case class Data(clusterIdx: Int, clusterCenter: Vector) + + /** [[MLWriter]] instance for [[MiniBatchKMeansModel]] */ + private[MiniBatchKMeansModel] class MiniBatchKMeansModelWriter(instance: MiniBatchKMeansModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => + Data(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } + } + + private class MiniBatchKMeansModelReader extends MLReader[MiniBatchKMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[MiniBatchKMeansModel].getName + + override def load(path: String): MiniBatchKMeansModel = { + // Import implicits for Dataset Encoder + val sparkSession = super.sparkSession + import sparkSession.implicits._ + + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + + val clusterCenters = { + val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) + } + + val model = new MiniBatchKMeansModel(metadata.uid, clusterCenters) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * MiniBatch K-means clustering proposed by Sculley. + * + * @see Sculley, Web-Scale + * K-Means Clustering. + */ +@Since("2.3.0") +class MiniBatchKMeans @Since("2.3.0") ( + @Since("2.3.0") override val uid: String) + extends Estimator[MiniBatchKMeansModel] with MiniBatchKMeansParams + with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> MLlibKMeans.K_MEANS_PARALLEL, + initSteps -> 2, + tol -> 1e-4, + fraction -> 1.0) + + @Since("2.3.0") + override def copy(extra: ParamMap): MiniBatchKMeans = defaultCopy(extra) + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("minibatch-kmeans")) + + /** @group setParam */ + @Since("2.3.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.3.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.3.0") + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + @Since("2.3.0") + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group expertSetParam */ + @Since("2.3.0") + def setInitSteps(value: Int): this.type = set(initSteps, value) + + /** @group setParam */ + @Since("2.3.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.3.0") + def setTol(value: Double): this.type = set(tol, value) + + /** @group setParam */ + @Since("2.3.0") + def setFraction(value: Double): this.type = set(fraction, value) + + /** @group setParam */ + @Since("2.3.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.3.0") + override def fit(dataset: Dataset[_]): MiniBatchKMeansModel = { + transformSchema(dataset.schema, logging = true) + + val data = dataset.select(col($(featuresCol))).rdd.map { + case Row(point: Vector) => + new VectorWithNorm(point) + } + data.persist(StorageLevel.MEMORY_AND_DISK) + + val instr = Instrumentation.create(this, dataset) + instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol, + fraction) + + val initStartTime = System.nanoTime() + val centers: Array[VectorWithNorm] = initCenters(data) + val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 + logInfo(f"Initialization with ${$(initMode)} took $initTimeInSeconds%.3f seconds.") + + val sc = dataset.sparkSession.sparkContext + + val iterationStartTime = System.nanoTime() + + val numFeatures = centers.head.vector.size + instr.logNumFeatures(numFeatures) + val numCenters = centers.length + val counts = Array.ofDim[Long](numCenters) + + var converged = false + var iteration = 0 + + // Execute iterations of Sculley's algorithm until converged + while (iteration < $(maxIter) && !converged) { + val iterStartTime = (System.nanoTime() - initStartTime) / 1e9 + + val costAccum = sc.doubleAccumulator + val bcCenters = sc.broadcast(centers) + val bcCounts = sc.broadcast(counts) + + val sampled = if ($(fraction) == 1.0) { + data + } else { + data.sample(false, $(fraction), iteration + 42) + } + + val totalContribs = sampled.mapPartitions { points => + val thisCenters = bcCenters.value + points.map { (point: VectorWithNorm) => + val (bestCenter, cost) = MLlibKMeans.findClosest(thisCenters, point) + costAccum.add(cost) + (bestCenter, point.vector) + } + }.partitionBy(new KeyPartitioner(numCenters)) + .mapPartitions { it => + val center = Vectors.zeros(numFeatures) + var count = -1L + var best = -1 + + it.foreach { + case (bestCenter, point) => + if (count < 0) { + axpy(1.0, bcCenters.value(bestCenter).vector, center) + count = bcCounts.value(bestCenter) + 1 + best = bestCenter + } else { + count += 1 + } + // learning rate + val lr = 1.0 / count + // center = center * (1 - lr) + point * lr + scal(1 - lr, center) + axpy(lr, point, center) + } + + if (count > 0) { + Iterator.single((best, (center, count))) + } else { + Iterator.empty + } + }.collectAsMap() + + // Update the cluster centers and costs + converged = true + totalContribs.foreach { case (j, (center, count)) => + val newCenter = new VectorWithNorm(center) + if (converged + && MLlibKMeans.fastSquaredDistance(newCenter, centers(j)) > $(tol) * $(tol)) { + converged = false + } + centers(j) = newCenter + counts(j) = count + } + + bcCenters.destroy(blocking = false) + bcCounts.destroy(blocking = false) + + val cost = costAccum.value + + val iterTimeInSeconds = (System.nanoTime() - iterStartTime) / 1e9 + logInfo(f"Iteration $iteration took $iterTimeInSeconds%.3f seconds, " + + f"cost on sampled data: $cost") + iteration += 1 + } + data.unpersist(blocking = false) + + val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9 + logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.") + + if (iteration == $(maxIter)) { + logInfo(s"MiniBatchKMeans reached the max number of iterations: ${$(maxIter)}.") + } else { + logInfo(s"MiniBatchKMeans converged in $iteration iterations.") + } + + new MiniBatchKMeansModel(uid, centers.map(_.vector.asML)) + } + + private class KeyPartitioner(partitions: Int) extends Partitioner { + require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + + override def numPartitions: Int = partitions + + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } + + private def initCenters(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + val algo = new MLlibKMeans() + .setK($(k)) + .setInitializationMode($(initMode)) + .setInitializationSteps($(initSteps)) + .setMaxIterations(0) + .setSeed($(seed)) + + $(initMode) match { + case MLlibKMeans.RANDOM => + algo.initRandom(data) + case MLlibKMeans.K_MEANS_PARALLEL => + algo.initKMeansParallel(data) + } + } + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + +@Since("2.3.0") +object MiniBatchKMeans extends DefaultParamsReadable[MiniBatchKMeans] { + + @Since("2.3.0") + override def load(path: String): MiniBatchKMeans = super.load(path) +} + +/** + * :: Experimental :: + * Summary of MiniBatchKMeans. + * + * @param predictions `DataFrame` produced by `MiniBatchKMeansModel.transform()`. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. + */ +@Since("2.3.0") +@Experimental +class MiniBatchKMeansSummary private[clustering] ( + predictions: DataFrame, + predictionCol: String, + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fa72b72e2d921..74b9cbd9c813f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -328,7 +328,7 @@ class KMeans private ( /** * Initialize a set of cluster centers at random. */ - private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[spark] def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { // Select without replacement; may still produce duplicates if the data has < k distinct // points, so deduplicate the centroids to match the behavior of k-means|| in the same situation data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt()) @@ -344,7 +344,7 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[spark] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -548,7 +548,7 @@ object KMeans { /** * Returns the index of the closest center to the given point, as well as the squared distance. */ - private[mllib] def findClosest( + private[spark] def findClosest( centers: TraversableOnce[VectorWithNorm], point: VectorWithNorm): (Int, Double) = { var bestDistance = Double.PositiveInfinity @@ -574,7 +574,7 @@ object KMeans { /** * Returns the K-means cost of a given point against the given cluster centers. */ - private[mllib] def pointCost( + private[spark] def pointCost( centers: TraversableOnce[VectorWithNorm], point: VectorWithNorm): Double = findClosest(centers, point)._2 @@ -583,7 +583,7 @@ object KMeans { * Returns the squared Euclidean distance between two vectors computed by * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. */ - private[clustering] def fastSquaredDistance( + private[spark] def fastSquaredDistance( v1: VectorWithNorm, v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) @@ -603,7 +603,7 @@ object KMeans { * * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] */ -private[clustering] +private[spark] class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable { def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala new file mode 100644 index 0000000000000..0c0d8c92859b0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +class MiniBatchKMeansSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + final val k = 5 + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) + } + + test("default parameters") { + val mbkm = new MiniBatchKMeans() + + assert(mbkm.getK === 2) + assert(mbkm.getFeaturesCol === "features") + assert(mbkm.getPredictionCol === "prediction") + assert(mbkm.getMaxIter === 20) + assert(mbkm.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) + assert(mbkm.getInitSteps === 2) + assert(mbkm.getTol === 1e-4) + assert(mbkm.getFraction === 1.0) + val model = mbkm.setMaxIter(1).fit(dataset) + + MLTestingUtils.checkCopyAndUids(mbkm, model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + } + + test("set parameters") { + val mbkm = new MiniBatchKMeans() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setInitMode(MLlibKMeans.RANDOM) + .setInitSteps(3) + .setSeed(123) + .setTol(1e-3) + .setFraction(0.1) + + assert(mbkm.getK === 9) + assert(mbkm.getFeaturesCol === "test_feature") + assert(mbkm.getPredictionCol === "test_prediction") + assert(mbkm.getMaxIter === 33) + assert(mbkm.getInitMode === MLlibKMeans.RANDOM) + assert(mbkm.getInitSteps === 3) + assert(mbkm.getSeed === 123) + assert(mbkm.getTol === 1e-3) + assert(mbkm.getFraction === 0.1) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new MiniBatchKMeans().setK(1) + } + intercept[IllegalArgumentException] { + new MiniBatchKMeans().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new MiniBatchKMeans().setInitSteps(0) + } + intercept[IllegalArgumentException] { + new MiniBatchKMeans().setFraction(0) + } + intercept[IllegalArgumentException] { + new MiniBatchKMeans().setFraction(1.01) + } + intercept[IllegalArgumentException] { + new MiniBatchKMeans().setFraction(-0.01) + } + } + + test("fit, transform and summary") { + val predictionColName = "minibatchkmeans_prediction" + val mbkm = new MiniBatchKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = mbkm.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) + assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) + } + + test("KMeansModel transform with non-default feature and prediction cols") { + val featuresColName = "minibatchkmeans_model_features" + val predictionColName = "minibatchkmeans_model_prediction" + + val model = new MiniBatchKMeans().setK(k).setSeed(1).fit(dataset) + model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) + + val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) + Seq(featuresColName, predictionColName).foreach { column => + assert(transformed.columns.contains(column)) + } + assert(model.getFeaturesCol == featuresColName) + assert(model.getPredictionCol == predictionColName) + } + + test("read/write") { + def checkModelData(model: MiniBatchKMeansModel, model2: MiniBatchKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new MiniBatchKMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, + KMeansSuite.allParamSettings, checkModelData) + } +} + +object MiniBatchKMeansSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01, + "fraction" -> 1.0 + ) +} From b16a5dcc2abaefc5dfcddb6ebe8cfedbe830632f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Sat, 27 May 2017 14:15:04 +0800 Subject: [PATCH 2/8] update test --- .../apache/spark/ml/clustering/MiniBatchKMeansSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala index 0c0d8c92859b0..b5959b8d3de6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala @@ -157,8 +157,8 @@ class MiniBatchKMeansSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new MiniBatchKMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, - KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, MiniBatchKMeansSuite.allParamSettings, + MiniBatchKMeansSuite.allParamSettings, checkModelData) } } @@ -173,6 +173,6 @@ object MiniBatchKMeansSuite { "k" -> 3, "maxIter" -> 2, "tol" -> 0.01, - "fraction" -> 1.0 + "fraction" -> 0.1 ) } From 5b8b615544079122a707ebc1407a379c6eb2c012 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 22 Jun 2017 13:24:13 +0800 Subject: [PATCH 3/8] update alg --- .../spark/ml/clustering/MiniBatchKMeans.scala | 76 ++++++++----------- 1 file changed, 30 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala index 38cf9bdfcee1f..7d3a96f52cf26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala @@ -343,10 +343,12 @@ class MiniBatchKMeans @Since("2.3.0") ( val numFeatures = centers.head.vector.size instr.logNumFeatures(numFeatures) + val numCenters = centers.length val counts = Array.ofDim[Long](numCenters) var converged = false + var batchSize = 0L var iteration = 0 // Execute iterations of Sculley's algorithm until converged @@ -363,52 +365,42 @@ class MiniBatchKMeans @Since("2.3.0") ( data.sample(false, $(fraction), iteration + 42) } + // Find the sum and count of points mapping to each center val totalContribs = sampled.mapPartitions { points => val thisCenters = bcCenters.value - points.map { (point: VectorWithNorm) => + val dims = thisCenters.head.vector.size + val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) + val counts = Array.fill(thisCenters.length)(0L) + + points.foreach { point => val (bestCenter, cost) = MLlibKMeans.findClosest(thisCenters, point) costAccum.add(cost) - (bestCenter, point.vector) + val sum = sums(bestCenter) + axpy(1.0, point.vector, sum) + counts(bestCenter) += 1 } - }.partitionBy(new KeyPartitioner(numCenters)) - .mapPartitions { it => - val center = Vectors.zeros(numFeatures) - var count = -1L - var best = -1 - - it.foreach { - case (bestCenter, point) => - if (count < 0) { - axpy(1.0, bcCenters.value(bestCenter).vector, center) - count = bcCounts.value(bestCenter) + 1 - best = bestCenter - } else { - count += 1 - } - // learning rate - val lr = 1.0 / count - // center = center * (1 - lr) + point * lr - scal(1 - lr, center) - axpy(lr, point, center) - } - - if (count > 0) { - Iterator.single((best, (center, count))) - } else { - Iterator.empty - } - }.collectAsMap() - - // Update the cluster centers and costs + + counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator + }.reduceByKey { case ((sum1, count1), (sum2, count2)) => + axpy(1.0, sum2, sum1) + (sum1, count1 + count2) + }.collectAsMap() + + // Update the cluster centers, costs and counts converged = true - totalContribs.foreach { case (j, (center, count)) => - val newCenter = new VectorWithNorm(center) - if (converged - && MLlibKMeans.fastSquaredDistance(newCenter, centers(j)) > $(tol) * $(tol)) { + batchSize = 0 + totalContribs.foreach { case (j, (sum, count)) => + batchSize += count + val newCount = counts(j) + count + scal(1.0 / newCount, sum) + axpy(counts(j).toDouble / newCount, centers(j).vector, sum) + val newCenter = new VectorWithNorm(sum) + if (converged && + MLlibKMeans.fastSquaredDistance(newCenter, centers(j)) > $(tol) * $(tol)) { converged = false } centers(j) = newCenter - counts(j) = count + counts(j) = newCount } bcCenters.destroy(blocking = false) @@ -418,7 +410,7 @@ class MiniBatchKMeans @Since("2.3.0") ( val iterTimeInSeconds = (System.nanoTime() - iterStartTime) / 1e9 logInfo(f"Iteration $iteration took $iterTimeInSeconds%.3f seconds, " + - f"cost on sampled data: $cost") + f"cost on $batchSize instances: $cost") iteration += 1 } data.unpersist(blocking = false) @@ -435,14 +427,6 @@ class MiniBatchKMeans @Since("2.3.0") ( new MiniBatchKMeansModel(uid, centers.map(_.vector.asML)) } - private class KeyPartitioner(partitions: Int) extends Partitioner { - require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") - - override def numPartitions: Int = partitions - - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - } - private def initCenters(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { val algo = new MLlibKMeans() .setK($(k)) From 12ac3d339c5ff6d528ca35064d3cd836081250b4 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 22 Jun 2017 13:31:30 +0800 Subject: [PATCH 4/8] del unused bc --- .../org/apache/spark/ml/clustering/MiniBatchKMeans.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala index 7d3a96f52cf26..d13b62d8473bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala @@ -353,11 +353,10 @@ class MiniBatchKMeans @Since("2.3.0") ( // Execute iterations of Sculley's algorithm until converged while (iteration < $(maxIter) && !converged) { - val iterStartTime = (System.nanoTime() - initStartTime) / 1e9 + val singleIterationStartTime = (System.nanoTime() - initStartTime) / 1e9 val costAccum = sc.doubleAccumulator val bcCenters = sc.broadcast(centers) - val bcCounts = sc.broadcast(counts) val sampled = if ($(fraction) == 1.0) { data @@ -402,14 +401,12 @@ class MiniBatchKMeans @Since("2.3.0") ( centers(j) = newCenter counts(j) = newCount } - bcCenters.destroy(blocking = false) - bcCounts.destroy(blocking = false) val cost = costAccum.value - val iterTimeInSeconds = (System.nanoTime() - iterStartTime) / 1e9 - logInfo(f"Iteration $iteration took $iterTimeInSeconds%.3f seconds, " + + val singleIterationTimeInSeconds = (System.nanoTime() - singleIterationStartTime) / 1e9 + logInfo(f"Iteration $iteration took $singleIterationTimeInSeconds%.3f seconds, " + f"cost on $batchSize instances: $cost") iteration += 1 } From f6ef7caacf015db1941054640952ce6a31a9e9e2 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 22 Jun 2017 16:19:33 +0800 Subject: [PATCH 5/8] del unused import --- .../scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala index d13b62d8473bd..1ab2c77583e36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path -import org.apache.spark.Partitioner import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} From 19d77f67e1b508360f4af46731afe3b656b5cf5e Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 22 Jun 2017 16:38:59 +0800 Subject: [PATCH 6/8] fix test style --- .../apache/spark/ml/clustering/MiniBatchKMeansSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala index b5959b8d3de6c..6b33ad724f30b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala @@ -164,10 +164,10 @@ class MiniBatchKMeansSuite extends SparkFunSuite with MLlibTestSparkContext object MiniBatchKMeansSuite { /** - * Mapping from all Params to valid settings which differ from the defaults. - * This is useful for tests which need to exercise all Params, such as save/load. - * This excludes input columns to simplify some tests. - */ + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ val allParamSettings: Map[String, Any] = Map( "predictionCol" -> "myPrediction", "k" -> 3, From 78c6deaba8cd41d8c53af31b33fa4f510c63e502 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 22 Jun 2017 17:57:52 +0800 Subject: [PATCH 7/8] copyvalue --- .../apache/spark/ml/clustering/MiniBatchKMeans.scala | 8 +++++++- .../spark/ml/clustering/MiniBatchKMeansSuite.scala | 10 +++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala index 1ab2c77583e36..9c2a66a24c1e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala @@ -420,7 +420,13 @@ class MiniBatchKMeans @Since("2.3.0") ( logInfo(s"MiniBatchKMeans converged in $iteration iterations.") } - new MiniBatchKMeansModel(uid, centers.map(_.vector.asML)) + val model = copyValues(new MiniBatchKMeansModel(uid, centers.map(_.vector.asML)) + .setParent(this)) + val summary = new MiniBatchKMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(Some(summary)) + instr.logSuccess(model) + model } private def initCenters(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala index 6b33ad724f30b..e7572c5fdb09d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/MiniBatchKMeansSuite.scala @@ -100,7 +100,7 @@ class MiniBatchKMeansSuite extends SparkFunSuite with MLlibTestSparkContext } test("fit, transform and summary") { - val predictionColName = "minibatchkmeans_prediction" + val predictionColName = "minibatch_kmeans_prediction" val mbkm = new MiniBatchKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = mbkm.fit(dataset) assert(model.clusterCenters.length === k) @@ -138,8 +138,8 @@ class MiniBatchKMeansSuite extends SparkFunSuite with MLlibTestSparkContext } test("KMeansModel transform with non-default feature and prediction cols") { - val featuresColName = "minibatchkmeans_model_features" - val predictionColName = "minibatchkmeans_model_prediction" + val featuresColName = "minibatch_kmeans_model_features" + val predictionColName = "minibatch_kmeans_model_prediction" val model = new MiniBatchKMeans().setK(k).setSeed(1).fit(dataset) model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) @@ -156,8 +156,8 @@ class MiniBatchKMeansSuite extends SparkFunSuite with MLlibTestSparkContext def checkModelData(model: MiniBatchKMeansModel, model2: MiniBatchKMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) } - val kmeans = new MiniBatchKMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, MiniBatchKMeansSuite.allParamSettings, + val mbkm = new MiniBatchKMeans() + testEstimatorAndModelReadWrite(mbkm, dataset, MiniBatchKMeansSuite.allParamSettings, MiniBatchKMeansSuite.allParamSettings, checkModelData) } } From 92642e52206844413f67c81e712ceee2d1a6e736 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 22 Jun 2017 17:58:16 +0800 Subject: [PATCH 8/8] del unnecessary line --- .../scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala index 9c2a66a24c1e6..806fcb3886743 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/MiniBatchKMeans.scala @@ -434,7 +434,6 @@ class MiniBatchKMeans @Since("2.3.0") ( .setK($(k)) .setInitializationMode($(initMode)) .setInitializationSteps($(initSteps)) - .setMaxIterations(0) .setSeed($(seed)) $(initMode) match {