diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala index 554e155201045..14638650cef84 100644 --- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala @@ -18,6 +18,7 @@ package org.apache.flink + import org.apache.flink.api.common.functions.{RichFilterFunction, RichMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.DataSink @@ -25,6 +26,7 @@ import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} import org.apache.flink.configuration.Configuration import org.apache.flink.ml.common.LabeledVector +import scala.collection.JavaConverters._ import scala.reflect.ClassTag package object ml { @@ -70,6 +72,14 @@ package object ml { dataSet.map(new BroadcastSingleElementMapperWithIteration[T, B, O](dataSet.clean(fun))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } + + def mapWithBcSet[B, O: TypeInformation: ClassTag]( + broadcastVariable: DataSet[B])( + fun: (T, Seq[B]) => O) + : DataSet[O] = { + dataSet.map(new BroadcastSetMapper[T, B, O](dataSet.clean(fun))) + .withBroadcastSet(broadcastVariable, "broadcastVariable") + } } private class BroadcastSingleElementMapper[T, B, O]( @@ -101,7 +111,7 @@ package object ml { fun(value, broadcastVariable, getIterationRuntimeContext.getSuperstepNumber) } } - + private class BroadcastSingleElementFilter[T, B]( fun: (T, B) => Boolean) extends RichFilterFunction[T] { @@ -116,4 +126,21 @@ package object ml { fun(value, broadcastVariable) } } + + private class BroadcastSetMapper[T, B, O](fun: (T, Seq[B]) => O) + extends RichMapFunction[T, O] { + var broadcastVariable: Seq[B] = _ + + @throws(classOf[Exception]) + override def open(configuration: Configuration): Unit = { + broadcastVariable = getRuntimeContext + .getBroadcastVariable[B]("broadcastVariable") + .asScala + .toSeq + } + + override def map(value: T): O = { + fun(value, broadcastVariable) + } + } } diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/clustering/KMeans.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/clustering/KMeans.scala new file mode 100644 index 0000000000000..edaee3d04b0d8 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/clustering/KMeans.scala @@ -0,0 +1,614 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering + +import org.apache.flink.api.common.functions.RichFilterFunction +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields +import org.apache.flink.api.scala.{DataSet, _} +import org.apache.flink.configuration.Configuration +import org.apache.flink.ml._ +import org.apache.flink.ml.common.FlinkMLTools.ModuloKeyPartitioner +import org.apache.flink.ml.common.{LabeledVector, _} +import org.apache.flink.ml.math.Breeze._ +import org.apache.flink.ml.math.{BLAS, Vector} +import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric +import org.apache.flink.ml.pipeline._ + +import scala.collection.JavaConverters._ +import scala.util.Random + + +/** + * Implements the KMeans algorithm which calculates cluster centroids based on set of training data + * points and a set of k initial centroids. + * + * [[KMeans]] is a [[Predictor]] which needs to be trained on a set of data points and can then be + * used to assign new points to the learned cluster centroids. + * + * The KMeans algorithm works as described on Wikipedia + * (http://en.wikipedia.org/wiki/K-means_clustering): + * + * Given an initial set of k means m1(1),…,mk(1) (see below), the algorithm proceeds by alternating + * between two steps: + * + * ===Assignment step:=== + * + * Assign each observation to the cluster whose mean yields the least within-cluster sum of + * squares (WCSS). Since the sum of squares is the squared Euclidean distance, this is intuitively + * the "nearest" mean. (Mathematically, this means partitioning the observations according to the + * Voronoi diagram generated by the means). + * + * `S_i^(t) = { x_p : || x_p - m_i^(t) ||^2 ≤ || x_p - m_j^(t) ||^2 \forall j, 1 ≤ j ≤ k}`, + * where each `x_p` is assigned to exactly one `S^{(t)}`, even if it could be assigned to two or + * more of them. + * + * ===Update step:=== + * + * Calculate the new means to be the centroids of the observations in the new clusters. + * + * `m^{(t+1)}_i = ( 1 / |S^{(t)}_i| ) \sum_{x_j \in S^{(t)}_i} x_j` + * + * Since the arithmetic mean is a least-squares estimator, this also minimizes the within-cluster + * sum of squares (WCSS) objective. + * + * @example + * {{{ + * val trainingDS: DataSet[Vector] = env.fromCollection(Clustering.trainingData) + * val initialCentroids: DataSet[LabledVector] = env.fromCollection(Clustering.initCentroids) + * + * val kmeans = KMeans() + * .setInitialCentroids(initialCentroids) + * .setNumIterations(10) + * + * kmeans.fit(trainingDS) + * + * // getting the computed centroids + * val centroidsResult = kmeans.centroids.get.collect() + * + * // get matching clusters for new points + * val testDS: DataSet[Vector] = env.fromCollection(Clustering.testData) + * val clusters: DataSet[LabeledVector] = kmeans.predict(testDS) + * }}} + * + * =Parameters= + * + * - [[org.apache.flink.ml.clustering.KMeans.NumIterations]]: + * Defines the number of iterations to recalculate the centroids of the clusters. As it + * is a heuristic algorithm, there is no guarantee that it will converge to the global optimum. The + * centroids of the clusters and the reassignment of the data points will be repeated till the + * given number of iterations is reached. + * (Default value: '''10''') + * + * - [[org.apache.flink.ml.clustering.KMeans.InitialCentroids]]: + * Defines the initial k centroids of the k clusters. They are used as start off point of the + * algorithm for clustering the data set. The centroids are recalculated as often as set in + * [[org.apache.flink.ml.clustering.KMeans.NumIterations]]. The choice of the initial centroids + * mainly affects the outcome of the algorithm. + * + * - [[org.apache.flink.ml.clustering.KMeans.InitialStrategy]]: + * Defines the initialization strategy to be used for initializing the KMeans algorithm in case + * the initial centroids are not provided. Allowed values are "random", "kmeans++" and "kmeans||". + * (Default Value: '''random''') + * + * - [[org.apache.flink.ml.clustering.KMeans.NumClusters]]: + * Defines the number of clusters required. This is essential to provide when only the + * initialization strategy is specified, not the initial centroids themselves. + * (Default Value: '''0''') + * + * - [[org.apache.flink.ml.clustering.KMeans.OversamplingFactor]]: + * Defines the oversampling rate for the kmeans|| initialization. + * (Default Value: '''2k'''), where k is the number of clusters. + * + * - [[org.apache.flink.ml.clustering.KMeans.KMeansParRounds]]: + * Defines the number of rounds for the kmeans|| initialization. + * (Default Value: '''5''') + * + */ +class KMeans extends Predictor[KMeans] { + + import KMeans._ + + /** + * Stores the learned clusters after the fit operation + */ + var centroids: Option[DataSet[Seq[LabeledVector]]] = None + + /** + * Sets the maximum number of iterations. + * + * @param numIterations The maximum number of iterations. + * @return itself + */ + def setNumIterations(numIterations: Int): KMeans = { + parameters.add(NumIterations, numIterations) + this + } + + /** + * Sets the number of clusters. + * + * @param numClusters The number of clusters + * @return itself + */ + def setNumClusters(numClusters: Int): KMeans = { + parameters.add(NumClusters, numClusters) + this + } + + /** + * Sets the initial centroids on which the algorithm will start computing. These points should + * depend on the data and will significantly influence the resulting centroids. + * Note that this setting will override [[setInitializationStrategy())]] and the size of + * initialCentroids will override the value, if set, by [[setNumClusters()]] + * + * @param initialCentroids A set of labeled vectors. + * @return itself + */ + def setInitialCentroids(initialCentroids: Seq[LabeledVector]): KMeans = { + parameters.add(InitialCentroids, initialCentroids) + this + } + + /** + * Automatically initialize the KMeans algorithm. Allowed options are "random", "kmeans++" and + * "kmeans||" + * + * @param initialStrategy + * @return itself + */ + def setInitializationStrategy(initialStrategy: String): KMeans = { + require(Array("random", "kmeans++", "kmeans||").contains(initialStrategy), s"$initialStrategy" + + s" is not supported") + parameters.add(InitialStrategy, initialStrategy) + this + } + + /** + * Oversampling factor to be used in case the initialization strategy is set to be "kmeans||" + * + * @param oversamplingFactor Oversampling factor(\ell) + * @return this + */ + def setOversamplingFactor(oversamplingFactor: Double): KMeans = { + require(oversamplingFactor > 0, "Oversampling factor must be positive.") + parameters.add(OversamplingFactor, oversamplingFactor) + this + } + + /** + * Number of initialization rounds to be done when the initialization strategy is set to be + * "kmeans||" + * + * @param numRounds Number of rounds(r) + * @return this + */ + def setNumRounds(numRounds: Int): KMeans = { + require(numRounds > 0, "Number of rounds must be positive") + parameters.add(KMeansParRounds, numRounds) + this + } + +} + +/** + * Companion object of KMeans. Contains convenience functions, the parameter type definitions + * of the algorithm and the [[FitOperation]] & [[PredictOperation]]. + */ +object KMeans { + + private val RANDOM_FRACTION = "random_sample_fraction" + private val PARINIT_SET = "par_init_solution_set" + private val PARINIT_COST = "par_init_solution_cost" + private val PARINIT_SAMPLE = "par_init_oversample_factor" + + /** Euclidean Distance Metric */ + val euclidean = EuclideanDistanceMetric() + + case object NumIterations extends Parameter[Int] { + val defaultValue = Some(10) + } + + case object InitialCentroids extends Parameter[Seq[LabeledVector]] { + val defaultValue = None + } + + case object InitialStrategy extends Parameter[String]{ + val defaultValue = Some("kmeans||") + } + + case object NumClusters extends Parameter[Int] { + val defaultValue = None + } + + case object OversamplingFactor extends Parameter[Double] { + val defaultValue = None + } + + case object KMeansParRounds extends Parameter[Int] { + val defaultValue = Some(5) + } + + // ========================================== Factory methods ==================================== + + def apply(): KMeans = { + new KMeans() + } + + // ========================================== Operations ========================================= + + /** Provides the operation that makes the predictions for individual examples. + * The label of the vector will be the index of the cluster the input vector belongs to. + * + * @tparam T + * @return A PredictOperation, through which it is possible to predict a value, given a + * feature vector + */ + implicit def predictVectors[T <: Vector] = { + new PredictOperation[KMeans, Seq[LabeledVector], T, Double](){ + + override def getModel( + self: KMeans, + predictParameters: ParameterMap) + : DataSet[Seq[LabeledVector]] = { + + self.centroids match { + case Some(model) => model + case None => { + throw new RuntimeException("The KMeans model has not been trained. Call first fit" + + "before calling the predict operation.") + } + } + } + + override def predict(value: T, model: Seq[LabeledVector]): Double = { + findNearestCentroid(value, model)._1 + } + } + } + + /** + * [[FitOperation]] which iteratively computes centroids that match the given input DataSet by + * adjusting the given initial centroids. + * + * @return A new [[FitOperation]] to train the model using the training data set. + */ + implicit def fitKMeans = { + new FitOperation[KMeans, Vector] { + override def fit(instance: KMeans, fitParameters: ParameterMap, trainingDS: DataSet[Vector]) + : Unit = { + val resultingParameters = instance.parameters ++ fitParameters + + // ================= INITIALIZATION OF KMEANS ========================== + val centroids: DataSet[Seq[LabeledVector]] = init(trainingDS, resultingParameters) + + val numIterations: Int = resultingParameters.get(NumIterations).get + + val finalCentroids = centroids.iterate(numIterations) { currentCentroids => + val newCentroids: DataSet[LabeledVector] = trainingDS + .mapWithBcVariable(currentCentroids) + { (dataPoint, centroids) => selectNearestCentroid(dataPoint, centroids) } + .map(x => (x.label, x.vector, 1.0)).withForwardedFields("label->_1; vector->_2") + .groupBy(x => x._1) + .reduce((p1, p2) => + (p1._1, (p1._2.asBreeze + p2._2.asBreeze).fromBreeze, p1._3 + p2._3)) + // TODO replace addition of Breeze vectors by future build in flink function + .withForwardedFields("_1") + .map(x => { + BLAS.scal(1.0 / x._3, x._2) + LabeledVector(x._1, x._2) + }) + .withForwardedFields("_1->label") + + // currentCentroids contains only one element. So, this is output only once + currentCentroids.mapWithBcSet(newCentroids){ + (_,newCenters) => newCenters + } + } + instance.centroids = Some(finalCentroids) + } + } + } + + /** + * Converts a given vector into a labeled vector where the label denotes the label of the closest + * centroid. + * + * @param dataPoint The vector to determine the nearest centroid. + * @param centroids A collection of the centroids. + * @return A [[LabeledVector]] consisting of the input vector and the label of the closest + * centroid. + */ + @ForwardedFields(Array("*->vector")) + private def selectNearestCentroid(dataPoint: Vector, centroids: Seq[LabeledVector]) = { + val nearest = findNearestCentroid(dataPoint, centroids) + LabeledVector(nearest._1, dataPoint) + } + + /** + * Finds the nearest centroid to a point and returns the distance to this centroid and label of it + * + * @param dataPoint The vector to determine the nearest centroid. + * @param centroids A collection of the centroids. + * @return A tuple of distance to the nearest centroid and label of this centroid + */ + private def findNearestCentroid(dataPoint: Vector, centroids: Seq[LabeledVector]) = { + var minDistance: Double = Double.MaxValue + var closestCentroidLabel: Double = -1 + centroids.foreach(centroid => { + val distance = euclidean.distance(dataPoint, centroid.vector) + if (distance < minDistance) { + minDistance = distance + closestCentroidLabel = centroid.label + } + }) + (closestCentroidLabel, minDistance) + } + + /** + * Returns the initial centroids for the KMeans algorithm based upon the information in + * parameter + * + * @param data The training data set + * @param parameter Parameter Map containing user parameters + * @return Initial centroids for KMeans clustering + */ + private def init(data: DataSet[Vector], parameter: ParameterMap): DataSet[Seq[LabeledVector]] = { + parameter.get(InitialCentroids) match { + case Some(value) => data.getExecutionEnvironment.fromElements(value) + case None => { + + val k = parameter.get(NumClusters) match{ + case Some(value) => value + case None => throw new RuntimeException("Specify the number of clusters.") + } + val l = parameter.get(OversamplingFactor) match{ + case Some(value) => value + case None => 2 * k // default value + } + val r = parameter.get(KMeansParRounds).get + + val blocks = data.getParallelism + + parameter.get(InitialStrategy) match { + case Some("random") => { + random(data.map(x => (x,1)), k) + } + case Some("kmeans++") => { + kmeans(data.map(x => (x,1)), k, blocks) + } + case Some("kmeans||") => { + parInit(data, k, blocks, l ,r) + } + case default => { + throw new RuntimeException("Specify a valid initialization strategy.") + } + } + } + } + } + + /** + * Pick k centers from data one by one using kmeans|| initialization scheme + * + * The k-means|| algorithm works as described by the original authors + * (http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf): + * + * Given a data set X with |X| points, the k-means|| algorithm proceeds as follows: + * + * 1. Initialize C \leftarrow \{\} + * 2. Let p be a point sampled uniformly at random from X. C \leftarrow C \cup \{p\} + * 3. for i \leftarrow 1 to r + * Let C' be the set of formed by independently sampling every point x in X with probability + * \ell\cdot\frac{d(x,C)}{sigma_nolimits{p \in X }d(p,C)} + * C \leftarrow C \cup C' + * 4. Assign weights to all point c in C as the number of points from X which are closest to c + * 5. Run kmeans++ initialization on the weighted set C and return k centers + * + * @param data Training data set + * @param k Number of clusters + * @param blocks Blocks in the data + * @param oversampling Oversampling rate (\ell) + * @param rounds Number of rounds (r) + * @return Initial centroids + */ + private def parInit( + data: DataSet[Vector], + k: Int, + blocks: Int, + oversampling: Double, + rounds: Int) + : DataSet[Seq[LabeledVector]] = { + // first pick one center randomly + val oversamplingFactor = data.getExecutionEnvironment.fromElements(oversampling) + + val initialCentroids = random(data.map(x => (x,1)), 1).map(x => x.head) + val unionOfSamples = initialCentroids.iterate(rounds){ + currentSet => { + // current cost + val currentCost = data.mapWithBcSet(currentSet){ + (vector, pointSet) => Math.pow(findNearestCentroid(vector, pointSet)._2, 2) + } + val sampledSet = data.filter(new RichFilterFunction[Vector] { + var currentSet: Seq[LabeledVector] = _ + var cost: Double = _ + var rng: Random = _ + var oversamplingFactor: Double = _ + override def open(parameter: Configuration): Unit ={ + currentSet = getRuntimeContext.getBroadcastVariable(PARINIT_SET).asScala + cost = getRuntimeContext.getBroadcastVariable(PARINIT_COST).get(0) + oversamplingFactor = getRuntimeContext.getBroadcastVariable(PARINIT_SAMPLE).get(0) + rng = new Random() + } + override def filter(value: Vector): Boolean = { + rng.nextDouble() < + oversamplingFactor * Math.pow(findNearestCentroid(value, currentSet)._2, 2) / cost + } + }).withBroadcastSet(currentCost, PARINIT_COST) + .withBroadcastSet(currentSet, PARINIT_SET) + .withBroadcastSet(oversamplingFactor, PARINIT_SAMPLE) + + // keep taking unions of independent samples at each step + currentSet.union(sampledSet.map(x => LabeledVector(0, x))) + } + } + + // now assign weights to points in the set + val weightedSample = data.mapWithBcSet(unionOfSamples){ + (vector, sampledSet) => { + val samples = sampledSet.toList + var minDistance: Double = Double.MaxValue + var closestCentroidIndex: Int = -1 + for (i <- 0 to samples.size - 1) { + val distance = EuclideanDistanceMetric().distance(vector, samples(i).vector) + if (distance < minDistance) { + minDistance = distance + closestCentroidIndex = i + } + } + // just assign a label of 1. We'll figure this out later. + (closestCentroidIndex, samples(closestCentroidIndex).vector, 1) + } + }.groupBy(0) + .reduce((a, b) => (a._1, a._2, a._3 + b._3)) + .map(x => (x._2,x._3)) + + // finally, do a kmeans++ on this weighted set + kmeans(weightedSample, k, blocks) + } + + /** + * Randomly initializes centroids from the data. + * Data is considered to be weighted. + * + * @param data Training data set + * @param k Number of centroids to be picked + * @return Initial random centroids + */ + private def random( + data: DataSet[(Vector, Int)], + k: Int) + : DataSet[Seq[LabeledVector]] = { + // we'll sample 10 times as many points as we actually need + // TODO Modify to use the Random Sample Operator as and when added. + + val fraction = data.map(x => 1).reduce(_ + _).map(x => 10 * (k + 0.0) / x) + + val extraSampledSet = data.filter(new RichFilterFunction[(Vector, Int)] { + var rng: Random = _ + var fraction: Double = _ + + override def open(parameters: Configuration): Unit ={ + rng = new Random() + fraction = getRuntimeContext.getBroadcastVariable(RANDOM_FRACTION).get(0) + } + override def filter(value: (Vector,Int)): Boolean = { + rng.nextDouble() < fraction * value._2 + }} + ).withBroadcastSet(fraction, RANDOM_FRACTION).map(x => x._1) + + data.getExecutionEnvironment.fromElements(k).mapWithBcSet(extraSampledSet){ + (required, largeSample) => + val output = Array.ofDim[LabeledVector](required) + for(i<- 0 to required - 1){ + output(i) = LabeledVector(i, largeSample(i)) + } + output.toSeq + } + } + + /** + * Pick k centers from data one by one using kmeans++ initialization scheme + * + * The k-means++ scheme works as described by the original authors in + * [[http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf]] + * + * Given a data set X with |X| points, the k-means++ algorithm proceeds as follows: + * + * 1. Initialize C \leftarrow \{\} + * 2. Let p be a point sampled uniformly at random from X. C \leftarrow C \cup \{p\} + * 3. for i \leftarrow 1 to k-1 + * Choose an x in X with probability \frac{d(x,C)}{sigma_nolimits{p \in X} d(p,C)} where + * d(p,C) denotes the distance of p from its nearest center in C + * 4. Output C as the initial seed points. These can now be used in a KMeans algorithm as initial + * centers + * + * @param data Training data set + * @param k Number of clusters + * @param blocks Blocks of data + * @return Initial centroids + */ + private def kmeans( + data: DataSet[(Vector, Int)], + k: Int, + blocks: Int) + : DataSet[Seq[LabeledVector]] = { + // first pick one center randomly + val initialCentroids = random(data, 1) + initialCentroids.iterate(k - 1){ + currentCentroids => { + // sample one point from each block based on the local probability distribution + val blockSamples = FlinkMLTools.block(data, blocks, Some(ModuloKeyPartitioner)) + .mapWithBcVariable(currentCentroids) { + (block, centroids) => { + val rng = new Random() + // form a cumulative distribution + val distances = Array.ofDim[Double](block.values.length) + distances(0) = + findNearestCentroid(block.values.head._1, centroids)._2 * block.values.head._2 + for (i <- 1 to block.values.length - 1) { + distances(i) = distances(i - 1) + + findNearestCentroid(block.values(i)._1, centroids)._2 * block.values(i)._2 + } + val samplePoint = sampleFromDistribution(rng.nextDouble() * distances.last, distances) + (block.values(samplePoint)._1, distances.last) + } + } + // now sample one point from the block sample + currentCentroids.mapWithBcSet(blockSamples) { + (centroids, blockSample) => { + // find the next label to use + val rng = new Random() + val nextLabel = centroids.map(x => x.label).max + 1 + val blockArray = blockSample.toArray + val blockCostArray = blockArray.map(x => x._2) + val sampleBlock = sampleFromDistribution( + rng.nextDouble() * blockCostArray.last, blockCostArray) + centroids.toList.::(LabeledVector(nextLabel, blockArray(sampleBlock)._1)).toSeq + } + } + } + } + } + + /** + * Finds an index i such that r >= distribution(i - 1) and r < distribution(i) + * + * @param r Value to be searched for + * @param distribution An unscaled cumulative distribution + * @return Index with cumulative probability just more than r + */ + private def sampleFromDistribution(r: Double, distribution: Array[Double]): Int = { + for(i<- 1 to distribution.length - 1){ + if(r >= distribution(i - 1) && r < distribution(i)){ + return i + } + } + 0 + } +} diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/ClusteringData.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/ClusteringData.scala new file mode 100644 index 0000000000000..1f82e65454b38 --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/ClusteringData.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering + +import breeze.linalg.{DenseVector => BreezeDenseVector, Vector => BreezeVector} +import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.ml.math.{DenseVector, Vector} +import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric + +/** + * Trainings- and test-data set for the K-Means implementation + * [[org.apache.flink.ml.clustering.KMeans]]. + */ +object ClusteringData { + + /* + * Number of iterations for the K-Means algorithm. + */ + val iterations = 10 + + /* + * Sequence of initial centroids. + */ + val centroidData: Seq[LabeledVector] = Seq( + LabeledVector(1, DenseVector(-0.1369104662767052, 0.2949172396037093, -0.01070450818187003)), + LabeledVector(2, DenseVector(0.43643950041582885, 0.30117329671833215, 0.20965108353159922)), + LabeledVector(3, DenseVector(0.26011627041438423, 0.22954649683337805, 0.2936286262276151)), + LabeledVector(4, DenseVector(-0.041980932305508145, 0.03116256923634109, 0.31065743174542293)), + LabeledVector(5, DenseVector(0.0984398491976613, -0.21227718242541602, -0.45083084300074255)), + LabeledVector(6, DenseVector(-0.2165269235451111, -0.47142840804338293, -0.02298954070830948)), + LabeledVector(7, DenseVector(-0.0632307695567563, 0.2387221400443612, 0.09416850805771804)), + LabeledVector(8, DenseVector(0.16383680898916775, -0.24586810465119346, 0.08783590589294081)), + LabeledVector(9, DenseVector(-0.24763544645492513, 0.19688995732231254, 0.4520904742796472)), + LabeledVector(10, DenseVector(0.16468044138881932, 0.06259522206982082, 0.12145870313604247)) + + ) + + /* + * 3 Dimensional DenseVectors from a Part of Cosmo-Gas Dataset + * Reference: http://nuage.cs.washington.edu/benchmark/ + */ + val trainingData: Seq[Vector] = Seq( + DenseVector(-0.489811986685, 0.496883004904, -0.483860999346), + DenseVector(-0.485296010971, 0.496421992779, -0.484212994576), + DenseVector(-0.481514006853, 0.496134012938, -0.48508900404), + DenseVector(-0.478542000055, 0.496246010065, -0.486301004887), + DenseVector(-0.475461006165, 0.496093004942, -0.487686008215), + DenseVector(-0.471846997738, 0.496558994055, -0.488242000341), + DenseVector(-0.467496991158, 0.497166007757, -0.48861899972), + DenseVector(-0.463036000729, 0.497680991888, -0.489721000195), + DenseVector(-0.458972990513, 0.4984369874, -0.490575999022), + DenseVector(-0.455772012472, 0.499684005976, -0.491737008095), + DenseVector(-0.453074991703, -0.499433010817, -0.492006987333), + DenseVector(-0.450913995504, -0.499316990376, -0.492769002914), + DenseVector(-0.448724985123, -0.499406009912, -0.493508011103), + DenseVector(-0.44715899229, -0.499680995941, -0.494500011206), + DenseVector(-0.445362001657, -0.499630987644, -0.495151996613), + DenseVector(-0.442811012268, -0.499303996563, -0.495151013136), + DenseVector(-0.439810991287, -0.499332994223, -0.49529799819), + DenseVector(-0.43678098917, -0.499361991882, -0.49545699358), + DenseVector(-0.433919012547, -0.499334007502, -0.495705991983), + DenseVector(-0.43117800355, -0.499345004559, -0.496196985245), + DenseVector(-0.428333997726, -0.499083012342, -0.496385991573), + DenseVector(-0.425300985575, -0.49844199419, -0.496405988932), + DenseVector(-0.421882003546, -0.497743010521, -0.496706992388), + DenseVector(-0.418137013912, -0.497193992138, -0.496524989605), + DenseVector(-0.414458990097, -0.496717989445, -0.49600699544), + DenseVector(-0.411509007215, -0.495965003967, -0.495519012213), + DenseVector(-0.40851598978, -0.49593898654, -0.495027005672), + DenseVector(-0.405084013939, -0.497224003077, -0.494318008423), + DenseVector(-0.402155995369, -0.498420000076, -0.493582010269), + DenseVector(-0.399185985327, -0.499316990376, -0.493566006422), + DenseVector(-0.396214991808, -0.499727994204, -0.494017004967), + DenseVector(-0.393094986677, -0.499821007252, -0.494278013706), + DenseVector(-0.389335989952, -0.499379009008, -0.494480013847), + DenseVector(-0.385125994682, -0.499267995358, -0.494628995657), + DenseVector(-0.380605995655, -0.499545991421, -0.495085000992), + DenseVector(-0.376213997602, -0.499879002571, -0.495617002249), + DenseVector(-0.372996985912, 0.499734997749, -0.496517002583), + DenseVector(-0.368934988976, 0.499749004841, -0.496690988541), + DenseVector(-0.363835990429, -0.499909996986, -0.496495991945), + DenseVector(-0.358395010233, 0.49980199337, -0.49607899785), + DenseVector(-0.353298008442, -0.499940007925, -0.495460003614), + DenseVector(-0.349240005016, -0.499356001616, -0.494697004557), + DenseVector(-0.345212012529, -0.499731004238, -0.494096010923), + DenseVector(-0.341008991003, 0.499749988317, -0.493512988091), + DenseVector(-0.336104005575, 0.498928010464, -0.49247199297), + DenseVector(-0.330855995417, 0.498306006193, -0.491232007742), + DenseVector(-0.32566100359, 0.498154014349, -0.490224003792), + DenseVector(-0.320849001408, 0.498154014349, -0.489493995905), + DenseVector(-0.316397994757, 0.49818199873, -0.488979011774), + DenseVector(-0.311291992664, 0.49848100543, -0.488063007593), + DenseVector(-0.30513599515, 0.498423010111, -0.487619996071), + DenseVector(-0.299059003592, 0.498239010572, -0.486963003874), + DenseVector(-0.295850992203, 0.497961014509, -0.486425995827), + DenseVector(-0.292504996061, 0.49786400795, -0.486220985651), + DenseVector(-0.287795990705, 0.496935009956, -0.486378014088), + DenseVector(-0.282094985247, 0.496926009655, -0.486101001501), + DenseVector(-0.230370000005, -0.423029005527, -0.190435007215), + DenseVector(-0.226144999266, -0.422674000263, -0.190456002951), + DenseVector(-0.221065998077, -0.422462999821, -0.190656006336), + DenseVector(-0.21570199728, -0.421921014786, -0.190736994147), + DenseVector(-0.211145997047, -0.421442002058, -0.190715998411), + DenseVector(-0.207176998258, -0.421692997217, -0.191337004304), + DenseVector(-0.202617004514, -0.421979010105, -0.192610993981), + DenseVector(-0.197987005115, -0.421979010105, -0.1939329952), + DenseVector(-0.193534001708, -0.42171099782, -0.195063993335), + DenseVector(-0.188442006707, -0.421469986439, -0.195659995079), + DenseVector(-0.18351200223, -0.421350002289, -0.196327000856), + DenseVector(-0.178878992796, -0.421176999807, -0.196639999747), + DenseVector(-0.173997998238, -0.420922011137, -0.19677400589), + DenseVector(-0.17026899755, -0.420855998993, -0.196733996272), + DenseVector(-0.166736006737, -0.420551985502, -0.196759000421), + DenseVector(-0.16250500083, -0.420587986708, -0.19698600471), + DenseVector(-0.158608004451, -0.420758008957, -0.196684002876), + DenseVector(-0.154406994581, -0.420715004206, -0.196183994412), + DenseVector(-0.15014000237, -0.420192986727, -0.1962479949), + DenseVector(-0.145583003759, -0.419409006834, -0.196958005428), + DenseVector(-0.141097992659, -0.41882699728, -0.197107002139), + DenseVector(-0.13644400239, -0.418215990067, -0.196890994906), + DenseVector(-0.132035002112, -0.417602986097, -0.196718007326), + DenseVector(-0.128143996, -0.417082995176, -0.196645006537), + DenseVector(-0.124609999359, -0.416640013456, -0.196575000882), + DenseVector(-0.12135899812, -0.416545987129, -0.196443006396), + DenseVector(-0.11831600219, -0.416736006737, -0.196152001619), + DenseVector(-0.114499002695, -0.416723996401, -0.195667997003), + DenseVector(-0.110071003437, -0.416583001614, -0.195353999734), + DenseVector(-0.105696000159, -0.416215986013, -0.195015996695), + DenseVector(-0.101567000151, -0.415634006262, -0.194840997458), + DenseVector(-0.0976777970791, -0.415030002594, -0.194831997156), + DenseVector(-0.0947626978159, -0.414595991373, -0.195255994797), + DenseVector(-0.0925178974867, -0.414178013802, -0.195669993758), + DenseVector(-0.0899709016085, -0.413747012615, -0.195713996887), + DenseVector(-0.0869152024388, -0.413572013378, -0.195683002472), + DenseVector(-0.0834548026323, -0.413212001324, -0.195618003607), + DenseVector(-0.0799069032073, -0.412741005421, -0.195555001497), + DenseVector(-0.0765667036176, -0.412616014481, -0.195696994662), + DenseVector(-0.0730601996183, -0.412665009499, -0.195963993669), + DenseVector(-0.0695542991161, -0.412683993578, -0.196098998189), + DenseVector(-0.0661773011088, -0.412420988083, -0.196201995015), + DenseVector(-0.062273401767, -0.412048995495, -0.196441993117), + DenseVector(-0.05775950104, -0.412072986364, -0.196572005749), + DenseVector(-0.0543152987957, -0.412909001112, -0.196082994342), + DenseVector(-0.0515625998378, -0.413266003132, -0.195540994406), + DenseVector(-0.0482833012938, -0.413659006357, -0.195500999689), + DenseVector(-0.0447212010622, -0.413929998875, -0.195748001337), + DenseVector(-0.0415252000093, -0.413904994726, -0.195501998067), + DenseVector(-0.0375672988594, -0.413911998272, -0.194977998734), + DenseVector(-0.0325004011393, -0.413509994745, -0.194503992796), + DenseVector(-0.0276785008609, -0.412813007832, -0.194422006607), + DenseVector(-0.0232041999698, -0.412286996841, -0.194086000323), + DenseVector(-0.0188629999757, -0.412007004023, -0.193660005927), + DenseVector(-0.0146391997114, -0.411799997091, -0.193238005042), + DenseVector(0.438068985939, -0.423878014088, -0.193721994758), + DenseVector(0.441168993711, -0.423072993755, -0.193834006786), + DenseVector(0.445264995098, -0.422354012728, -0.193957000971), + DenseVector(0.449609994888, -0.421038001776, -0.193471997976), + DenseVector(0.452950000763, -0.419548988342, -0.194122001529), + DenseVector(0.455969005823, -0.418231010437, -0.19481100142), + DenseVector(0.458950012922, -0.417178988457, -0.195061996579), + DenseVector(0.462146013975, -0.416909009218, -0.195247992873), + DenseVector(0.466147005558, -0.417118012905, -0.195875003934), + DenseVector(0.470245987177, -0.417659014463, -0.196611002088), + DenseVector(0.474249005318, -0.41837900877, -0.197458997369), + DenseVector(0.478522986174, -0.419180989265, -0.197898998857), + DenseVector(0.482955992222, -0.419600009918, -0.198066994548), + DenseVector(0.487857013941, -0.419793009758, -0.198157995939), + DenseVector(0.492332011461, -0.420217990875, -0.198266997933), + DenseVector(0.49594399333, -0.421918988228, -0.19885699451), + DenseVector(0.49856698513, -0.421321004629, -0.199806004763), + DenseVector(-0.49766099453, -0.419916987419, -0.200901001692), + DenseVector(-0.493865013123, -0.416572988033, -0.201345995069), + DenseVector(-0.491010010242, -0.417364001274, -0.201638996601), + DenseVector(-0.488465994596, -0.41782400012, -0.202114000916), + DenseVector(-0.486595988274, -0.418282985687, -0.202690005302), + DenseVector(-0.484320014715, -0.418422996998, -0.203033000231), + DenseVector(-0.481851994991, -0.418476998806, -0.202452003956), + DenseVector(-0.479981005192, -0.418624013662, -0.201403006911), + DenseVector(-0.47786000371, -0.418918997049, -0.200885996222), + DenseVector(-0.476034998894, -0.419658988714, -0.200096994638), + DenseVector(-0.473533004522, -0.420367002487, -0.199972003698), + DenseVector(-0.470723986626, -0.421427994967, -0.199258998036), + DenseVector(-0.467518001795, -0.421974986792, -0.199303001165), + DenseVector(-0.463970988989, -0.422064006329, -0.199428007007), + DenseVector(-0.459452986717, -0.422127008438, -0.199609994888), + DenseVector(-0.45430201292, -0.422224998474, -0.199546992779), + DenseVector(-0.44893398881, -0.421568006277, -0.199607998133), + DenseVector(-0.443767011166, -0.421770989895, -0.199814006686), + DenseVector(-0.438787996769, -0.421896994114, -0.199639007449), + DenseVector(-0.43403300643, -0.421761006117, -0.199591994286), + DenseVector(-0.429554998875, -0.421952992678, -0.199662998319), + DenseVector(-0.425689995289, -0.422435998917, -0.200434997678), + DenseVector(-0.422215998173, -0.423157989979, -0.20145599544), + DenseVector(-0.418269991875, -0.423471987247, -0.202366992831), + DenseVector(-0.414126008749, -0.42369300127, -0.203261002898), + DenseVector(-0.411013990641, -0.423835992813, -0.203920006752), + DenseVector(-0.408345997334, -0.423550993204, -0.204453006387), + DenseVector(-0.406082987785, -0.422883987427, -0.204586997628), + DenseVector(-0.403436988592, -0.422601014376, -0.204478994012), + DenseVector(-0.399006009102, -0.423094987869, -0.203880995512), + DenseVector(-0.39403501153, -0.422764986753, -0.202616006136), + DenseVector(-0.389073014259, -0.422423005104, -0.201444000006), + DenseVector(-0.384308993816, -0.422013998032, -0.201012000442), + DenseVector(-0.379889011383, -0.421714991331, -0.201112002134), + DenseVector(-0.375250011683, -0.420976996422, -0.201361998916), + DenseVector(-0.371003001928, -0.420338004827, -0.201533004642), + DenseVector(-0.366775989532, -0.420608013868, -0.201326996088), + DenseVector(-0.362919986248, -0.421036988497, -0.20096899569), + DenseVector(-0.358947008848, -0.421590000391, -0.201083004475) + + ) + + /* + * For reliable checking these expected-vectors are the output of Twister (another iterative + * framework) + * Reference: http://www.iterativemapreduce.org/ + */ + val expectedCentroids = Seq[LabeledVector]( + LabeledVector(1, DenseVector(-0.37971876676276917, 0.4979574657403462, -0.4891930004726923)), + LabeledVector(6, DenseVector(-0.28812266733768305, -0.4380759022409115, -0.2696436452528952)), + LabeledVector(8, DenseVector(0.46770288137823535, -0.4198470028007058, -0.1961898225195882)) + ) + + /* + * Contains points with their expected label. + */ + val testData = Seq[LabeledVector]( + LabeledVector(1, DenseVector(-0.37971876676276917, 0.4979574657403462, -0.4891930004726923)), + LabeledVector(6, DenseVector(-0.28812266733768305, -0.4380759022409115, -0.2696436452528952)), + LabeledVector(8, DenseVector(0.46770288137823535, -0.4198470028007058, -0.1961898225195882)), + LabeledVector(1, DenseVector(-0.4, 0.5, -0.5)), + LabeledVector(6, DenseVector(-0.3, -0.45, -0.27)), + LabeledVector(8, DenseVector(0.48, -0.42, -0.2)), + LabeledVector(1, DenseVector(-0.3, 0.47, -0.4)), + LabeledVector(6, DenseVector(-0.25, -0.4, -0.2)), + LabeledVector(8, DenseVector(0.5, -0.4, -0.25)), + LabeledVector(1, DenseVector(-0.28, 0.6, -0.5)), + LabeledVector(6, DenseVector(-0.2, -0.5, -0.2)), + LabeledVector(8, DenseVector(0.6, -0.4, -0.1)) + ) + + /** Finds the nearest cluster center from the centroid set to the given vector + * Returns the distance to the nearest center and the label of the nearest center + * + */ + def MinClusterDistance( + vec: Vector, + centroids: Seq[LabeledVector]) + : (Double, Double) = { + var minDistance: Double = Double.MaxValue + var closestCentroidLabel: Double = -1 + centroids.foreach { c => + val distance = EuclideanDistanceMetric().distance(vec, c.vector) + if (distance < minDistance) { + minDistance = distance + closestCentroidLabel = c.label + } + } + (minDistance, closestCentroidLabel) + } +} + diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala new file mode 100644 index 0000000000000..4328fe9e3942d --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering + +import org.apache.flink.api.scala._ +import org.apache.flink.ml._ +import org.apache.flink.ml.math +import org.apache.flink.ml.math.DenseVector +import org.apache.flink.test.util.FlinkTestBase +import org.scalatest.{FlatSpec, Matchers} + +class KMeansITSuite extends FlatSpec with Matchers with FlinkTestBase { + + behavior of "The KMeans implementation" + + def fixture = new { + val env = ExecutionEnvironment.getExecutionEnvironment + val kmeans = KMeans(). + setInitialCentroids(ClusteringData.centroidData). + setNumIterations(ClusteringData.iterations) + + val trainingDS = env.fromCollection(ClusteringData.trainingData) + + kmeans.fit(trainingDS) + } + + it should "cluster data points into 'K' cluster centers" in { + val f = fixture + + val centroidsResult = f.kmeans.centroids.get.collect().apply(0) + + val centroidsExpected = ClusteringData.expectedCentroids + + // the sizes must match + centroidsResult.length should be === centroidsExpected.length + + // create a lookup table for better matching + val expectedMap = centroidsExpected map (e => e.label->e.vector.asInstanceOf[DenseVector]) toMap + + // each of the results must be in lookup table + centroidsResult.iterator.foreach(result => { + val expectedVector = expectedMap.get(result.label).get + + // the type must match (not None) + expectedVector shouldBe a [math.DenseVector] + + val expectedData = expectedVector.asInstanceOf[DenseVector].data + val resultData = result.vector.asInstanceOf[DenseVector].data + + // match the individual values of the vector + expectedData zip resultData foreach { + case (expectedVector, entryVector) => + entryVector should be(expectedVector +- 0.00001) + } + }) + } + + it should "predict points to cluster centers" in { + val f = fixture + + val vectorsWithExpectedLabels = ClusteringData.testData + // create a lookup table for better matching + val expectedMap = vectorsWithExpectedLabels map (v => + v.vector.asInstanceOf[DenseVector] -> v.label + ) toMap + + // calculate the vector to cluster mapping on the plain vectors + val plainVectors = vectorsWithExpectedLabels.map(v => v.vector) + val predictedVectors = f.kmeans.predict(f.env.fromCollection(plainVectors)) + + // check if all vectors were labeled correctly + predictedVectors.collect() foreach (result => { + val expectedLabel = expectedMap.get(result._1.asInstanceOf[DenseVector]).get + result._2 should be(expectedLabel) + }) + + } + + it should "initialize k cluster centers randomly" in { + + val env = ExecutionEnvironment.getExecutionEnvironment + val kmeans = KMeans() + .setNumClusters(10) + .setNumIterations(ClusteringData.iterations) + .setInitializationStrategy("random") + + val trainingDS = env.fromCollection(ClusteringData.trainingData) + kmeans.fit(trainingDS) + + println(trainingDS.mapWithBcVariable(kmeans.centroids.get) { + (vector, centroid) => Math.pow(ClusteringData.MinClusterDistance(vector, centroid)._1, 2) + }.reduce(_ + _).collect().toArray.apply(0)) + } + + it should "initialize k cluster centers using kmeans++" in { + + val env = ExecutionEnvironment.getExecutionEnvironment + val kmeans = KMeans() + .setNumClusters(10) + .setNumIterations(ClusteringData.iterations) + .setInitializationStrategy("kmeans++") + + val trainingDS = env.fromCollection(ClusteringData.trainingData) + kmeans.fit(trainingDS) + + println(trainingDS.mapWithBcVariable(kmeans.centroids.get) { + (vector, centroid) => Math.pow(ClusteringData.MinClusterDistance(vector, centroid)._1, 2) + }.reduce(_ + _).collect().toArray.apply(0)) + } + + it should "initialize k cluster using kmeans||" in { + + val env = ExecutionEnvironment.getExecutionEnvironment + val kmeans = KMeans() + .setNumClusters(10) + .setNumIterations(ClusteringData.iterations) + .setInitializationStrategy("kmeans||") + + val trainingDS = env.fromCollection(ClusteringData.trainingData) + kmeans.fit(trainingDS) + + println(trainingDS.mapWithBcVariable(kmeans.centroids.get) { + (vector, centroid) => Math.pow(ClusteringData.MinClusterDistance(vector, centroid)._1, 2) + }.reduce(_ + _).collect().toArray.apply(0)) + } +}