diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/clustering/KMeans.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/clustering/KMeans.scala new file mode 100644 index 0000000000000..e46a6396eb4ee --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/clustering/KMeans.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields +import org.apache.flink.api.scala.{DataSet, _} +import org.apache.flink.configuration.Configuration +import org.apache.flink.ml.common.{LabeledVector, _} +import org.apache.flink.ml.math.Breeze._ +import org.apache.flink.ml.math.Vector +import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric + +import scala.collection.JavaConverters._ + + +/** + * Implements the KMeans algorithm which calculates cluster centroids based on set of training data + * points and a set of k initial centroids. + * + * [[org.apache.flink.ml.clustering.KMeans]] is a [[org.apache.flink.ml.common.Learner]] which + * needs to be trained on a set of data points and emits a + * [[org.apache.flink.ml.clustering.KMeansModel]] which is a + * [[org.apache.flink.ml.common.Transformer]] to assign new points to the learned cluster centroids. + * + * The KMeans algorithm works as described on Wikipedia + * (http://en.wikipedia.org/wiki/K-means_clustering): + * + * Given an initial set of k means m1(1),…,mk(1) (see below), the algorithm proceeds by alternating + * between two steps: + * + * ===Assignment step:=== + * + * Assign each observation to the cluster whose mean yields the least within-cluster sum of + * squares (WCSS). Since the sum of squares is the squared Euclidean distance, this is intuitively + * the "nearest" mean. (Mathematically, this means partitioning the observations according to the + * Voronoi diagram generated by the means). + * + * `S_i^(t) = { x_p : || x_p - m_i^(t) ||^2 ≤ || x_p - m_j^(t) ||^2 \forall j, 1 ≤ j ≤ k}`, + * where each `x_p` is assigned to exactly one `S^{(t)}`, even if it could be assigned to two or + * more of them. + * + * ===Update step:=== + * + * Calculate the new means to be the centroids of the observations in the new clusters. + * + * `m^{(t+1)}_i = ( 1 / |S^{(t)}_i| ) \sum_{x_j \in S^{(t)}_i} x_j` + * + * Since the arithmetic mean is a least-squares estimator, this also minimizes the within-cluster + * sum of squares (WCSS) objective. + * + * @example + * {{{ + * val trainingDS: DataSet[Vector] = env.fromCollection(Clustering.trainingData) + * val initialCentroids: DataSet[LabledVector] = env.fromCollection(Clustering.initCentroids) + * + * val kmeans = KMeans() + * .setInitialCentroids(initialCentroids) + * .setNumIterations(10) + * + * val model = kmeans.fit(trainingDS) + * + * val testDS: DataSet[Vector] = env.fromCollection(Clustering.testData) + * + * val clusters: DataSet[LabeledVector] = model.transform(testDS) + * }}} + * + * =Parameters= + * + * - [[KMeans.NumIterations]]: + * Defines the number of iterations to recalculate the centroids of the clusters. As it + * is a heuristic algorithm, there is no guarantee that it will converge to the global optimum. The + * centroids of the clusters and the reassignment of the data points will be repeated till the + * given number of iterations is reached. + * (Default value: '''10''') + * + * - [[KMeans.InitialCentroids]]: + * Defines the initial k centroids of the k clusters. They are used as start off point of the + * algorithm for clustering the data set. The centroids are recalculated as often as set in + * [[KMeans.NumIterations]]. The choice of the initial centroids mainly affects the outcome of the + * algorithm. + * + */ +class KMeans extends Learner[Vector, KMeansModel] with Serializable { + + import KMeans._ + + /** + * Sets the number of iterations. + * + * @param numIterations + * @return itself + */ + def setNumIterations(numIterations: Int): KMeans = { + parameters.add(NumIterations, numIterations) + this + } + + /** + * Sets the initial centroids on which the algorithm will start computing. + * These points should depend on the data and significantly influence the resulting centroids. + * + * @param initialCentroids A sequence of labeled vectors. + * @return itself + */ + def setInitialCentroids(initialCentroids: DataSet[LabeledVector]): KMeans = { + parameters.add(InitialCentroids, initialCentroids) + this + } + + /** + * Iteratively computes centroids that match the given input DataSet by adjusting the given + * initial centroids. + * + * @param input Training data set + * @param fitParameters Parameter values + * @return Trained KMeans Model which represents the final centroids. + */ + override def fit(input: DataSet[Vector], fitParameters: ParameterMap): KMeansModel = { + val resultingParameters = this.parameters ++ fitParameters + + val centroids: DataSet[LabeledVector] = resultingParameters.get(InitialCentroids).get + val numIterations: Int = resultingParameters.get(NumIterations).get + + val finalCentroids = centroids.iterate(numIterations) { currentCentroids => + val newCentroids: DataSet[LabeledVector] = input + .map(new SelectNearestCenterMapper).withBroadcastSet(currentCentroids, CENTROIDS) + .map(x => (x.label, x.vector, 1.0)).withForwardedFields("label->_1; vector->_2") + .groupBy(x => x._1) + .reduce((p1, p2) => (p1._1, (p1._2.asBreeze + p2._2.asBreeze).fromBreeze, p1._3 + p2._3)) + .withForwardedFields("_1") + .map(x => LabeledVector(x._1, (x._2.asBreeze :/ x._3).fromBreeze)) + .withForwardedFields("_1->label") + + newCentroids + } + + KMeansModel(finalCentroids) + } + +} + +/** + * Companion object of KMeans. Contains convenience functions and the parameter type definitions + * of the algorithm. + */ +object KMeans { + val CENTROIDS = "centroids" + + case object NumIterations extends Parameter[Int] { + val defaultValue = Some(10) + } + + case object InitialCentroids extends Parameter[DataSet[LabeledVector]] { + val defaultValue = None + } + + def apply(): KMeans = { + new KMeans() + } + +} + +/** + * The resulting model of final centroids after the KMeans algorithm. Can be used to determine to + * which centroid a vector belongs. + * + * @param centroids The learned centroids based on the training data. + */ + +case class KMeansModel(centroids: DataSet[LabeledVector]) extends Transformer[Vector, LabeledVector] +with Serializable { + + import KMeans._ + + override def transform(input: DataSet[Vector], parameters: ParameterMap): + DataSet[LabeledVector] = { + input.map(new SelectNearestCenterMapper).withBroadcastSet(centroids, CENTROIDS) + } + +} + +/** + * Converts a given vector into a labeled vector where the label denotes the label of the closest + * centroid. + */ +@ForwardedFields(Array("*->vector")) +final class SelectNearestCenterMapper extends RichMapFunction[Vector, LabeledVector] { + + import KMeans._ + + private var centroids: Traversable[LabeledVector] = null + + /** Reads the centroid values from a broadcast variable into a collection. */ + override def open(parameters: Configuration) { + centroids = getRuntimeContext.getBroadcastVariable[LabeledVector](CENTROIDS).asScala + } + + def map(v: Vector): LabeledVector = { + var minDistance: Double = Double.MaxValue + var closestCentroidLabel: Double = -1 + centroids.foreach(centroid => { + val distance = EuclideanDistanceMetric().distance(v, centroid.vector) + if (distance < minDistance) { + minDistance = distance + closestCentroidLabel = centroid.label + } + }) + LabeledVector(closestCentroidLabel, v) + } + +}