diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a741584982725..38ca2af80fa4a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -77,6 +77,34 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** + * Returns the distances to all clusters for a given point. + */ + @Since("1.5.0") + def distanceToCenters(point: Vector): Iterable[(Int, Double)] = { + val pointWithNorm = new VectorWithNorm(point) + clusterCentersWithNorm.zipWithIndex.map { + case (c, i) => + (i, KMeans.fastSquaredDistance(c, pointWithNorm)) + }.toList + } + + /** + * Maps given points to their distances to all clusters. + */ + @Since("1.5.0") + def distanceToCenters(points: RDD[Vector]): RDD[(Vector, Iterable[(Int, Double)])] = { + val centersWithNorm = clusterCentersWithNorm + val bcCentersWithNorm = points.context.broadcast(centersWithNorm) + points.map(p => { + val pointWithNorm = new VectorWithNorm(p) + (p, bcCentersWithNorm.value.zipWithIndex.map { + case (c, i) => + (i, KMeans.fastSquaredDistance(c, pointWithNorm)) + }.toList) + }) + } + /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 3003c62d9876c..9b45e88e61b35 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -250,7 +250,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { // Two iterations are sufficient no matter where the initial centers are. - val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode) + val k = 2 + val model = KMeans.train(rdd, k = k, maxIterations = 2, runs = 1, initMode) val predicts = model.predict(rdd).collect() @@ -259,6 +260,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predicts(3) === predicts(4)) assert(predicts(3) === predicts(5)) assert(predicts(0) != predicts(3)) + + assert(model.distanceToCenters(rdd).flatMap(_._2).count === points.size * k) } } @@ -341,6 +344,7 @@ class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { val model = KMeans.train(points, 2, 2, 1, initMode) val predictions = model.predict(points).collect() val cost = model.computeCost(points) + val dToCenters = model.distanceToCenters(points.first()) } } }