Skip to content

Commit

Permalink
[SPARK-7333] [MLLIB] Add BinaryClassificationEvaluator to PySpark
Browse files Browse the repository at this point in the history
This PR adds `BinaryClassificationEvaluator` to Python ML Pipelines API, which is a simple wrapper of the Scala implementation. oefirouz

Author: Xiangrui Meng <meng@databricks.com>

Closes #5885 from mengxr/SPARK-7333 and squashes the following commits:

25d7451 [Xiangrui Meng] fix tests in python 3
babdde7 [Xiangrui Meng] fix doc
cb51e6a [Xiangrui Meng] add BinaryClassificationEvaluator in PySpark

(cherry picked from commit ee374e8)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
mengxr committed May 5, 2015
1 parent 598902b commit dfb6bfc
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 3 deletions.
16 changes: 16 additions & 0 deletions python/docs/pyspark.ml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ pyspark.ml.classification module
:members:
:undoc-members:
:inherited-members:

pyspark.ml.tuning module
--------------------------------

.. automodule:: pyspark.ml.tuning
:members:
:undoc-members:
:inherited-members:

pyspark.ml.evaluation module
--------------------------------

.. automodule:: pyspark.ml.evaluation
:members:
:undoc-members:
:inherited-members:
107 changes: 107 additions & 0 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.ml.wrapper import JavaEvaluator
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc

__all__ = ['BinaryClassificationEvaluator']


@inherit_doc
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
"""
Evaluator for binary classification, which expects two input
columns: rawPrediction and label.
>>> from pyspark.mllib.linalg import Vectors
>>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]),
... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)])
>>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
...
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw")
>>> evaluator.evaluate(dataset)
0.70...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
0.83...
"""

_java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator"

# a placeholder to make it appear in the generated doc
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")

@keyword_only
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC"):
"""
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC")
"""
super(BinaryClassificationEvaluator, self).__init__()
#: param for metric name in evaluation (areaUnderROC|areaUnderPR)
self.metricName = Param(self, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")
self._setDefault(rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC")
kwargs = self.__init__._input_kwargs
self._set(**kwargs)

def setMetricName(self, value):
"""
Sets the value of :py:attr:`metricName`.
"""
self.paramMap[self.metricName] = value
return self

def getMetricName(self):
"""
Gets the value of metricName or its default value.
"""
return self.getOrDefault(self.metricName)

@keyword_only
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC"):
"""
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC")
Sets params for binary classification evaluator.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.evaluation tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)
1 change: 1 addition & 0 deletions python/pyspark/ml/param/_shared_params_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get$Name(self):
("featuresCol", "features column name", "'features'"),
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
("inputCol", "input column name", None),
("outputCol", "output column name", None),
("numFeatures", "number of features", None)]
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,35 @@ def getPredictionCol(self):
return self.getOrDefault(self.predictionCol)


class HasRawPredictionCol(Params):
"""
Mixin for param rawPredictionCol: raw prediction column name.
"""

# a placeholder to make it appear in the generated doc
rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction column name")

def __init__(self):
super(HasRawPredictionCol, self).__init__()
#: param for raw prediction column name
self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction column name")
if 'rawPrediction' is not None:
self._setDefault(rawPredictionCol='rawPrediction')

def setRawPredictionCol(self, value):
"""
Sets the value of :py:attr:`rawPredictionCol`.
"""
self.paramMap[self.rawPredictionCol] = value
return self

def getRawPredictionCol(self):
"""
Gets the value of rawPredictionCol or its default value.
"""
return self.getOrDefault(self.rawPredictionCol)


class HasInputCol(Params):
"""
Mixin for param inputCol: input column name.
Expand Down
23 changes: 22 additions & 1 deletion python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyspark.mllib.common import inherit_doc


__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']


@inherit_doc
Expand Down Expand Up @@ -168,3 +168,24 @@ def transform(self, dataset, params={}):
for t in self.transformers:
dataset = t.transform(dataset, paramMap)
return dataset


class Evaluator(object):
"""
Base class for evaluators that compute metrics from predictions.
"""

__metaclass__ = ABCMeta

@abstractmethod
def evaluate(self, dataset, params={}):
"""
Evaluates the output.
:param dataset: a dataset that contains labels/observations and
predictions
:param params: an optional param map that overrides embedded
params
:return: metric
"""
raise NotImplementedError()
17 changes: 16 additions & 1 deletion python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
from pyspark.mllib.common import inherit_doc


Expand Down Expand Up @@ -147,3 +147,18 @@ def __init__(self, java_model):

def _java_obj(self):
return self._java_model


@inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def evaluate(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())
2 changes: 1 addition & 1 deletion python/pyspark/sql/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def _python_to_sql_converter(dataType):

if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
converters = map(_python_to_sql_converter, types)
converters = [_python_to_sql_converter(t) for t in types]

def converter(obj):
if isinstance(obj, dict):
Expand Down
1 change: 1 addition & 0 deletions python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ function run_ml_tests() {
run_test "pyspark/ml/classification.py"
run_test "pyspark/ml/tuning.py"
run_test "pyspark/ml/tests.py"
run_test "pyspark/ml/evaluation.py"
}

function run_streaming_tests() {
Expand Down

0 comments on commit dfb6bfc

Please sign in to comment.