From cb51e6a4038ed71c8d7d153cf093941c413af8d7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 4 May 2015 01:02:42 -0700 Subject: [PATCH] add BinaryClassificationEvaluator in PySpark --- python/pyspark/ml/evaluation.py | 106 ++++++++++++++++++ .../ml/param/_shared_params_code_gen.py | 1 + python/pyspark/ml/param/shared.py | 29 +++++ python/pyspark/ml/pipeline.py | 22 +++- python/pyspark/ml/wrapper.py | 17 ++- python/run-tests | 1 + 6 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 python/pyspark/ml/evaluation.py diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py new file mode 100644 index 0000000000000..35aea39d5cb18 --- /dev/null +++ b/python/pyspark/ml/evaluation.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.wrapper import JavaEvaluator +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol +from pyspark.ml.util import keyword_only + +__all__ = ['BinaryClassificationEvaluator'] + + +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): + """ + Evaluator for binary classification, which expects two input + columns: rawPrediction and label. + + >>> from pyspark.mllib.linalg import Vectors + >>> scoreAndLabels = sc.parallelize([ + ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)]) + >>> rawPredictionAndLabels = scoreAndLabels.map( + ... lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1])) + >>> dataset = rawPredictionAndLabels.toDF(["raw", "label"]) + >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw") + >>> evaluator.evaluate(dataset) + 0.70... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) + 0.83... + """ + + _java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator" + + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)") + + @keyword_only + def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC"): + """ + __init__(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC") + """ + super(BinaryClassificationEvaluator, self).__init__() + #: param for metric name in evaluation (areaUnderROC|areaUnderPR) + self.metricName = Param(self, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)") + self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self.paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC"): + """ + setParams(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC") + Sets params for binary classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6a3192465d66d..c71c823db2c81 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -93,6 +93,7 @@ def get$Name(self): ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), + ("rawPredictionCol", "raw prediction column name", "'rawPrediction'"), ("inputCol", "input column name", None), ("outputCol", "output column name", None), ("numFeatures", "number of features", None)] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 13b6749998ad0..4f243844f8caa 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -165,6 +165,35 @@ def getPredictionCol(self): return self.getOrDefault(self.predictionCol) +class HasRawPredictionCol(Params): + """ + Mixin for param rawPredictionCol: raw prediction column name. + """ + + # a placeholder to make it appear in the generated doc + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction column name") + + def __init__(self): + super(HasRawPredictionCol, self).__init__() + #: param for raw prediction column name + self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction column name") + if 'rawPrediction' is not None: + self._setDefault(rawPredictionCol='rawPrediction') + + def setRawPredictionCol(self, value): + """ + Sets the value of :py:attr:`rawPredictionCol`. + """ + self.paramMap[self.rawPredictionCol] = value + return self + + def getRawPredictionCol(self): + """ + Gets the value of rawPredictionCol or its default value. + """ + return self.getOrDefault(self.rawPredictionCol) + + class HasInputCol(Params): """ Mixin for param inputCol: input column name. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 7c1ec3026da6f..69a5f21f53996 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -22,7 +22,7 @@ from pyspark.mllib.common import inherit_doc -__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator'] @inherit_doc @@ -168,3 +168,23 @@ def transform(self, dataset, params={}): for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset + + +class Evaluator(object): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def evaluate(self, dataset, params={}): + """ + Evaluates the output. + :param dataset: a dataset that contains labels/observations + and predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + raise NotImplementedError() diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 394f23c5e9b12..73741c4b40dfb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,7 +20,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.ml.pipeline import Estimator, Transformer, Evaluator from pyspark.mllib.common import inherit_doc @@ -147,3 +147,18 @@ def __init__(self, java_model): def _java_obj(self): return self._java_model + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def evaluate(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.evaluate(dataset._jdf, self._empty_java_param_map()) diff --git a/python/run-tests b/python/run-tests index 0e0eee3564e7c..f9ca26467f17e 100755 --- a/python/run-tests +++ b/python/run-tests @@ -100,6 +100,7 @@ function run_ml_tests() { run_test "pyspark/ml/classification.py" run_test "pyspark/ml/tuning.py" run_test "pyspark/ml/tests.py" + run_test "pyspark/ml/evaluation.py" } function run_streaming_tests() {