From fc6b5c18c0eb3e1983a2ea51a3c41c62d1c383a0 Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Wed, 27 May 2015 20:08:19 +0200 Subject: [PATCH] [FLINK-2104] [ml] Fixes problem with type inference for fallback implicits where Nothing is not correctly treated (see SI-1570) --- .../apache/flink/ml/pipeline/Predictor.scala | 13 ++-- .../flink/ml/pipeline/Transformer.scala | 13 ++-- .../flink/ml/pipeline/PipelineITSuite.scala | 61 +++++++++++++++++-- 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala index c0e66a077e0b4..8a6b2040fd786 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala @@ -67,24 +67,21 @@ object Predictor{ * * @tparam Self Type of the [[Predictor]] * @tparam Testing Type of the testing data - * @tparam Prediction Type of the predicted data * @return */ - implicit def fallbackPredictOperation[Self: ClassTag, Testing: ClassTag, Prediction: ClassTag] - : PredictOperation[Self, Testing, Prediction] = { - new PredictOperation[Self, Testing, Prediction] { + implicit def fallbackPredictOperation[Self: ClassTag, Testing: ClassTag] + : PredictOperation[Self, Testing, Any] = { + new PredictOperation[Self, Testing, Any] { override def predict( instance: Self, predictParameters: ParameterMap, input: DataSet[Testing]) - : DataSet[Prediction] = { + : DataSet[Any] = { val self = implicitly[ClassTag[Self]] val testing = implicitly[ClassTag[Testing]] - val prediction = implicitly[ClassTag[Prediction]] throw new RuntimeException("There is no PredictOperation defined for " + self.runtimeClass + - " which takes a DataSet[" + testing.runtimeClass + "] as input and returns a DataSet[" + - prediction.runtimeClass + "]") + " which takes a DataSet[" + testing.runtimeClass + "] as input.") } } } diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala index 02360bcfeb686..7e2c744748961 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala @@ -137,27 +137,24 @@ object Transformer{ * * @tparam Self Type of the [[Transformer]] for which the [[TransformOperation]] is defined * @tparam IN Input data type of the [[TransformOperation]] - * @tparam OUT Output data type of the [[TransformOperation]] * @return */ implicit def fallbackTransformOperation[ Self: ClassTag, - IN: ClassTag, - OUT: ClassTag] - : TransformOperation[Self, IN, OUT] = { - new TransformOperation[Self, IN, OUT] { + IN: ClassTag] + : TransformOperation[Self, IN, Any] = { + new TransformOperation[Self, IN, Any] { override def transform( instance: Self, transformParameters: ParameterMap, input: DataSet[IN]) - : DataSet[OUT] = { + : DataSet[Any] = { val self = implicitly[ClassTag[Self]] val in = implicitly[ClassTag[IN]] - val out = implicitly[ClassTag[OUT]] throw new RuntimeException("There is no TransformOperation defined for " + self.runtimeClass + " which takes a DataSet[" + in.runtimeClass + - "] as input and transforms it into a DataSet[" + out.runtimeClass + "]") + "] as input.") } } } diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala index 9909a18ef89b4..3597ec139b47c 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala @@ -22,7 +22,8 @@ import breeze.linalg import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer import org.apache.flink.api.scala._ -import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.ml.classification.CoCoA +import org.apache.flink.ml.common.{ParameterMap, LabeledVector} import org.apache.flink.ml.math._ import org.apache.flink.ml.preprocessing.{PolynomialFeatures, StandardScaler} import org.apache.flink.ml.regression.MultipleLinearRegression @@ -86,16 +87,42 @@ class PipelineITSuite extends FlatSpec with Matchers with FlinkTestBase { val vData = List(DenseVector(1.0, 2.0, 3.0), DenseVector(2.0, 3.0, 4.0)) val vectorData = env.fromCollection(vData) + val labeledVectors = List(LabeledVector(1.0, DenseVector(1.0, 2.0)), + LabeledVector(2.0, DenseVector(2.0, 3.0)), + LabeledVector(3.0, DenseVector(3.0, 4.0))) + val labeledData = env.fromCollection(labeledVectors) + val doubles = List(1.0, 2.0, 3.0) + val doubleData = env.fromCollection(doubles) val pipeline = scaler.chainPredictor(mlr) - val exception = intercept[RuntimeException] { + val exceptionFit = intercept[RuntimeException] { pipeline.fit(vectorData) } - exception.getMessage should equal("There is no FitOperation defined for class org.apache." + + exceptionFit.getMessage should equal("There is no FitOperation defined for class org.apache." + "flink.ml.regression.MultipleLinearRegression which trains on a " + "DataSet[class org.apache.flink.ml.math.DenseVector]") + + // fit the pipeline so that the StandardScaler won't fail when predict is called on the pipeline + pipeline.fit(labeledData) + + // make sure that we have TransformOperation[StandardScaler, Double, Double] + implicit val standardScalerDoubleTransform = + new TransformOperation[StandardScaler, Double, Double] { + override def transform(instance: StandardScaler, transformParameters: ParameterMap, + input: DataSet[Double]): DataSet[Double] = { + input + } + } + + val exceptionPredict = intercept[RuntimeException] { + pipeline.predict(doubleData) + } + + exceptionPredict.getMessage should equal("There is no PredictOperation defined for class " + + "org.apache.flink.ml.regression.MultipleLinearRegression which takes a " + + "DataSet[double] as input.") } it should "throw an exception when the input data is not supported" in { @@ -109,12 +136,19 @@ class PipelineITSuite extends FlatSpec with Matchers with FlinkTestBase { val pipeline = scaler.chainTransformer(polyFeatures) - val exception = intercept[RuntimeException] { + val exceptionFit = intercept[RuntimeException] { pipeline.fit(doubleData) } - exception.getMessage should equal("There is no FitOperation defined for class org.apache." + + exceptionFit.getMessage should equal("There is no FitOperation defined for class org.apache." + "flink.ml.preprocessing.StandardScaler which trains on a DataSet[double]") + + val exceptionTransform = intercept[RuntimeException] { + pipeline.transform(doubleData) + } + + exceptionTransform.getMessage should equal("There is no TransformOperation defined for class " + + "org.apache.flink.ml.preprocessing.StandardScaler which takes a DataSet[double] as input.") } it should "support multiple transformers and a predictor" in { @@ -146,4 +180,21 @@ class PipelineITSuite extends FlatSpec with Matchers with FlinkTestBase { weightVector._2 should be (1.3131727 +- 0.01) } + + it should "throw an exception when the input data is not supported by a predictor" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val data = List(1.0, 2.0, 3.0) + val doubleData = env.fromCollection(data) + + val svm = CoCoA() + + intercept[RuntimeException] { + svm.fit(doubleData) + } + + intercept[RuntimeException] { + svm.predict(doubleData) + } + } }