From 00fd2c8064d407eb5b47da583362fbd6c6d441cd Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 30 May 2016 16:39:06 +0800 Subject: [PATCH] create pr --- .../org/apache/spark/ml/clustering/LDA.scala | 3 +- .../ml/regression/LinearRegression.scala | 1 - ...lticlassClassificationEvaluatorSuite.scala | 48 +++++++++++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index ec60991af64ff..23960179ca72c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -31,8 +31,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} import org.apache.spark.mllib.impl.PeriodicCheckpointer -import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector, - Vectors => OldVectors} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6be2584785bd3..062902d63a7a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,7 +36,6 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 522f6675d7f46..fa7f05aa8075d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class MulticlassClassificationEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -40,4 +41,51 @@ class MulticlassClassificationEvaluatorSuite test("should support all NumericType labels and not support other types") { MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark) } + + test("Multiclass Classification Evaluator") { + val labelAndScores = spark.createDataFrame(Seq( + (0.0, 0.0), + (1.0, 1.0), + (2.0, 2.0), + (1.0, 2.0), + (0.0, 2.0))).toDF("label", "prediction") + + /** + * Using the following Python code to evaluate metrics. + * + * > from sklearn.metrics import * + * > y_true = [0, 1, 2, 1, 0] + * > y_pred = [0, 1, 2, 2, 2] + * > f1 = f1_score(y_true, y_pred, average='weighted') + * > precision = precision_score(y_true, y_pred, average='micro') + * > recall = recall_score(y_true, y_pred, average='micro') + * > weighted_precision = precision_score(y_true, y_pred, average='weighted') + * > weighted_recall = recall_score(y_true, y_pred, average='weighted') + * > accuracy = accuracy_score(y_true, y_pred) + */ + + // default = weighted f1 + val evaluator = new MulticlassClassificationEvaluator() + assert(evaluator.evaluate(labelAndScores) ~== 0.633333 absTol 0.01) + + // micro precision + evaluator.setMetricName("precision") + assert(evaluator.evaluate(labelAndScores) ~== 0.6 absTol 0.01) + + // micro recall + evaluator.setMetricName("recall") + assert(evaluator.evaluate(labelAndScores) ~== 0.6 absTol 0.01) + + // weighted precision + evaluator.setMetricName("weightedPrecision") + assert(evaluator.evaluate(labelAndScores) ~== 0.866667 absTol 0.01) + + // weighted recall + evaluator.setMetricName("weightedRecall") + assert(evaluator.evaluate(labelAndScores) ~== 0.6 absTol 0.01) + + // accuracy + evaluator.setMetricName("accuracy") + assert(evaluator.evaluate(labelAndScores) ~== 0.6 absTol 0.01) + } }