From 09d7fd5c57dfe6ea0d04e0254b3a1826e130f5e6 Mon Sep 17 00:00:00 2001 From: Thomas FOURNIER Date: Sun, 23 Oct 2016 16:04:14 +0200 Subject: [PATCH] [Flink-4865] [ml] Add EvaluateDataSet operation for LabeledVector This closes #2684. --- .../apache/flink/ml/pipeline/Predictor.scala | 52 ++++++++++++++++++- .../flink/ml/classification/SVMITSuite.scala | 26 ++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala index 9d11cff9e933c..1a7fd1a9e7d79 100644 --- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.ml._ -import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters} +import org.apache.flink.ml.common.{LabeledVector, FlinkMLTools, ParameterMap, WithParameters} /** Predictor trait for Flink's pipeline operators. * @@ -172,6 +172,56 @@ object Predictor { } } } + + /** Specific [[EvaluateDataSetOperation]] which takes a [[PredictOperation]] to calculate a tuple + * of true label value and predicted label value, from a DataSet[LabeledVector]. + * + * Note:This implementation differs from [[defaultEvaluateDataSetOperation]] because it can + * evaluate a dataSet of LabeledVector (and not a dataset of tuples (Vector,Double)). + * + * @param predictOperation + * @param testingTypeInformation + * @param predictionValueTypeInformation + * @tparam Instance + * @tparam Model + * @tparam FlinkVector + * @tparam Double + * @return + */ + implicit def labeledVectorEvaluateDataSetOperation[ + Instance <: Estimator[Instance], + Model, + FlinkVector, + Double]( + implicit predictOperation: PredictOperation[Instance, Model, + FlinkVector, Double], + testingTypeInformation: TypeInformation[FlinkVector], + predictionValueTypeInformation: TypeInformation[Double]) + : EvaluateDataSetOperation[Instance, LabeledVector, Double] = { + new EvaluateDataSetOperation[Instance, LabeledVector, Double] { + override def evaluateDataSet( + instance: Instance, + evaluateParameters: ParameterMap, + testing: DataSet[LabeledVector]) + : DataSet[(Double, Double)] = { + val resultingParameters = instance.parameters ++ evaluateParameters + val model = predictOperation.getModel(instance, resultingParameters) + + implicit val resultTypeInformation = + createTypeInformation[(FlinkVector, Double)] + + testing.mapWithBcVariable(model) { + (element, model) => { + (element.label.asInstanceOf[Double], + predictOperation.predict(element.vector.asInstanceOf[FlinkVector], + model)) + } + } + } + } + } + + } /** Type class for the predict operation of [[Predictor]]. This predict operation works on DataSets. diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala index af17451d066a3..d642316c740d0 100644 --- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala @@ -51,6 +51,32 @@ class SVMITSuite extends FlatSpec with Matchers with FlinkTestBase { } } + it should "evaluate with LabeledDataPoint" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val svm = SVM( ). + setBlocks( env.getParallelism ). + setIterations( 100 ). + setLocalIterations( 100 ). + setRegularization( 0.002 ). + setStepsize( 0.1 ). + setSeed( 0 ) + + val trainingDS = env.fromCollection( Classification.trainingData ) + + val test = trainingDS + + svm.fit( trainingDS ) + + val predictionPairs = svm.evaluate( test ) + + val absoluteErrorSum = predictionPairs.collect( ).map { + case (truth, prediction) => Math.abs( truth - prediction ) + }.sum + + absoluteErrorSum should be < 15.0 + } + it should "make (mostly) correct predictions" in { val env = ExecutionEnvironment.getExecutionEnvironment