Skip to content

Commit

Permalink
[FLINK-1731] [ml] adjusted unit test for KMeans for the new ml pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
FGoessler committed Jun 24, 2015
1 parent a351b79 commit 292ec0b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,22 @@ object Clustering {
LabeledVector(6, DenseVector(-0.28812266733768305, -0.4380759022409115, -0.2696436452528952)),
LabeledVector(8, DenseVector(0.46770288137823535, -0.4198470028007058, -0.1961898225195882))
)

/*
* Contains points with their expected label.
*/
val testData = Seq[LabeledVector](
LabeledVector(1, DenseVector(-0.37971876676276917, 0.4979574657403462, -0.4891930004726923)),
LabeledVector(6, DenseVector(-0.28812266733768305, -0.4380759022409115, -0.2696436452528952)),
LabeledVector(8, DenseVector(0.46770288137823535, -0.4198470028007058, -0.1961898225195882)),
LabeledVector(1, DenseVector(-0.4, 0.5, -0.5)),
LabeledVector(6, DenseVector(-0.3, -0.45, -0.27)),
LabeledVector(8, DenseVector(0.48, -0.42, -0.2)),
LabeledVector(1, DenseVector(-0.3, 0.47, -0.4)),
LabeledVector(6, DenseVector(-0.25, -0.4, -0.2)),
LabeledVector(8, DenseVector(0.5, -0.4, -0.25)),
LabeledVector(1, DenseVector(-0.28, 0.6, -0.5)),
LabeledVector(6, DenseVector(-0.2, -0.5, -0.2)),
LabeledVector(8, DenseVector(0.6, -0.4, -0.1))
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,25 @@ import org.apache.flink.ml.math.DenseVector
import org.apache.flink.test.util.FlinkTestBase
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable

class KMeansITSuite extends FlatSpec with Matchers with FlinkTestBase {

behavior of "The KMeans implementation"

it should "data points are clustered into 'K' cluster centers" in {

def fixture = new {
val env = ExecutionEnvironment.getExecutionEnvironment

val learner = KMeans().
val kmeans = KMeans().
setInitialCentroids(env.fromCollection(Clustering.centroidData)).
setNumIterations(Clustering.iterations)

val trainingDS = env.fromCollection(Clustering.trainingData)

val model = learner.fit(trainingDS)
val centroidsResult = model.centroids.collect()
kmeans.fit(trainingDS)
}

it should "data points are clustered into 'K' cluster centers" in {
val f = fixture

val centroidsResult = f.kmeans.centroids.get.collect()

val centroidsExpected = Clustering.expectedCentroids

Expand All @@ -67,6 +68,26 @@ class KMeansITSuite extends FlatSpec with Matchers with FlinkTestBase {
entryVector should be(expectedVector +- 0.00001)
}
})
}

it should "predict points to cluster centers" in {
val f = fixture

val vectorsWithExpectedLabels = Clustering.testData
// create a lookup table for better matching
val expectedMap = vectorsWithExpectedLabels map (v =>
v.vector.asInstanceOf[DenseVector] -> v.label
) toMap

// calculate the vector to cluster mapping on the plain vectors
val plainVectors = vectorsWithExpectedLabels.map(v => v.vector)
val predictedVectors = f.kmeans.predict(f.env.fromCollection(plainVectors))

// check if all vectors were labeled correctly
predictedVectors.collect() foreach (result => {
val expectedLabel = expectedMap.get(result.vector.asInstanceOf[DenseVector]).get
result.label should be(expectedLabel)
})

}

Expand Down

0 comments on commit 292ec0b

Please sign in to comment.