-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-7333][MLLIB] Add BinaryClassificationEvaluator to PySpark #5885
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from pyspark.ml.wrapper import JavaEvaluator | ||
from pyspark.ml.param import Param, Params | ||
from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol | ||
from pyspark.ml.util import keyword_only | ||
from pyspark.mllib.common import inherit_doc | ||
|
||
__all__ = ['BinaryClassificationEvaluator'] | ||
|
||
|
||
@inherit_doc | ||
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): | ||
""" | ||
Evaluator for binary classification, which expects two input | ||
columns: rawPrediction and label. | ||
|
||
>>> from pyspark.mllib.linalg import Vectors | ||
>>> scoreAndLabels = sc.parallelize([ | ||
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)]) | ||
>>> rawPredictionAndLabels = scoreAndLabels.map( | ||
... lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1])) | ||
>>> dataset = rawPredictionAndLabels.toDF(["raw", "label"]) | ||
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw") | ||
>>> evaluator.evaluate(dataset) | ||
0.70... | ||
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) | ||
0.83... | ||
""" | ||
|
||
_java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
metricName = Param(Params._dummy(), "metricName", | ||
"metric name in evaluation (areaUnderROC|areaUnderPR)") | ||
|
||
@keyword_only | ||
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", | ||
metricName="areaUnderROC"): | ||
""" | ||
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \ | ||
metricName="areaUnderROC") | ||
""" | ||
super(BinaryClassificationEvaluator, self).__init__() | ||
#: param for metric name in evaluation (areaUnderROC|areaUnderPR) | ||
self.metricName = Param(self, "metricName", | ||
"metric name in evaluation (areaUnderROC|areaUnderPR)") | ||
self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", | ||
metricName="areaUnderROC") | ||
kwargs = self.__init__._input_kwargs | ||
self._set(**kwargs) | ||
|
||
def setMetricName(self, value): | ||
""" | ||
Sets the value of :py:attr:`metricName`. | ||
""" | ||
self.paramMap[self.metricName] = value | ||
return self | ||
|
||
def getMetricName(self): | ||
""" | ||
Gets the value of metricName or its default value. | ||
""" | ||
return self.getOrDefault(self.metricName) | ||
|
||
@keyword_only | ||
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", | ||
metricName="areaUnderROC"): | ||
""" | ||
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \ | ||
metricName="areaUnderROC") | ||
Sets params for binary classification evaluator. | ||
""" | ||
kwargs = self.setParams._input_kwargs | ||
return self._set(**kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
from pyspark.context import SparkContext | ||
from pyspark.sql import SQLContext | ||
globs = globals().copy() | ||
# The small batch size here ensures that we see multiple batches, | ||
# even in these small test examples: | ||
sc = SparkContext("local[2]", "ml.feature tests") | ||
sqlContext = SQLContext(sc) | ||
globs['sc'] = sc | ||
globs['sqlContext'] = sqlContext | ||
(failure_count, test_count) = doctest.testmod( | ||
globs=globs, optionflags=doctest.ELLIPSIS) | ||
sc.stop() | ||
if failure_count: | ||
exit(-1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,6 +165,35 @@ def getPredictionCol(self): | |
return self.getOrDefault(self.predictionCol) | ||
|
||
|
||
class HasRawPredictionCol(Params): | ||
""" | ||
Mixin for param rawPredictionCol: raw prediction column name. | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction column name") | ||
|
||
def __init__(self): | ||
super(HasRawPredictionCol, self).__init__() | ||
#: param for raw prediction column name | ||
self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction column name") | ||
if 'rawPrediction' is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this might be a mistake? You are comparing a string to None. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this code is generated. Please see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, my mistake! |
||
self._setDefault(rawPredictionCol='rawPrediction') | ||
|
||
def setRawPredictionCol(self, value): | ||
""" | ||
Sets the value of :py:attr:`rawPredictionCol`. | ||
""" | ||
self.paramMap[self.rawPredictionCol] = value | ||
return self | ||
|
||
def getRawPredictionCol(self): | ||
""" | ||
Gets the value of rawPredictionCol or its default value. | ||
""" | ||
return self.getOrDefault(self.rawPredictionCol) | ||
|
||
|
||
class HasInputCol(Params): | ||
""" | ||
Mixin for param inputCol: input column name. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
from pyspark.mllib.common import inherit_doc | ||
|
||
|
||
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] | ||
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator'] | ||
|
||
|
||
@inherit_doc | ||
|
@@ -168,3 +168,24 @@ def transform(self, dataset, params={}): | |
for t in self.transformers: | ||
dataset = t.transform(dataset, paramMap) | ||
return dataset | ||
|
||
|
||
class Evaluator(object): | ||
""" | ||
Base class for evaluators that compute metrics from predictions. | ||
""" | ||
|
||
__metaclass__ = ABCMeta | ||
|
||
@abstractmethod | ||
def evaluate(self, dataset, params={}): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should "params" be "paramMap" to match Scala? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python cannot overload methods. So it should be both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realized I didn't get this. What does "it should be both paramMaps and paramMap" mean? |
||
""" | ||
Evaluates the output. | ||
|
||
:param dataset: a dataset that contains labels/observations and | ||
predictions | ||
:param params: an optional param map that overrides embedded | ||
params | ||
:return: metric | ||
""" | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from pyspark import SparkContext | ||
from pyspark.sql import DataFrame | ||
from pyspark.ml.param import Params | ||
from pyspark.ml.pipeline import Estimator, Transformer | ||
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator | ||
from pyspark.mllib.common import inherit_doc | ||
|
||
|
||
|
@@ -147,3 +147,18 @@ def __init__(self, java_model): | |
|
||
def _java_obj(self): | ||
return self._java_model | ||
|
||
|
||
@inherit_doc | ||
class JavaEvaluator(Evaluator, JavaWrapper): | ||
""" | ||
Base class for :py:class:`Evaluator`s that wrap Java/Scala | ||
implementations. | ||
""" | ||
|
||
__metaclass__ = ABCMeta | ||
|
||
def evaluate(self, dataset, params={}): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: "paramMap" |
||
java_obj = self._java_obj() | ||
self._transfer_params_to_java(params, java_obj) | ||
return java_obj.evaluate(dataset._jdf, self._empty_java_param_map()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a note here and in the Scala version that you can pass this the "prediction" column to get simple metrics like accuracy (precision), but that rawPrediction gives you much more info? Or maybe that belongs in the programming guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For binary, we only have
areaUnderROC
andareaUnderPR
for now.