diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/Clustering.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/Clustering.scala index 82ad577b8739c..4510a325d30d8 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/Clustering.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/Clustering.scala @@ -235,4 +235,22 @@ object Clustering { LabeledVector(6, DenseVector(-0.28812266733768305, -0.4380759022409115, -0.2696436452528952)), LabeledVector(8, DenseVector(0.46770288137823535, -0.4198470028007058, -0.1961898225195882)) ) + + /* + * Contains points with their expected label. + */ + val testData = Seq[LabeledVector]( + LabeledVector(1, DenseVector(-0.37971876676276917, 0.4979574657403462, -0.4891930004726923)), + LabeledVector(6, DenseVector(-0.28812266733768305, -0.4380759022409115, -0.2696436452528952)), + LabeledVector(8, DenseVector(0.46770288137823535, -0.4198470028007058, -0.1961898225195882)), + LabeledVector(1, DenseVector(-0.4, 0.5, -0.5)), + LabeledVector(6, DenseVector(-0.3, -0.45, -0.27)), + LabeledVector(8, DenseVector(0.48, -0.42, -0.2)), + LabeledVector(1, DenseVector(-0.3, 0.47, -0.4)), + LabeledVector(6, DenseVector(-0.25, -0.4, -0.2)), + LabeledVector(8, DenseVector(0.5, -0.4, -0.25)), + LabeledVector(1, DenseVector(-0.28, 0.6, -0.5)), + LabeledVector(6, DenseVector(-0.2, -0.5, -0.2)), + LabeledVector(8, DenseVector(0.6, -0.4, -0.1)) + ) } diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala index 8fc160ef8c687..644312e59ab94 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala @@ -24,24 +24,25 @@ import org.apache.flink.ml.math.DenseVector import org.apache.flink.test.util.FlinkTestBase import org.scalatest.{FlatSpec, Matchers} -import scala.collection.mutable - class KMeansITSuite extends FlatSpec with Matchers with FlinkTestBase { behavior of "The KMeans implementation" - it should "data points are clustered into 'K' cluster centers" in { - + def fixture = new { val env = ExecutionEnvironment.getExecutionEnvironment - - val learner = KMeans(). + val kmeans = KMeans(). setInitialCentroids(env.fromCollection(Clustering.centroidData)). setNumIterations(Clustering.iterations) val trainingDS = env.fromCollection(Clustering.trainingData) - val model = learner.fit(trainingDS) - val centroidsResult = model.centroids.collect() + kmeans.fit(trainingDS) + } + + it should "data points are clustered into 'K' cluster centers" in { + val f = fixture + + val centroidsResult = f.kmeans.centroids.get.collect() val centroidsExpected = Clustering.expectedCentroids @@ -67,6 +68,26 @@ class KMeansITSuite extends FlatSpec with Matchers with FlinkTestBase { entryVector should be(expectedVector +- 0.00001) } }) + } + + it should "predict points to cluster centers" in { + val f = fixture + + val vectorsWithExpectedLabels = Clustering.testData + // create a lookup table for better matching + val expectedMap = vectorsWithExpectedLabels map (v => + v.vector.asInstanceOf[DenseVector] -> v.label + ) toMap + + // calculate the vector to cluster mapping on the plain vectors + val plainVectors = vectorsWithExpectedLabels.map(v => v.vector) + val predictedVectors = f.kmeans.predict(f.env.fromCollection(plainVectors)) + + // check if all vectors were labeled correctly + predictedVectors.collect() foreach (result => { + val expectedLabel = expectedMap.get(result.vector.asInstanceOf[DenseVector]).get + result.label should be(expectedLabel) + }) }