Skip to content

Commit

Permalink
[FLINK-1731] [ml] Implementation of K-Means
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschrott authored and FGoessler committed May 20, 2015
1 parent 912f8d9 commit 71aa47b
Showing 1 changed file with 228 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.clustering

import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.common.{LabeledVector, _}
import org.apache.flink.ml.math.Breeze._
import org.apache.flink.ml.math.Vector
import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric

import scala.collection.JavaConverters._


/**
* Implements the KMeans algorithm which calculates cluster centroids based on set of training data
* points and a set of k initial centroids.
*
* [[org.apache.flink.ml.clustering.KMeans]] is a [[org.apache.flink.ml.common.Learner]] which
* needs to be trained on a set of data points and emits a
* [[org.apache.flink.ml.clustering.KMeansModel]] which is a
* [[org.apache.flink.ml.common.Transformer]] to assign new points to the learned cluster centroids.
*
* The KMeans algorithm works as described on Wikipedia
* (http://en.wikipedia.org/wiki/K-means_clustering):
*
* Given an initial set of k means m1(1),…,mk(1) (see below), the algorithm proceeds by alternating
* between two steps:
*
* ===Assignment step:===
*
* Assign each observation to the cluster whose mean yields the least within-cluster sum of
* squares (WCSS). Since the sum of squares is the squared Euclidean distance, this is intuitively
* the "nearest" mean. (Mathematically, this means partitioning the observations according to the
* Voronoi diagram generated by the means).
*
* `S_i^(t) = { x_p : || x_p - m_i^(t) ||^2 ≤ || x_p - m_j^(t) ||^2 \forall j, 1 ≤ j ≤ k}`,
* where each `x_p` is assigned to exactly one `S^{(t)}`, even if it could be assigned to two or
* more of them.
*
* ===Update step:===
*
* Calculate the new means to be the centroids of the observations in the new clusters.
*
* `m^{(t+1)}_i = ( 1 / |S^{(t)}_i| ) \sum_{x_j \in S^{(t)}_i} x_j`
*
* Since the arithmetic mean is a least-squares estimator, this also minimizes the within-cluster
* sum of squares (WCSS) objective.
*
* @example
* {{{
* val trainingDS: DataSet[Vector] = env.fromCollection(Clustering.trainingData)
* val initialCentroids: DataSet[LabledVector] = env.fromCollection(Clustering.initCentroids)
*
* val kmeans = KMeans()
* .setInitialCentroids(initialCentroids)
* .setNumIterations(10)
*
* val model = kmeans.fit(trainingDS)
*
* val testDS: DataSet[Vector] = env.fromCollection(Clustering.testData)
*
* val clusters: DataSet[LabeledVector] = model.transform(testDS)
* }}}
*
* =Parameters=
*
* - [[KMeans.NumIterations]]:
* Defines the number of iterations to recalculate the centroids of the clusters. As it
* is a heuristic algorithm, there is no guarantee that it will converge to the global optimum. The
* centroids of the clusters and the reassignment of the data points will be repeated till the
* given number of iterations is reached.
* (Default value: '''10''')
*
* - [[KMeans.InitialCentroids]]:
* Defines the initial k centroids of the k clusters. They are used as start off point of the
* algorithm for clustering the data set. The centroids are recalculated as often as set in
* [[KMeans.NumIterations]]. The choice of the initial centroids mainly affects the outcome of the
* algorithm.
*
*/
class KMeans extends Learner[Vector, KMeansModel] with Serializable {

import KMeans._

/**
* Sets the number of iterations.
*
* @param numIterations
* @return itself
*/
def setNumIterations(numIterations: Int): KMeans = {
parameters.add(NumIterations, numIterations)
this
}

/**
* Sets the initial centroids on which the algorithm will start computing.
* These points should depend on the data and significantly influence the resulting centroids.
*
* @param initialCentroids A sequence of labeled vectors.
* @return itself
*/
def setInitialCentroids(initialCentroids: DataSet[LabeledVector]): KMeans = {
parameters.add(InitialCentroids, initialCentroids)
this
}

/**
* Iteratively computes centroids that match the given input DataSet by adjusting the given
* initial centroids.
*
* @param input Training data set
* @param fitParameters Parameter values
* @return Trained KMeans Model which represents the final centroids.
*/
override def fit(input: DataSet[Vector], fitParameters: ParameterMap): KMeansModel = {
val resultingParameters = this.parameters ++ fitParameters

val centroids: DataSet[LabeledVector] = resultingParameters.get(InitialCentroids).get
val numIterations: Int = resultingParameters.get(NumIterations).get

val finalCentroids = centroids.iterate(numIterations) { currentCentroids =>
val newCentroids: DataSet[LabeledVector] = input
.map(new SelectNearestCenterMapper).withBroadcastSet(currentCentroids, CENTROIDS)
.map(x => (x.label, x.vector, 1.0)).withForwardedFields("label->_1; vector->_2")
.groupBy(x => x._1)
.reduce((p1, p2) => (p1._1, (p1._2.asBreeze + p2._2.asBreeze).fromBreeze, p1._3 + p2._3))
.withForwardedFields("_1")
.map(x => LabeledVector(x._1, (x._2.asBreeze :/ x._3).fromBreeze))
.withForwardedFields("_1->label")

newCentroids
}

KMeansModel(finalCentroids)
}

}

/**
* Companion object of KMeans. Contains convenience functions and the parameter type definitions
* of the algorithm.
*/
object KMeans {
val CENTROIDS = "centroids"

case object NumIterations extends Parameter[Int] {
val defaultValue = Some(10)
}

case object InitialCentroids extends Parameter[DataSet[LabeledVector]] {
val defaultValue = None
}

def apply(): KMeans = {
new KMeans()
}

}

/**
* The resulting model of final centroids after the KMeans algorithm. Can be used to determine to
* which centroid a vector belongs.
*
* @param centroids The learned centroids based on the training data.
*/

case class KMeansModel(centroids: DataSet[LabeledVector]) extends Transformer[Vector, LabeledVector]
with Serializable {

import KMeans._

override def transform(input: DataSet[Vector], parameters: ParameterMap):
DataSet[LabeledVector] = {
input.map(new SelectNearestCenterMapper).withBroadcastSet(centroids, CENTROIDS)
}

}

/**
* Converts a given vector into a labeled vector where the label denotes the label of the closest
* centroid.
*/
@ForwardedFields(Array("*->vector"))
final class SelectNearestCenterMapper extends RichMapFunction[Vector, LabeledVector] {

import KMeans._

private var centroids: Traversable[LabeledVector] = null

/** Reads the centroid values from a broadcast variable into a collection. */
override def open(parameters: Configuration) {
centroids = getRuntimeContext.getBroadcastVariable[LabeledVector](CENTROIDS).asScala
}

def map(v: Vector): LabeledVector = {
var minDistance: Double = Double.MaxValue
var closestCentroidLabel: Double = -1
centroids.foreach(centroid => {
val distance = EuclideanDistanceMetric().distance(v, centroid.vector)
if (distance < minDistance) {
minDistance = distance
closestCentroidLabel = centroid.label
}
})
LabeledVector(closestCentroidLabel, v)
}

}

0 comments on commit 71aa47b

Please sign in to comment.