From 54c03de0792223abd16bf9789e7310fbf34eef3f Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 5 Jun 2015 13:25:26 -0700 Subject: [PATCH 1/6] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline --- .../MulticlassClassificationEvaluator.scala | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala new file mode 100644 index 0000000000000..402900c1a962d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamValidators, Param} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for binary classification, which expects two input columns: score and label. + */ +@Experimental +class MulticlassClassificationEvaluator (override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("mcEval")) + + /** + * param for metric name in evaluation (supports `"f1"` (default)) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("f1")) + new Param(this, "metricName", "metric name in evaluation (f1)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "f1") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new MulticlassMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "f1" => metrics.weightedFMeasure + } + metric + } + +} From eec9865f46a059d1c3a8a5b92a886385da675a27 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 29 Jul 2015 22:41:51 -0700 Subject: [PATCH 2/6] Fix Python Indentation --- python/pyspark/ml/evaluation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 66f3a18567a6b..38510c98f2e35 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -229,8 +229,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio 2.842... """ # a placeholder to make it appear in the generated doc - metricName = Param(Params._dummy(), "metricName", - "metric name in evaluation (f1)") + metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (f1)") @keyword_only def __init__(self, predictionCol="prediction", labelCol="label", From 032d2a3ccf443b25b14690deb59098b79b41a90e Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 08:28:02 -0700 Subject: [PATCH 3/6] fix test --- python/pyspark/ml/evaluation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 38510c98f2e35..1de9318446c72 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -220,13 +220,13 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio """ Evaluator for Multiclass Classification, which expects two input columns: prediction and label. - >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), - ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) ... >>> evaluator = MulticlassClassificationEvaluator(predictionCol="raw") >>> evaluator.evaluate(dataset) - 2.842... + 0.66... """ # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (f1)") From 16115ae2e71edae1ebde8757129d75f329282850 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 19:20:52 -0700 Subject: [PATCH 4/6] code review fixes --- .../MulticlassClassificationEvaluator.scala | 10 +++++-- ...lticlassClassificationEvaluatorSuite.scala | 28 +++++++++++++++++++ python/pyspark/ml/evaluation.py | 16 +++++++---- 3 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 941ebb86bf275..1e2312b250e9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -40,8 +40,10 @@ class MulticlassClassificationEvaluator (override val uid: String) * @group param */ val metricName: Param[String] = { - val allowedParams = ParamValidators.inArray(Array("f1")) - new Param(this, "metricName", "metric name in evaluation (f1)", allowedParams) + val allowedParams = ParamValidators.inArray(Array("f1", "precision", + "recall", "weightedPrecision", "weightedRecall")) + new Param(this, "metricName", "metric name in evaluation " + + "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) } /** @group getParam */ @@ -70,6 +72,10 @@ class MulticlassClassificationEvaluator (override val uid: String) val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure + case "precision" => metrics.precision + case "recall" => metrics.recall + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall } metric } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..6d8412b0b3701 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new MulticlassClassificationEvaluator) + } +} diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 1de9318446c72..fa988ebd4880d 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -222,14 +222,19 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio columns: prediction and label. >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] - >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) ... - >>> evaluator = MulticlassClassificationEvaluator(predictionCol="raw") + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") >>> evaluator.evaluate(dataset) 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + 0.66... """ # a placeholder to make it appear in the generated doc - metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (f1)") + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall)") @keyword_only def __init__(self, predictionCol="prediction", labelCol="label", @@ -241,9 +246,10 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) - #: param for metric name in evaluation (f1) + # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) self.metricName = Param(self, "metricName", - "metric name in evaluation (f1)") + "metric name in evaluation" + " (f1|precision|recall|weightedPrecision|weightedRecall)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") kwargs = self.__init__._input_kwargs From 3f09a85007ab2201c9f186b16c8ed8bec36ea43b Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 19:24:48 -0700 Subject: [PATCH 5/6] cleanup doc --- .../ml/evaluation/MulticlassClassificationEvaluator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 1e2312b250e9c..44f779c1908d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: - * Evaluator for binary classification, which expects two input columns: score and label. + * Evaluator for multiclass classification, which expects two input columns: score and label. */ @Experimental class MulticlassClassificationEvaluator (override val uid: String) @@ -36,7 +36,8 @@ class MulticlassClassificationEvaluator (override val uid: String) def this() = this(Identifiable.randomUID("mcEval")) /** - * param for metric name in evaluation (supports `"f1"` (default)) + * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, + * `"weightedPrecision"`, `"weightedRecall"`) * @group param */ val metricName: Param[String] = { From 9bf4ec75240ec60eb75ba0be514a1c29d5d68afd Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 19:40:13 -0700 Subject: [PATCH 6/6] fix indentation --- python/pyspark/ml/evaluation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index fa988ebd4880d..06e809352225b 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -234,7 +234,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio """ # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", - "metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall)") + "metric name in evaluation " + "(f1|precision|recall|weightedPrecision|weightedRecall)") @keyword_only def __init__(self, predictionCol="prediction", labelCol="label",