Skip to content

Commit

Permalink
[FLINK-1731] [ml] Migrated K-Means implementation to new ml pipeline …
Browse files Browse the repository at this point in the history
…interfaces
  • Loading branch information
FGoessler committed Jun 3, 2015
1 parent d6a9e71 commit 83d4aba
Showing 1 changed file with 78 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.flink.ml.common.{LabeledVector, _}
import org.apache.flink.ml.math.Breeze._
import org.apache.flink.ml.math.Vector
import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric
import org.apache.flink.ml.pipeline._

import scala.collection.JavaConverters._

Expand All @@ -34,10 +35,8 @@ import scala.collection.JavaConverters._
* Implements the KMeans algorithm which calculates cluster centroids based on set of training data
* points and a set of k initial centroids.
*
* [[org.apache.flink.ml.clustering.KMeans]] is a [[org.apache.flink.ml.common.Learner]] which
* needs to be trained on a set of data points and emits a
* [[org.apache.flink.ml.clustering.KMeansModel]] which is a
* [[org.apache.flink.ml.common.Transformer]] to assign new points to the learned cluster centroids.
* [[KMeans]] is a [[Predictor]] which needs to be trained on a set of data points and can then be
* used to assign new points to the learned cluster centroids.
*
* The KMeans algorithm works as described on Wikipedia
* (http://en.wikipedia.org/wiki/K-means_clustering):
Expand Down Expand Up @@ -74,33 +73,39 @@ import scala.collection.JavaConverters._
* .setInitialCentroids(initialCentroids)
* .setNumIterations(10)
*
* val model = kmeans.fit(trainingDS)
* kmeans.fit(trainingDS)
*
* val testDS: DataSet[Vector] = env.fromCollection(Clustering.testData)
* // getting the computed centroids
* val centroidsResult = kmeans.centroids.get.collect()
*
* val clusters: DataSet[LabeledVector] = model.transform(testDS)
* // get matching clusters for new points
* val testDS: DataSet[Vector] = env.fromCollection(Clustering.testData)
* val clusters: DataSet[LabeledVector] = kmeans.predict(testDS)
* }}}
*
* =Parameters=
*
* - [[KMeans.NumIterations]]:
* - [[org.apache.flink.ml.clustering.KMeans.NumIterations]]:
* Defines the number of iterations to recalculate the centroids of the clusters. As it
* is a heuristic algorithm, there is no guarantee that it will converge to the global optimum. The
* centroids of the clusters and the reassignment of the data points will be repeated till the
* given number of iterations is reached.
* (Default value: '''10''')
*
* - [[KMeans.InitialCentroids]]:
* - [[org.apache.flink.ml.clustering.KMeans.InitialCentroids]]:
* Defines the initial k centroids of the k clusters. They are used as start off point of the
* algorithm for clustering the data set. The centroids are recalculated as often as set in
* [[KMeans.NumIterations]]. The choice of the initial centroids mainly affects the outcome of the
* algorithm.
* [[org.apache.flink.ml.clustering.KMeans.NumIterations]]. The choice of the initial centroids
* mainly affects the outcome of the algorithm.
*
*/
class KMeans extends Learner[Vector, KMeansModel] with Serializable {
class KMeans extends Predictor[KMeans] {

import KMeans._

/** Stores the learned clusters after the fit operation */
var centroids: Option[DataSet[LabeledVector]] = None

/**
* Sets the number of iterations.
*
Expand All @@ -124,41 +129,11 @@ class KMeans extends Learner[Vector, KMeansModel] with Serializable {
this
}

/**
* Iteratively computes centroids that match the given input DataSet by adjusting the given
* initial centroids.
*
* @param input Training data set
* @param fitParameters Parameter values
* @return Trained KMeans Model which represents the final centroids.
*/
override def fit(input: DataSet[Vector], fitParameters: ParameterMap): KMeansModel = {
val resultingParameters = this.parameters ++ fitParameters

val centroids: DataSet[LabeledVector] = resultingParameters.get(InitialCentroids).get
val numIterations: Int = resultingParameters.get(NumIterations).get

val finalCentroids = centroids.iterate(numIterations) { currentCentroids =>
val newCentroids: DataSet[LabeledVector] = input
.map(new SelectNearestCenterMapper).withBroadcastSet(currentCentroids, CENTROIDS)
.map(x => (x.label, x.vector, 1.0)).withForwardedFields("label->_1; vector->_2")
.groupBy(x => x._1)
.reduce((p1, p2) => (p1._1, (p1._2.asBreeze + p2._2.asBreeze).fromBreeze, p1._3 + p2._3))
.withForwardedFields("_1")
.map(x => LabeledVector(x._1, (x._2.asBreeze :/ x._3).fromBreeze))
.withForwardedFields("_1->label")

newCentroids
}

KMeansModel(finalCentroids)
}

}

/**
* Companion object of KMeans. Contains convenience functions and the parameter type definitions
* of the algorithm.
* Companion object of KMeans. Contains convenience functions, the parameter type definitions
* of the algorithm and the [[FitOperation]] & [[PredictOperation]].
*/
object KMeans {
val CENTROIDS = "centroids"
Expand All @@ -171,27 +146,71 @@ object KMeans {
val defaultValue = None
}

// ========================================== Factory methods ====================================

def apply(): KMeans = {
new KMeans()
}

}
// ========================================== Operations =========================================

/**
* The resulting model of final centroids after the KMeans algorithm. Can be used to determine to
* which centroid a vector belongs.
*
* @param centroids The learned centroids based on the training data.
*/

case class KMeansModel(centroids: DataSet[LabeledVector]) extends Transformer[Vector, LabeledVector]
with Serializable {

import KMeans._
/**
* [[PredictOperation]] for vector types. The result type is a [[LabeledVector]].
*/
implicit def predictValues = {
new PredictOperation[KMeans, Vector, LabeledVector] {
override def predict(
instance: KMeans,
predictParameters: ParameterMap,
input: DataSet[Vector])
: DataSet[LabeledVector] = {

instance.centroids match {
case Some(centroids) => {
input.map(new SelectNearestCenterMapper).withBroadcastSet(centroids, CENTROIDS)
}

case None => {
throw new RuntimeException("The KMeans model has not been trained. Call first fit" +
"before calling the predict operation.")
}
}
}
}
}

override def transform(input: DataSet[Vector], parameters: ParameterMap):
DataSet[LabeledVector] = {
input.map(new SelectNearestCenterMapper).withBroadcastSet(centroids, CENTROIDS)
/**
* [[FitOperation]] which iteratively computes centroids that match the given input DataSet by
* adjusting the given initial centroids.
*/
implicit def fitKMeans = {
new FitOperation[KMeans, Vector] {
override def fit(
instance: KMeans,
fitParameters: ParameterMap,
input: DataSet[Vector])
: Unit = {
val resultingParameters = instance.parameters ++ fitParameters

val centroids: DataSet[LabeledVector] = resultingParameters.get(InitialCentroids).get
val numIterations: Int = resultingParameters.get(NumIterations).get

val finalCentroids = centroids.iterate(numIterations) { currentCentroids =>
val newCentroids: DataSet[LabeledVector] = input
.map(new SelectNearestCenterMapper).withBroadcastSet(currentCentroids, CENTROIDS)
.map(x => (x.label, x.vector, 1.0)).withForwardedFields("label->_1; vector->_2")
.groupBy(x => x._1)
.reduce((p1, p2) => (p1._1,(p1._2.asBreeze + p2._2.asBreeze).fromBreeze, p1._3 + p2._3))
.withForwardedFields("_1")
.map(x => LabeledVector(x._1, (x._2.asBreeze :/ x._3).fromBreeze))
.withForwardedFields("_1->label")

newCentroids
}

instance.centroids = Some(finalCentroids)
}
}
}

}
Expand Down

0 comments on commit 83d4aba

Please sign in to comment.